Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/threading.py: 100.00%
56 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-03-13 15:36 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2025-03-13 15:36 +0000
1# Copyright (C) 2024 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import contextvars
16import threading
17from concurrent import futures
18from concurrent.futures import Future
19from typing import Any, Callable, Iterable, ParamSpec, TypeVar
21from buildgrid.server.logging import buildgrid_logger
23_T = TypeVar("_T")
24_P = ParamSpec("_P")
27LOGGER = buildgrid_logger(__name__)
30class ContextThreadPoolExecutor(futures.ThreadPoolExecutor):
31 def __init__(
32 self,
33 max_workers: int | None = None,
34 thread_name_prefix: str = "",
35 initializer: Callable[[], Any] | None = None,
36 initargs: tuple[Any, ...] = (),
37 immediate_copy: bool = False,
38 ) -> None:
39 """
40 Create a thread pool executor which forwards context from the creating thread.
42 immediate_copy if true, copies the context when this threadpool object is created.
43 If false, the context will be copied as jobs are submitted to it.
44 """
46 self._init_ctx: contextvars.Context | None = None
47 if immediate_copy:
48 self._init_ctx = contextvars.copy_context()
50 super().__init__(
51 max_workers=max_workers,
52 thread_name_prefix=thread_name_prefix,
53 initializer=initializer,
54 initargs=initargs,
55 )
57 def submit(self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> Future[_T]:
58 if self._init_ctx is None:
59 run = contextvars.copy_context().run
60 else:
61 run = self._init_ctx.copy().run
63 # In newer versions of grpcio (>=1.60.0), a context is used, but is not copied from
64 # the initializer thread. We can use our own instead.
65 if isinstance(getattr(fn, "__self__", None), contextvars.Context):
66 return super().submit(run, *args, **kwargs) # type: ignore[arg-type]
67 return super().submit(run, fn, *args, **kwargs) # type: ignore[arg-type]
70class ContextThread(threading.Thread):
71 def __init__(
72 self,
73 target: "Callable[_P, _T]",
74 name: str | None = None,
75 args: Iterable[Any] = (),
76 kwargs: dict[str, Any] | None = None,
77 *,
78 daemon: bool | None = None,
79 ) -> None:
80 ctx = contextvars.copy_context()
81 super().__init__(
82 target=ctx.copy().run,
83 name=name,
84 args=(target, *args),
85 kwargs=kwargs,
86 daemon=daemon,
87 )
90class ContextWorker:
91 def __init__(
92 self,
93 target: Callable[[threading.Event], None],
94 name: str | None = None,
95 *,
96 on_shutdown_requested: Callable[[], None] | None = None,
97 ) -> None:
98 """
99 Run a long-lived task in a thread, where the method is provided an Event that indicates if
100 shutdown is requested. We delay creating the thread until started to allow the context
101 to continue to be populated.
102 """
104 self._shutdown_requested = threading.Event()
105 self._thread: ContextThread | None = None
107 self._target = target
108 self._name = name
109 self._on_shutdown_requested = on_shutdown_requested
111 def is_alive(self) -> bool:
112 return self._thread is not None and self._thread.is_alive()
114 def __enter__(self) -> "ContextWorker":
115 self.start()
116 return self
118 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
119 self.stop()
121 def start(self) -> None:
122 if not self._thread:
123 self._thread = ContextThread(
124 target=lambda: self._target(self._shutdown_requested), name=self._name, daemon=True
125 )
126 self._thread.start()
128 def stop(self) -> None:
129 if not self._shutdown_requested.is_set():
130 LOGGER.info("Stopping worker.", tags=dict(name=self._name))
131 self._shutdown_requested.set()
132 if self._on_shutdown_requested:
133 self._on_shutdown_requested()
134 if self._thread:
135 self._thread.join()
136 LOGGER.info("Stopped worker.", tags=dict(name=self._name))
138 def wait(self, timeout: float | None = None) -> None:
139 if self._thread:
140 self._thread.join(timeout=timeout)