Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/execution/instance.py: 93.75%

112 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-04-15 14:01 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17ExecutionInstance 

18================= 

19An instance of the Remote Execution Service. 

20""" 

21 

22import logging 

23from collections import defaultdict 

24from contextlib import ExitStack 

25from typing import Any, Dict, Iterable, Optional, Set, Tuple 

26 

27from buildgrid_metering.models.dataclasses import Identity, RPCUsage, Usage 

28 

29from buildgrid._exceptions import FailedPreconditionError, InvalidArgumentError 

30from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import DESCRIPTOR as RE_DESCRIPTOR 

31from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

32 Action, 

33 ActionCacheUpdateCapabilities, 

34 CacheCapabilities, 

35 Command, 

36 Digest, 

37 DigestFunction, 

38 Platform, 

39 RequestMetadata, 

40) 

41from buildgrid._protos.google.longrunning.operations_pb2 import Operation 

42from buildgrid.server.metrics_names import SCHEDULER_QUEUE_ACTION_TIME_METRIC_NAME 

43from buildgrid.server.metrics_utils import DurationMetric 

44from buildgrid.server.persistence.sql.impl import SQLDataStore 

45from buildgrid.server.persistence.sql.models import ClientIdentityEntry 

46from buildgrid.server.servicer import Instance 

47from buildgrid.server.utils.context import CancellationContext 

48from buildgrid.utils import get_hash_type 

49 

50# All priorities >= this value will not be throttled / deprioritized 

51EXECUTION_DEPRIORITIZATION_LIMIT = 1 

52 

53LOGGER = logging.getLogger(__name__) 

54 

55 

56class ExecutionInstance(Instance): 

57 SERVICE_NAME = RE_DESCRIPTOR.services_by_name["Execution"].full_name 

58 

59 def __init__( 

60 self, 

61 scheduler: SQLDataStore, 

62 operation_stream_keepalive_timeout: Optional[int] = None, 

63 ) -> None: 

64 self._stack = ExitStack() 

65 self.scheduler = scheduler 

66 

67 self._operation_stream_keepalive_timeout = operation_stream_keepalive_timeout 

68 

69 # --- Public API --- 

70 

71 def start(self) -> None: 

72 self.scheduler.start(start_job_watcher=True) 

73 self._stack.callback(self.scheduler.stop) 

74 

75 def stop(self) -> None: 

76 self._stack.close() 

77 

78 def set_instance_name(self, instance_name: str) -> None: 

79 super().set_instance_name(instance_name) 

80 self.scheduler.set_instance_name(instance_name) 

81 

82 def hash_type(self) -> "DigestFunction.Value.ValueType": 

83 return get_hash_type() 

84 

85 @DurationMetric(SCHEDULER_QUEUE_ACTION_TIME_METRIC_NAME, instanced=True) 

86 def execute( 

87 self, 

88 *, 

89 action_digest: Digest, 

90 skip_cache_lookup: bool, 

91 priority: int = 0, 

92 request_metadata: Optional[RequestMetadata] = None, 

93 client_identity: Optional[ClientIdentityEntry] = None, 

94 ) -> str: 

95 """ 

96 Sends a job for execution. Queues an action and creates an Operation to be associated with this action. 

