Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/threading.py: 100.00%

56 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-03-13 15:36 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import contextvars 

16import threading 

17from concurrent import futures 

18from concurrent.futures import Future 

19from typing import Any, Callable, Iterable, ParamSpec, TypeVar 

20 

21from buildgrid.server.logging import buildgrid_logger 

22 

23_T = TypeVar("_T") 

24_P = ParamSpec("_P") 

25 

26 

27LOGGER = buildgrid_logger(__name__) 

28 

29 

30class ContextThreadPoolExecutor(futures.ThreadPoolExecutor): 

31 def __init__( 

32 self, 

33 max_workers: int | None = None, 

34 thread_name_prefix: str = "", 

35 initializer: Callable[[], Any] | None = None, 

36 initargs: tuple[Any, ...] = (), 

37 immediate_copy: bool = False, 

38 ) -> None: 

39 """ 

40 Create a thread pool executor which forwards context from the creating thread. 

41 

42 immediate_copy if true, copies the context when this threadpool object is created. 

43 If false, the context will be copied as jobs are submitted to it. 

44 """ 

45 

46 self._init_ctx: contextvars.Context | None = None 

47 if immediate_copy: 

48 self._init_ctx = contextvars.copy_context() 

49 

50 super().__init__( 

51 max_workers=max_workers, 

52 thread_name_prefix=thread_name_prefix, 

53 initializer=initializer, 

54 initargs=initargs, 

55 ) 

56 

57 def submit(self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> Future[_T]: 

58 if self._init_ctx is None: 

59 run = contextvars.copy_context().run 

60 else: 

61 run = self._init_ctx.copy().run 

62 

63 # In newer versions of grpcio (>=1.60.0), a context is used, but is not copied from 

64 # the initializer thread. We can use our own instead. 

65 if isinstance(getattr(fn, "__self__", None), contextvars.Context): 

66 return super().submit(run, *args, **kwargs) # type: ignore[arg-type] 

67 return super().submit(run, fn, *args, **kwargs) # type: ignore[arg-type] 

68 

69 

70class ContextThread(threading.Thread): 

71 def __init__( 

72 self, 

73 target: "Callable[_P, _T]", 

74 name: str | None = None, 

75 args: Iterable[Any] = (), 

76 kwargs: dict[str, Any] | None = None, 

77 *, 

78 daemon: bool | None = None, 

79 ) -> None: 

80 ctx = contextvars.copy_context() 

81 super().__init__( 

82 target=ctx.copy().run, 

83 name=name, 

84 args=(target, *args), 

85 kwargs=kwargs, 

86 daemon=daemon, 

87 ) 

88 

89 

90class ContextWorker: 

91 def __init__( 

92 self, 

93 target: Callable[[threading.Event], None], 

94 name: str | None = None, 

95 *, 

96 on_shutdown_requested: Callable[[], None] | None = None, 

97 ) -> None: 

98 """ 

99 Run a long-lived task in a thread, where the method is provided an Event that indicates if 

100 shutdown is requested. We delay creating the thread until started to allow the context 

101 to continue to be populated. 

102 """ 

103 

104 self._shutdown_requested = threading.Event() 

105 self._thread: ContextThread | None = None 

106 

107 self._target = target 

108 self._name = name 

109 self._on_shutdown_requested = on_shutdown_requested 

110 

111 def is_alive(self) -> bool: 

112 return self._thread is not None and self._thread.is_alive() 

113 

114 def __enter__(self) -> "ContextWorker": 

115 self.start() 

116 return self 

117 

118 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

119 self.stop() 

120 

121 def start(self) -> None: 

122 if not self._thread: 

123 self._thread = ContextThread( 

124 target=lambda: self._target(self._shutdown_requested), name=self._name, daemon=True 

125 ) 

126 self._thread.start() 

127 

128 def stop(self) -> None: 

129 if not self._shutdown_requested.is_set(): 

130 LOGGER.info("Stopping worker.", tags=dict(name=self._name)) 

131 self._shutdown_requested.set() 

132 if self._on_shutdown_requested: 

133 self._on_shutdown_requested() 

134 if self._thread: 

135 self._thread.join() 

136 LOGGER.info("Stopped worker.", tags=dict(name=self._name)) 

137 

138 def wait(self, timeout: float | None = None) -> None: 

139 if self._thread: 

140 self._thread.join(timeout=timeout)