Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/threading.py: 96.55%
58 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2024 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import contextvars
16import threading
17from concurrent import futures
18from concurrent.futures import Future
19from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar
21from buildgrid.server.logging import buildgrid_logger
23_T = TypeVar("_T")
25if TYPE_CHECKING:
26 from typing_extensions import ParamSpec
28 _P = ParamSpec("_P")
31LOGGER = buildgrid_logger(__name__)
34class ContextThreadPoolExecutor(futures.ThreadPoolExecutor):
35 def __init__(
36 self,
37 max_workers: Optional[int] = None,
38 thread_name_prefix: str = "",
39 initializer: Optional[Callable[[], Any]] = None,
40 initargs: "Tuple[Any, ...]" = (),
41 immediate_copy: bool = False,
42 ) -> None:
43 """
44 Create a thread pool executor which forwards context from the creating thread.
46 immediate_copy if true, copies the context when this threadpool object is created.
47 If false, the context will be copied as jobs are submitted to it.
48 """
50 self._init_ctx: Optional[contextvars.Context] = None
51 if immediate_copy:
52 self._init_ctx = contextvars.copy_context()
54 super().__init__(
55 max_workers=max_workers,
56 thread_name_prefix=thread_name_prefix,
57 initializer=initializer,
58 initargs=initargs,
59 )
61 def submit(self, fn: "Callable[_P, _T]", *args: "_P.args", **kwargs: "_P.kwargs") -> "Future[_T]":
62 if self._init_ctx is None:
63 run = contextvars.copy_context().run
64 else:
65 run = self._init_ctx.copy().run
67 # In newer versions of grpcio (>=1.60.0), a context is used, but is not copied from
68 # the initializer thread. We can use our own instead.
69 if isinstance(getattr(fn, "__self__", None), contextvars.Context):
70 return super().submit(run, *args, **kwargs) # type: ignore[arg-type]
71 return super().submit(run, fn, *args, **kwargs) # type: ignore[arg-type]
74class ContextThread(threading.Thread):
75 def __init__(
76 self,
77 target: "Callable[_P, _T]",
78 name: Optional[str] = None,
79 args: Iterable[Any] = (),
80 kwargs: Optional[Dict[str, Any]] = None,
81 *,
82 daemon: Optional[bool] = None,
83 ) -> None:
84 ctx = contextvars.copy_context()
85 super().__init__(
86 target=ctx.copy().run,
87 name=name,
88 args=(target, *args),
89 kwargs=kwargs,
90 daemon=daemon,
91 )
94class ContextWorker:
95 def __init__(
96 self,
97 target: Callable[[threading.Event], None],
98 name: Optional[str] = None,
99 *,
100 on_shutdown_requested: Optional[Callable[[], None]] = None,
101 ) -> None:
102 """
103 Run a long-lived task in a thread, where the method is provided an Event that indicates if
104 shutdown is requested. We delay creating the thread until started to allow the context
105 to continue to be populated.
106 """
108 self._shutdown_requested = threading.Event()
109 self._thread: Optional[ContextThread] = None
111 self._target = target
112 self._name = name
113 self._on_shutdown_requested = on_shutdown_requested
115 def is_alive(self) -> bool:
116 return self._thread is not None and self._thread.is_alive()
118 def __enter__(self) -> "ContextWorker":
119 self.start()
120 return self
122 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
123 self.stop()
125 def start(self) -> None:
126 if not self._thread:
127 self._thread = ContextThread(
128 target=lambda: self._target(self._shutdown_requested), name=self._name, daemon=True
129 )
130 self._thread.start()
132 def stop(self) -> None:
133 if not self._shutdown_requested.is_set():
134 LOGGER.info("Stopping worker.", tags=dict(name=self._name))
135 self._shutdown_requested.set()
136 if self._on_shutdown_requested:
137 self._on_shutdown_requested()
138 if self._thread:
139 self._thread.join()
140 LOGGER.info("Stopped worker.", tags=dict(name=self._name))
142 def wait(self, timeout: Optional[float] = None) -> None:
143 if self._thread:
144 self._thread.join(timeout=timeout)