97 """ 

98 

99 action = self.scheduler.storage.get_message(action_digest, Action) 

100 if not action: 

101 raise FailedPreconditionError("Could not get action from storage.") 

102 

103 command = self.scheduler.storage.get_message(action.command_digest, Command) 

104 if not command: 

105 raise FailedPreconditionError("Could not get command from storage.") 

106 

107 if action.HasField("platform"): 

108 platform = action.platform 

109 elif command.HasField("platform"): 

110 platform = command.platform 

111 else: 

112 platform = Platform() 

113 

114 platform_requirements: Dict[str, Set[str]] = defaultdict(set) 

115 for platform_property in platform.properties: 

116 name = platform_property.name 

117 if name not in self.scheduler.property_keys: 

118 raise FailedPreconditionError( 

119 f"Unregistered platform property [{name}]. Known properties are: [{self.scheduler.property_keys}]" 

120 ) 

121 if name in self.scheduler.unique_keys and name in platform_requirements: 

122 raise FailedPreconditionError(f"Platform property [{name}] can only be set once.") 

123 if name in self.scheduler.match_properties: 

124 platform_requirements[name].add(platform_property.value) 

125 

126 should_throttle = self._should_throttle_execution(priority, client_identity) 

127 if should_throttle: 

128 # TODO test_execution_instance is a total mess. It mocks way too much making tests 

129 # brittle. when possible merge it into execution_service tests and use proper logging here. 

130 # Should be able to write `action_digest=[{action_digest.hash}/{action_digest.size_bytes}]`, but cant 

131 # AttributeError: 'str' object has no attribute 'hash' 

132 LOGGER.info( 

133 f"Job priority throttled action_digest=[{action_digest}] " 

134 f"old_priority=[{priority}] new_priority=[{EXECUTION_DEPRIORITIZATION_LIMIT}]" 

135 ) 

136 priority = EXECUTION_DEPRIORITIZATION_LIMIT 

137 

138 operation_name = self.scheduler.queue_job_action( 

139 action=action, 

140 action_digest=action_digest, 

141 command=command, 

142 platform_requirements=platform_requirements, 

143 skip_cache_lookup=skip_cache_lookup, 

144 priority=priority, 

145 request_metadata=request_metadata, 

146 client_identity=client_identity, 

147 ) 

148 self._meter_execution(client_identity, operation_name) 

149 return operation_name 

150 

151 def stream_operation_updates(self, operation_name: str, context: CancellationContext) -> Iterable[Operation]: 

152 job_name = self.scheduler.get_operation_job_name(operation_name) 

153 if not job_name: 

154 raise InvalidArgumentError(f"Operation name does not exist: [{operation_name}]") 

155 # Start the listener as soon as we get the job name and re-query. This avoids potentially missing 

156 # the completed update if it triggers in between sending back the first result and the yield resuming. 

157 with self.scheduler.ops_notifier.subscription(job_name) as update_requested: 

158 yield (operation := self.scheduler.load_operation(operation_name)) 

159 if operation.done: 

160 return 

161 

162 # When the context is deactivated, we can quickly stop waiting. 

163 context.on_cancel(update_requested.set) 

164 while not context.is_cancelled(): 

165 update_requested.wait(timeout=self._operation_stream_keepalive_timeout) 

166 update_requested.clear() 

167 

168 if context.is_cancelled(): 

169 return 

170 

171 yield (operation := self.scheduler.load_operation(operation_name)) 

172 if operation.done: 

173 return 

174 

175 def get_storage_capabilities(self) -> CacheCapabilities: 

176 return self.scheduler.storage.get_capabilities() 

177 

178 def get_action_cache_capabilities(self) -> Tuple[Optional[Any], Optional[ActionCacheUpdateCapabilities]]: 

179 hash_type = None 

180 capabilities = None 

181 if self.scheduler.action_cache is not None: 

182 capabilities = ActionCacheUpdateCapabilities() 

183 hash_type = self.scheduler.action_cache.hash_type() 

184 capabilities.update_enabled = self.scheduler.action_cache.allow_updates 

185 return hash_type, capabilities 

186 

187 def _meter_execution(self, client_identity: Optional[ClientIdentityEntry], operation_name: str) -> None: 

188 """Meter the number of executions of client""" 

189 if self.scheduler.metering_client is None or client_identity is None: 

190 return 

191 try: 

192 identity = Identity( 

193 instance=client_identity.instance, 

194 workflow=client_identity.workflow, 

195 actor=client_identity.actor, 

196 subject=client_identity.subject, 

197 ) 

198 usage = Usage(rpc=RPCUsage(execute=1)) 

199 self.scheduler.metering_client.put_usage(identity, operation_name, usage) 

200 except Exception as exc: 

201 LOGGER.exception(f"Failed to publish execution usage, {identity=} {usage=}", exc_info=exc) 

202 

203 def _should_throttle_execution(self, priority: int, client_identity: Optional[ClientIdentityEntry]) -> bool: 

204 if ( 

205 priority >= EXECUTION_DEPRIORITIZATION_LIMIT 

206 or self.scheduler.metering_client is None 

207 or client_identity is None 

208 ): 

209 return False 

210 try: 

211 identity = Identity( 

212 instance=client_identity.instance, 

213 workflow=client_identity.workflow, 

214 actor=client_identity.actor, 

215 subject=client_identity.subject, 

216 ) 

217 response = self.scheduler.metering_client.get_throttling(identity) 

218 if response.throttled: 

219 LOGGER.info( 

220 "Execution request is throttled, client_id: [%s], usage: [%s]", 

221 client_identity, 

222 response.tracked_usage, 

223 ) 

224 return response.throttled 

225 except Exception as exc: 

226 LOGGER.exception("Failed to get throttling information", exc_info=exc) 

227 return False