Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/threading.py: 96.55%

58 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import contextvars 

16import logging 

17import threading 

18from concurrent import futures 

19from concurrent.futures import Future 

20from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Tuple, TypeVar 

21 

22_T = TypeVar("_T") 

23 

24if TYPE_CHECKING: 

25 from typing_extensions import ParamSpec 

26 

27 _P = ParamSpec("_P") 

28 

29 

30LOGGER = logging.getLogger(__name__) 

31 

32 

33class ContextThreadPoolExecutor(futures.ThreadPoolExecutor): 

34 def __init__( 

35 self, 

36 max_workers: Optional[int] = None, 

37 thread_name_prefix: str = "", 

38 initializer: Optional[Callable[[], Any]] = None, 

39 initargs: "Tuple[Any, ...]" = (), 

40 immediate_copy: bool = False, 

41 ) -> None: 

42 """ 

43 Create a thread pool executor which forwards context from the creating thread. 

44 

45 immediate_copy if true, copies the context when this threadpool object is created. 

46 If false, the context will be copied as jobs are submitted to it. 

47 """ 

48 

49 self._init_ctx: Optional[contextvars.Context] = None 

50 if immediate_copy: 

51 self._init_ctx = contextvars.copy_context() 

52 

53 super().__init__( 

54 max_workers=max_workers, 

55 thread_name_prefix=thread_name_prefix, 

56 initializer=initializer, 

57 initargs=initargs, 

58 ) 

59 

60 def submit(self, fn: "Callable[_P, _T]", *args: "_P.args", **kwargs: "_P.kwargs") -> "Future[_T]": 

61 if self._init_ctx is None: 

62 run = contextvars.copy_context().run 

63 else: 

64 run = self._init_ctx.copy().run 

65 

66 # In newer versions of grpcio (>=1.60.0), a context is used, but is not copied from 

67 # the initializer thread. We can use our own instead. 

68 if isinstance(getattr(fn, "__self__", None), contextvars.Context): 

69 return super().submit(run, *args, **kwargs) # type: ignore[arg-type] 

70 return super().submit(run, fn, *args, **kwargs) # type: ignore[arg-type] 

71 

72 

73class ContextThread(threading.Thread): 

74 def __init__( 

75 self, 

76 target: "Callable[_P, _T]", 

77 name: Optional[str] = None, 

78 args: Iterable[Any] = (), 

79 kwargs: Optional[Dict[str, Any]] = None, 

80 *, 

81 daemon: Optional[bool] = None, 

82 ) -> None: 

83 ctx = contextvars.copy_context() 

84 super().__init__( 

85 target=ctx.copy().run, 

86 name=name, 

87 args=(target, *args), 

88 kwargs=kwargs, 

89 daemon=daemon, 

90 ) 

91 

92 

93class ContextWorker: 

94 def __init__( 

95 self, 

96 target: Callable[[threading.Event], None], 

97 name: Optional[str] = None, 

98 *, 

99 on_shutdown_requested: Optional[Callable[[], None]] = None, 

100 ) -> None: 

101 """ 

102 Run a long-lived task in a thread, where the method is provided an Event that indicates if 

103 shutdown is requested. We delay creating the thread until started to allow the context 

104 to continue to be populated. 

105 """ 

106 

107 self._shutdown_requested = threading.Event() 

108 self._thread: Optional[ContextThread] = None 

109 

110 self._target = target 

111 self._name = name 

112 self._on_shutdown_requested = on_shutdown_requested 

113 

114 def is_alive(self) -> bool: 

115 return self._thread is not None and self._thread.is_alive() 

116 

117 def __enter__(self) -> "ContextWorker": 

118 self.start() 

119 return self 

120 

121 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: 

122 self.stop() 

123 

124 def start(self) -> None: 

125 if not self._thread: 

126 self._thread = ContextThread( 

127 target=lambda: self._target(self._shutdown_requested), name=self._name, daemon=True 

128 ) 

129 self._thread.start() 

130 

131 def stop(self) -> None: 

132 if not self._shutdown_requested.is_set(): 

133 LOGGER.info(f"Stopping {self._name} worker") 

134 self._shutdown_requested.set() 

135 if self._on_shutdown_requested: 

136 self._on_shutdown_requested() 

137 if self._thread: 

138 self._thread.join() 

139 LOGGER.info(f"Stopped {self._name} worker") 

140 

141 def wait(self, timeout: Optional[float] = None) -> None: 

142 if self._thread: 

143 self._thread.join(timeout=timeout)