Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/threading.py: 96.55%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import contextvars 

16import threading 

17from concurrent import futures 

18from concurrent.futures import Future 

19from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar 

20 

21from buildgrid.server.logging import buildgrid_logger 

22 

23_T = TypeVar("_T") 

24 

25if TYPE_CHECKING: 

26 from typing_extensions import ParamSpec 

27 

28 _P = ParamSpec("_P") 

29 

30 

31LOGGER = buildgrid_logger(__name__) 

32 

33 

34class ContextThreadPoolExecutor(futures.ThreadPoolExecutor): 

35 def __init__( 

36 self, 

37 max_workers: Optional[int] = None, 

38 thread_name_prefix: str = "", 

39 initializer: Optional[Callable[[], Any]] = None, 

40 initargs: "Tuple[Any, ...]" = (), 

41 immediate_copy: bool = False, 

42 ) -> None: 

43 """ 

44 Create a thread pool executor which forwards context from the creating thread. 

45 

46 immediate_copy if true, copies the context when this threadpool object is created. 

47 If false, the context will be copied as jobs are submitted to it. 

48 """ 

49 

50 self._init_ctx: Optional[contextvars.Context] = None 

51 if immediate_copy: 

52 self._init_ctx = contextvars.copy_context() 

53 

54 super().__init__( 

55 max_workers=max_workers, 

56 thread_name_prefix=thread_name_prefix, 

57 initializer=initializer, 

58 initargs=initargs, 

59 ) 

60 

61 def submit(self, fn: "Callable[_P, _T]", *args: "_P.args", **kwargs: "_P.kwargs") -> "Future[_T]": 

62 if self._init_ctx is None: 

63 run = contextvars.copy_context().run 

64 else: 

65 run = self._init_ctx.copy().run 

66 

67 # In newer versions of grpcio (>=1.60.0), a context is used, but is not copied from 

68 # the initializer thread. We can use our own instead. 

69 if isinstance(getattr(fn, "__self__", None), contextvars.Context): 

70 return super().submit(run, *args, **kwargs) # type: ignore[arg-type] 

71 return super().submit(run, fn, *args, **kwargs) # type: ignore[arg-type] 

72 

73 

74class ContextThread(threading.Thread): 

75 def __init__( 

76 self, 

77 target: "Callable[_P, _T]", 

78 name: Optional[str] = None, 

79 args: Iterable[Any] = (), 

80 kwargs: Optional[Dict[str, Any]] = None, 

81 *, 

82 daemon: Optional[bool] = None, 

83 ) -> None: 

84 ctx = contextvars.copy_context() 

85 super().__init__( 

86 target=ctx.copy().run, 

87 name=name, 

88 args=(target, *args), 

89 kwargs=kwargs, 

90 daemon=daemon, 

91 ) 

92 

93 

94class ContextWorker: 

95 def __init__( 

96 self, 

97 target: Callable[[threading.Event], None], 

98 name: Optional[str] = None, 

99 *, 

100 on_shutdown_requested: Optional[Callable[[], None]] = None, 

101 ) -> None: 

102 """ 

103 Run a long-lived task in a thread, where the method is provided an Event that indicates if 

104 shutdown is requested. We delay creating the thread until started to allow the context 

105 to continue to be populated. 

106 """ 

107 

108 self._shutdown_requested = threading.Event() 

109 self._thread: Optional[ContextThread] = None 

110 

111 self._target = target 

112 self._name = name 

113 self._on_shutdown_requested = on_shutdown_requested 

114 

115 def is_alive(self) -> bool: 

116 return self._thread is not None and self._thread.is_alive() 

117 

118 def __enter__(self) -> "ContextWorker": 

119 self.start() 

120 return self 

121 

122 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

123 self.stop() 

124 

125 def start(self) -> None: 

126 if not self._thread: 

127 self._thread = ContextThread( 

128 target=lambda: self._target(self._shutdown_requested), name=self._name, daemon=True 

129 ) 

130 self._thread.start() 

131 

132 def stop(self) -> None: 

133 if not self._shutdown_requested.is_set(): 

134 LOGGER.info("Stopping worker.", tags=dict(name=self._name)) 

135 self._shutdown_requested.set() 

136 if self._on_shutdown_requested: 

137 self._on_shutdown_requested() 

138 if self._thread: 

139 self._thread.join() 

140 LOGGER.info("Stopped worker.", tags=dict(name=self._name)) 

141 

142 def wait(self, timeout: Optional[float] = None) -> None: 

143 if self._thread: 

144 self._thread.join(timeout=timeout)