Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/scheduler/impl.py: 92.59%
850 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2019 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import threading
17import uuid
18from collections import defaultdict
19from contextlib import ExitStack
20from datetime import datetime, timedelta
21from time import time
22from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, TypedDict, TypeVar, Union, cast
24from buildgrid_metering.client import SyncMeteringServiceClient
25from buildgrid_metering.models.dataclasses import ComputingUsage, Identity, Usage
26from google.protobuf.any_pb2 import Any as ProtoAny
27from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
28from grpc import Channel
29from sqlalchemy import and_, delete, func, insert, literal_column, or_, select, union, update
30from sqlalchemy.dialects import postgresql, sqlite
31from sqlalchemy.exc import IntegrityError
32from sqlalchemy.orm import Session, joinedload, selectinload
33from sqlalchemy.sql import ClauseElement
34from sqlalchemy.sql.expression import Insert, Select
36from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import (
38 Action,
39 ActionResult,
40 Command,
41 Digest,
42 ExecutedActionMetadata,
43 ExecuteOperationMetadata,
44 ExecuteResponse,
45 RequestMetadata,
46 ToolDetails,
47)
48from buildgrid._protos.build.buildbox.execution_stats_pb2 import ExecutionStatistics
49from buildgrid._protos.buildgrid.v2.identity_pb2 import ClientIdentity
50from buildgrid._protos.google.devtools.remoteworkers.v1test2.bots_pb2 import Lease
51from buildgrid._protos.google.longrunning import operations_pb2
52from buildgrid._protos.google.longrunning.operations_pb2 import Operation
53from buildgrid._protos.google.rpc import code_pb2, status_pb2
54from buildgrid._protos.google.rpc.status_pb2 import Status
55from buildgrid.server.actioncache.caches.action_cache_abc import ActionCacheABC
56from buildgrid.server.cas.storage.storage_abc import StorageABC
57from buildgrid.server.client.asset import AssetClient
58from buildgrid.server.client.logstream import logstream_client
59from buildgrid.server.context import current_instance, instance_context, try_current_instance
60from buildgrid.server.decorators import timed
61from buildgrid.server.enums import BotStatus, LeaseState, MeteringThrottleAction, MetricCategories, OperationStage
62from buildgrid.server.exceptions import (
63 BotSessionClosedError,
64 BotSessionMismatchError,
65 CancelledError,
66 DatabaseError,
67 InvalidArgumentError,
68 NotFoundError,
69 ResourceExhaustedError,
70 UpdateNotAllowedError,
71)
72from buildgrid.server.logging import buildgrid_logger
73from buildgrid.server.metrics_names import METRIC
74from buildgrid.server.metrics_utils import publish_counter_metric, publish_timer_metric, timer
75from buildgrid.server.operations.filtering import DEFAULT_SORT_KEYS, OperationFilter, SortKey
76from buildgrid.server.settings import DEFAULT_MAX_EXECUTION_TIMEOUT, SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS
77from buildgrid.server.sql.models import Base as OrmBase
78from buildgrid.server.sql.models import (
79 BotEntry,
80 ClientIdentityEntry,
81 JobEntry,
82 LeaseEntry,
83 OperationEntry,
84 PlatformEntry,
85 RequestMetadataEntry,
86 digest_to_string,
87 job_platform_association,
88 string_to_digest,
89)
90from buildgrid.server.sql.provider import SqlProvider
91from buildgrid.server.sql.utils import (
92 build_custom_filters,
93 build_page_filter,
94 build_page_token,
95 build_sort_column_list,
96 extract_sort_keys,
97)
98from buildgrid.server.threading import ContextWorker
99from buildgrid.server.utils.digests import create_digest
101from .assigner import JobAssigner
102from .notifier import OperationsNotifier
103from .properties import PropertySet, hash_from_dict
105LOGGER = buildgrid_logger(__name__)
108PROTOBUF_MEDIA_TYPE = "application/x-protobuf"
109DIGEST_URI_TEMPLATE = "nih:sha-256;{digest_hash}"
112class SchedulerMetrics(TypedDict, total=False):
113 leases: Dict[LeaseState, int]
114 jobs: Dict[OperationStage, int]
117class AgedJobHandlerOptions(NamedTuple):
118 job_max_age: timedelta = timedelta(days=30)
119 handling_period: timedelta = timedelta(minutes=5)
120 max_handling_window: int = 10000
122 @staticmethod
123 def from_config(
124 job_max_age_cfg: Dict[str, float],
125 handling_period_cfg: Optional[Dict[str, float]] = None,
126 max_handling_window_cfg: Optional[int] = None,
127 ) -> "AgedJobHandlerOptions":
128 """Helper method for creating ``AgedJobHandlerOptions`` objects
129 If input configs are None, assign defaults"""
131 def _dict_to_timedelta(config: Dict[str, float]) -> timedelta:
132 return timedelta(
133 weeks=config.get("weeks", 0),
134 days=config.get("days", 0),
135 hours=config.get("hours", 0),
136 minutes=config.get("minutes", 0),
137 seconds=config.get("seconds", 0),
138 )
140 return AgedJobHandlerOptions(
141 job_max_age=_dict_to_timedelta(job_max_age_cfg) if job_max_age_cfg else timedelta(days=30),
142 handling_period=_dict_to_timedelta(handling_period_cfg) if handling_period_cfg else timedelta(minutes=5),
143 max_handling_window=max_handling_window_cfg if max_handling_window_cfg else 10000,
144 )
147T = TypeVar("T", bound="Scheduler")
150class Scheduler:
151 RETRYABLE_STATUS_CODES = (code_pb2.INTERNAL, code_pb2.UNAVAILABLE)
153 def __init__(
154 self,
155 sql_provider: SqlProvider,
156 storage: StorageABC,
157 *,
158 sql_ro_provider: Optional[SqlProvider] = None,
159 sql_notifier_provider: Optional[SqlProvider] = None,
160 property_set: PropertySet,
161 action_cache: Optional[ActionCacheABC] = None,
162 action_browser_url: Optional[str] = None,
163 max_execution_timeout: int = DEFAULT_MAX_EXECUTION_TIMEOUT,
164 metering_client: Optional[SyncMeteringServiceClient] = None,
165 metering_throttle_action: Optional[MeteringThrottleAction] = None,
166 bot_session_keepalive_timeout: int = 600,
167 logstream_channel: Optional[Channel] = None,
168 logstream_instance: Optional[str] = None,
169 asset_client: Optional[AssetClient] = None,
170 queued_action_retention_hours: Optional[float] = None,
171 completed_action_retention_hours: Optional[float] = None,
172 action_result_retention_hours: Optional[float] = None,
173 enable_job_watcher: bool = False,
174 poll_interval: float = 1,
175 pruning_options: Optional[AgedJobHandlerOptions] = None,
176 queue_timeout_options: Optional[AgedJobHandlerOptions] = None,
177 max_job_attempts: int = 5,
178 job_assignment_interval: float = 1.0,
179 priority_assignment_percentage: int = 100,
180 max_queue_size: Optional[int] = None,
181 ) -> None:
182 self._stack = ExitStack()
184 self.storage = storage
186 self.poll_interval = poll_interval
187 self.pruning_options = pruning_options
188 self.queue_timeout_options = queue_timeout_options
189 self.max_job_attempts = max_job_attempts
191 self._sql = sql_provider
192 self._sql_ro = sql_ro_provider or sql_provider
193 self._sql_notifier = sql_notifier_provider or sql_provider
195 self.property_set = property_set
197 self.action_cache = action_cache
198 self.action_browser_url = (action_browser_url or "").rstrip("/")
199 self.max_execution_timeout = max_execution_timeout
200 self.enable_job_watcher = enable_job_watcher
201 self.metering_client = metering_client
202 self.metering_throttle_action = metering_throttle_action or MeteringThrottleAction.DEPRIORITIZE
203 self.bot_session_keepalive_timeout = bot_session_keepalive_timeout
204 self.logstream_channel = logstream_channel
205 self.logstream_instance = logstream_instance
206 self.asset_client = asset_client
207 self.queued_action_retention_hours = queued_action_retention_hours
208 self.completed_action_retention_hours = completed_action_retention_hours
209 self.action_result_retention_hours = action_result_retention_hours
210 self.max_queue_size = max_queue_size
212 # Overall Scheduler Metrics (totals of jobs/leases in each state)
213 # Publish those metrics a bit more sparsely since the SQL requests
214 # required to gather them can become expensive
215 self._last_scheduler_metrics_publish_time: Optional[datetime] = None
216 self._scheduler_metrics_publish_interval = timedelta(seconds=SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS)
218 self.ops_notifier = OperationsNotifier(self._sql_notifier, self.poll_interval)
219 self.prune_timer = ContextWorker(name="JobPruner", target=self.prune_timer_loop)
220 self.queue_timer = ContextWorker(name="QueueTimeout", target=self.queue_timer_loop)
221 self.execution_timer = ContextWorker(name="ExecutionTimeout", target=self.execution_timer_loop)
222 self.session_expiry_timer = ContextWorker(self.session_expiry_timer_loop, "BotReaper")
223 self.job_assigner = JobAssigner(
224 self,
225 property_set=property_set,
226 job_assignment_interval=job_assignment_interval,
227 priority_percentage=priority_assignment_percentage,
228 )
230 def __repr__(self) -> str:
231 return f"Scheduler for `{repr(self._sql._engine.url)}`"
233 def __enter__(self: T) -> T:
234 self.start()
235 return self
237 def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
238 self.stop()
240 def start(
241 self,
242 *,
243 start_job_watcher: bool = True,
244 ) -> None:
245 self._stack.enter_context(self.storage)
246 if self.action_cache:
247 self._stack.enter_context(self.action_cache)
249 if self.logstream_channel:
250 self._stack.enter_context(self.logstream_channel)
251 if self.asset_client:
252 self._stack.enter_context(self.asset_client)
253 # Pruning configuration parameters
254 if self.pruning_options is not None:
255 LOGGER.info(f"Scheduler pruning enabled: {self.pruning_options}")
256 self._stack.enter_context(self.prune_timer)
257 else:
258 LOGGER.info("Scheduler pruning not enabled.")
260 # Queue timeout thread
261 if self.queue_timeout_options is not None:
262 LOGGER.info(f"Job queue timeout enabled: {self.queue_timeout_options}")
263 self._stack.enter_context(self.queue_timer)
264 else:
265 LOGGER.info("Job queue timeout not enabled.")
267 if start_job_watcher:
268 self._stack.enter_context(self.execution_timer)
269 self._stack.enter_context(self.ops_notifier)
271 def stop(self) -> None:
272 self._stack.close()
273 LOGGER.info("Stopped Scheduler.")
275 def _job_in_instance(self) -> ClauseElement:
276 return JobEntry.instance_name == current_instance()
278 def _bot_in_instance(self) -> ClauseElement:
279 return BotEntry.instance_name == current_instance()
281 def queue_job_action(
282 self,
283 *,
284 action: Action,
285 action_digest: Digest,
286 command: Command,
287 platform_requirements: Dict[str, List[str]],
288 property_label: str,
289 priority: int,
290 skip_cache_lookup: bool,
291 request_metadata: Optional[RequestMetadata] = None,
292 client_identity: Optional[ClientIdentityEntry] = None,
293 ) -> str:
294 """
295 De-duplicates or inserts a newly created job into the execution queue.
296 Returns an operation name associated with this job.
297 """
299 if not action.do_not_cache:
300 if operation_name := self.create_operation_for_existing_job(
301 action_digest=action_digest,
302 priority=priority,
303 request_metadata=request_metadata,
304 client_identity=client_identity,
305 ):
306 return operation_name
308 # If there was another job already in the action cache, we can check now.
309 # We can use this entry to create a job and create it already completed!
310 execute_response: Optional[ExecuteResponse] = None
311 if self.action_cache and not action.do_not_cache and not skip_cache_lookup:
312 try:
313 action_result = self.action_cache.get_action_result(action_digest)
314 LOGGER.info("Job cache hit for action.", tags=dict(digest=action_digest))
315 execute_response = ExecuteResponse()
316 execute_response.result.CopyFrom(action_result)
317 execute_response.cached_result = True
318 except NotFoundError:
319 pass
320 except Exception:
321 LOGGER.exception("Checking ActionCache for action failed.", tags=dict(digest=action_digest))
323 # Extend retention for action
324 self._update_action_retention(action, action_digest, self.queued_action_retention_hours)
326 return self.create_operation_for_new_job(
327 action=action,
328 action_digest=action_digest,
329 command=command,
330 execute_response=execute_response,
331 platform_requirements=platform_requirements,
332 property_label=property_label,
333 priority=priority,
334 request_metadata=request_metadata,
335 client_identity=client_identity,
336 )
338 def create_operation_for_existing_job(
339 self,
340 *,
341 action_digest: Digest,
342 priority: int,
343 request_metadata: Optional[RequestMetadata],
344 client_identity: Optional[ClientIdentityEntry],
345 ) -> Optional[str]:
346 # Find a job with a matching action that isn't completed or cancelled and that can be cached.
347 find_existing_stmt = (
348 select(JobEntry)
349 .where(
350 JobEntry.action_digest == digest_to_string(action_digest),
351 JobEntry.stage != OperationStage.COMPLETED.value,
352 JobEntry.cancelled != True, # noqa: E712
353 JobEntry.do_not_cache != True, # noqa: E712
354 self._job_in_instance(),
355 )
356 .with_for_update()
357 )
359 with self._sql.session(exceptions_to_not_raise_on=[Exception]) as session:
360 if not (job := session.execute(find_existing_stmt).scalars().first()):
361 return None
363 # Reschedule if priority is now greater, and we're still waiting on it to start.
364 if priority < job.priority and job.stage == OperationStage.QUEUED.value:
365 LOGGER.info("Job assigned a new priority.", tags=dict(job_name=job.name, priority=priority))
366 job.priority = priority
367 job.assigned = False
369 return self._create_operation(
370 session,
371 job_name=job.name,
372 request_metadata=request_metadata,
373 client_identity=client_identity,
374 )
376 def create_operation_for_new_job(
377 self,
378 *,
379 action: Action,
380 action_digest: Digest,
381 command: Command,
382 execute_response: Optional[ExecuteResponse],
383 platform_requirements: Dict[str, List[str]],
384 property_label: str,
385 priority: int,
386 request_metadata: Optional[RequestMetadata] = None,
387 client_identity: Optional[ClientIdentityEntry] = None,
388 ) -> str:
389 if execute_response is None and self.max_queue_size is not None:
390 # Using func.count here to avoid generating a subquery in the WHERE
391 # clause of the resulting query.
392 # https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.query.Query.count
393 queue_count_statement = select(func.count(JobEntry.name)).where(
394 JobEntry.assigned != True, # noqa: E712
395 self._job_in_instance(),
396 JobEntry.property_label == property_label,
397 JobEntry.stage == OperationStage.QUEUED.value,
398 )
399 else:
400 queue_count_statement = None
402 with self._sql.session(exceptions_to_not_raise_on=[Exception]) as session:
403 if queue_count_statement is not None:
404 queue_size = session.execute(queue_count_statement).scalar_one()
405 if queue_size >= self.max_queue_size:
406 raise ResourceExhaustedError(f"The platform's job queue is full: {property_label=}")
408 now = datetime.utcnow()
409 job = JobEntry(
410 instance_name=current_instance(),
411 name=str(uuid.uuid4()),
412 action=action.SerializeToString(),
413 action_digest=digest_to_string(action_digest),
414 do_not_cache=action.do_not_cache,
415 priority=priority,
416 stage=OperationStage.QUEUED.value,
417 create_timestamp=now,
418 queued_timestamp=now,
419 command=" ".join(command.arguments),
420 platform_requirements=hash_from_dict(platform_requirements),
421 platform=self._populate_platform_requirements(session, platform_requirements),
422 property_label=property_label,
423 n_tries=1,
424 )
425 if execute_response:
426 job.stage = OperationStage.COMPLETED.value
427 job.result = digest_to_string(self.storage.put_message(execute_response))
428 job.status_code = execute_response.status.code
429 job.worker_completed_timestamp = datetime.utcnow()
431 session.add(job)
433 return self._create_operation(
434 session,
435 job_name=job.name,
436 request_metadata=request_metadata,
437 client_identity=client_identity,
438 )
440 def _populate_platform_requirements(
441 self, session: Session, platform_requirements: Dict[str, List[str]]
442 ) -> List[PlatformEntry]:
443 if not platform_requirements:
444 return []
446 required_entries = {(k, v) for k, values in platform_requirements.items() for v in values}
447 conditions = [and_(PlatformEntry.key == k, PlatformEntry.value == v) for k, v in required_entries]
448 statement = select(PlatformEntry.key, PlatformEntry.value).where(or_(*conditions))
450 while missing := required_entries - {(k, v) for [k, v] in session.execute(statement).all()}:
451 try:
452 session.bulk_insert_mappings(PlatformEntry, [{"key": k, "value": v} for k, v in missing])
453 session.commit()
454 except IntegrityError:
455 session.rollback()
457 return list(session.execute(select(PlatformEntry).where(or_(*conditions))).scalars())
459 def create_operation(
460 self,
461 job_name: str,
462 *,
463 request_metadata: Optional[RequestMetadata] = None,
464 client_identity: Optional[ClientIdentityEntry] = None,
465 ) -> str:
466 with self._sql.session(exceptions_to_not_raise_on=[Exception]) as session:
467 if not (job := self._get_job(job_name, session, with_for_update=True)):
468 raise NotFoundError(f"Job name does not exist: [{job_name}]")
470 if job.cancelled:
471 raise CancelledError(f"Job {job_name} is cancelled")
473 return self._create_operation(
474 session, job_name=job_name, request_metadata=request_metadata, client_identity=client_identity
475 )
477 def _create_operation(
478 self,
479 session: Session,
480 *,
481 job_name: str,
482 request_metadata: Optional[RequestMetadata],
483 client_identity: Optional[ClientIdentityEntry],
484 ) -> str:
486 client_identity_id: Optional[int] = None
487 if client_identity:
488 client_identity_id = self.get_or_create_client_identity_in_store(session, client_identity).id
490 request_metadata_id: Optional[int] = None
491 if request_metadata:
492 request_metadata_id = self.get_or_create_request_metadata_in_store(session, request_metadata).id
494 request_metadata = request_metadata or RequestMetadata()
495 operation = OperationEntry(
496 name=str(uuid.uuid4()),
497 job_name=job_name,
498 client_identity_id=client_identity_id,
499 request_metadata_id=request_metadata_id,
500 )
501 session.add(operation)
502 return operation.name
504 def load_operation(self, operation_name: str) -> Operation:
505 statement = (
506 select(OperationEntry).join(JobEntry).where(OperationEntry.name == operation_name, self._job_in_instance())
507 )
508 with self._sql_ro.session(exceptions_to_not_raise_on=[Exception]) as session:
509 if op := session.execute(statement).scalars().first():
510 return self._load_operation(op)
512 raise NotFoundError(f"Operation name does not exist: [{operation_name}]")
514 def _load_operation(self, op: OperationEntry) -> Operation:
515 job: JobEntry = op.job
517 operation = operations_pb2.Operation(
518 name=op.name,
519 done=job.stage == OperationStage.COMPLETED.value or op.cancelled or job.cancelled,
520 )
521 metadata = ExecuteOperationMetadata(
522 stage=OperationStage.COMPLETED.value if operation.done else job.stage, # type: ignore[arg-type]
523 action_digest=string_to_digest(job.action_digest),
524 stderr_stream_name=job.stderr_stream_name or "",
525 stdout_stream_name=job.stdout_stream_name or "",
526 partial_execution_metadata=self.get_execute_action_metadata(job),
527 )
528 operation.metadata.Pack(metadata)
530 if job.cancelled or op.cancelled:
531 operation.error.CopyFrom(status_pb2.Status(code=code_pb2.CANCELLED))
532 elif job.status_code is not None and job.status_code != code_pb2.OK:
533 operation.error.CopyFrom(status_pb2.Status(code=job.status_code))
535 execute_response: Optional[ExecuteResponse] = None
536 if job.result:
537 result_digest = string_to_digest(job.result)
538 execute_response = self.storage.get_message(result_digest, ExecuteResponse)
539 if not execute_response:
540 operation.error.CopyFrom(status_pb2.Status(code=code_pb2.DATA_LOSS))
541 elif job.cancelled:
542 execute_response = ExecuteResponse(
543 status=status_pb2.Status(code=code_pb2.CANCELLED, message="Execution cancelled")
544 )
546 if execute_response:
547 if self.action_browser_url:
548 execute_response.message = f"{self.action_browser_url}/action/{job.action_digest}/"
549 operation.response.Pack(execute_response)
551 return operation
553 def _get_job(self, job_name: str, session: Session, with_for_update: bool = False) -> Optional[JobEntry]:
554 statement = select(JobEntry).where(JobEntry.name == job_name, self._job_in_instance())
555 if with_for_update:
556 statement = statement.with_for_update()
558 job: Optional[JobEntry] = session.execute(statement).scalars().first()
559 if job:
560 LOGGER.debug(
561 "Loaded job from db.",
562 tags=dict(job_name=job_name, job_stage=job.stage, result=job.result, instance_name=job.instance_name),
563 )
565 return job
567 def get_operation_job_name(self, operation_name: str) -> Optional[str]:
568 with self._sql_ro.session(exceptions_to_not_raise_on=[Exception]) as session:
569 if operation := self._get_operation(operation_name, session):
570 return operation.job_name
571 return None
573 def get_operation_request_metadata_by_name(self, operation_name: str) -> Optional[RequestMetadata]:
574 with self._sql_ro.session(exceptions_to_not_raise_on=[Exception]) as session:
575 operation = self._get_operation(operation_name, session)
576 if not operation or not operation.request_metadata:
577 return None
579 metadata = RequestMetadata(
580 tool_details=ToolDetails(
581 tool_name=operation.request_metadata.tool_name or "",
582 tool_version=operation.request_metadata.tool_version or "",
583 ),
584 action_id=operation.job.action_digest,
585 correlated_invocations_id=operation.request_metadata.correlated_invocations_id or "",
586 tool_invocation_id=operation.request_metadata.invocation_id or "",
587 action_mnemonic=operation.request_metadata.action_mnemonic or "",
588 configuration_id=operation.request_metadata.configuration_id or "",
589 target_id=operation.request_metadata.target_id or "",
590 )
592 return metadata
594 def get_client_identity_by_operation(self, operation_name: str) -> Optional[ClientIdentity]:
595 with self._sql_ro.session(exceptions_to_not_raise_on=[Exception]) as session:
596 operation = self._get_operation(operation_name, session)
597 if not operation or not operation.client_identity:
598 return None
600 return ClientIdentity(
601 actor=operation.client_identity.actor or "",
602 subject=operation.client_identity.subject or "",
603 workflow=operation.client_identity.workflow or "",
604 )
606 def _notify_job_updated(self, job_names: Union[str, List[str]], session: Session) -> None:
607 if self._sql.dialect == "postgresql":
608 if isinstance(job_names, str):
609 job_names = [job_names]
610 for job_name in job_names:
611 # Mypy bug? "execute" of "_SessionTypingCommon" has incompatible type "str"; expected "Executable
612 session.execute(f"NOTIFY job_updated, '{job_name}';") # type: ignore[arg-type]
614 def _get_operation(self, operation_name: str, session: Session) -> Optional[OperationEntry]:
615 statement = (
616 select(OperationEntry).join(JobEntry).where(OperationEntry.name == operation_name, self._job_in_instance())
617 )
618 return session.execute(statement).scalars().first()
620 def _batch_timeout_jobs(self, job_select_stmt: Select, status_code: int, message: str) -> int:
621 """Timeout all jobs selected by a query"""
622 with self._sql.session(sqlite_lock_immediately=True, exceptions_to_not_raise_on=[Exception]) as session:
623 # Get the full list of jobs to timeout
624 jobs = [job.name for job in session.execute(job_select_stmt).scalars().all()]
626 if jobs:
627 # Put response binary
628 response = remote_execution_pb2.ExecuteResponse(
629 status=status_pb2.Status(code=status_code, message=message)
630 )
631 response_binary = response.SerializeToString()
632 response_digest = create_digest(response_binary)
633 self.storage.bulk_update_blobs([(response_digest, response_binary)])
635 # Update response
636 stmt_timeout_jobs = (
637 update(JobEntry)
638 .where(JobEntry.name.in_(jobs))
639 .values(
640 stage=OperationStage.COMPLETED.value,
641 status_code=status_code,
642 result=digest_to_string(response_digest),
643 )
644 )
645 session.execute(stmt_timeout_jobs)
647 # Notify all jobs updated
648 self._notify_job_updated(jobs, session)
649 return len(jobs)
651 def execution_timer_loop(self, shutdown_requested: threading.Event) -> None:
652 """Periodically timeout aged executing jobs"""
653 while not shutdown_requested.is_set():
654 try:
655 self.cancel_jobs_exceeding_execution_timeout(self.max_execution_timeout)
656 except Exception as e:
657 LOGGER.exception("Failed to timeout aged executing jobs.", exc_info=e)
658 shutdown_requested.wait(timeout=self.poll_interval)
660 @timed(METRIC.SCHEDULER.EXECUTION_TIMEOUT_DURATION)
661 def cancel_jobs_exceeding_execution_timeout(self, max_execution_timeout: Optional[int] = None) -> None:
662 if not max_execution_timeout:
663 return
665 # Get the full list of jobs exceeding execution timeout
666 stale_jobs_statement = (
667 select(JobEntry)
668 .where(
669 JobEntry.stage == OperationStage.EXECUTING.value,
670 JobEntry.worker_start_timestamp <= datetime.utcnow() - timedelta(seconds=max_execution_timeout),
671 )
672 .with_for_update(skip_locked=True)
673 )
674 with self._sql.session(sqlite_lock_immediately=True, exceptions_to_not_raise_on=[Exception]) as session:
675 jobs = session.execute(stale_jobs_statement).scalars().all()
676 if not jobs:
677 return
679 response = remote_execution_pb2.ExecuteResponse(
680 status=status_pb2.Status(
681 code=code_pb2.DEADLINE_EXCEEDED,
682 message="Execution didn't finish within timeout threshold",
683 )
684 )
685 response_binary = response.SerializeToString()
686 response_digest = create_digest(response_binary)
687 self.storage.bulk_update_blobs([(response_digest, response_binary)])
689 for job in jobs:
690 executing_duration = datetime.utcnow() - (job.worker_start_timestamp or datetime.utcnow())
691 LOGGER.warning(
692 "Job has been executing for too long. Cancelling.",
693 tags=dict(
694 job_name=job.name,
695 executing_duration=executing_duration,
696 max_execution_timeout=max_execution_timeout,
697 ),
698 )
699 for op in job.operations:
700 op.cancelled = True
701 for lease in job.active_leases:
702 lease.state = LeaseState.CANCELLED.value
703 job.worker_completed_timestamp = datetime.utcnow()
704 job.stage = OperationStage.COMPLETED.value
705 job.cancelled = True
706 job.result = digest_to_string(response_digest)
708 for job in jobs:
709 self._notify_job_updated(job.name, session)
711 publish_counter_metric(METRIC.SCHEDULER.EXECUTION_TIMEOUT_COUNT, len(jobs))
713 def cancel_operation(self, operation_name: str) -> None:
714 statement = (
715 select(JobEntry)
716 .join(OperationEntry)
717 .where(OperationEntry.name == operation_name, self._job_in_instance())
718 .with_for_update()
719 )
720 with self._sql.session() as session:
721 if not (job := session.execute(statement).scalars().first()):
722 raise NotFoundError(f"Operation name does not exist: [{operation_name}]")
724 if job.stage == OperationStage.COMPLETED.value or job.cancelled:
725 return
727 for op in job.operations:
728 if op.name == operation_name:
729 if op.cancelled:
730 return
731 op.cancelled = True
733 if all(op.cancelled for op in job.operations):
734 for lease in job.active_leases:
735 lease.state = LeaseState.CANCELLED.value
736 job.worker_completed_timestamp = datetime.utcnow()
737 job.stage = OperationStage.COMPLETED.value
738 job.cancelled = True
740 self._notify_job_updated(job.name, session)
742 def list_operations(
743 self,
744 operation_filters: Optional[List[OperationFilter]] = None,
745 page_size: Optional[int] = None,
746 page_token: Optional[str] = None,
747 ) -> Tuple[List[operations_pb2.Operation], str]:
748 # Build filters and sort order
749 sort_keys = DEFAULT_SORT_KEYS
750 custom_filters = None
751 platform_filters = []
752 if operation_filters:
753 # Extract custom sort order (if present)
754 specified_sort_keys, non_sort_filters = extract_sort_keys(operation_filters)
756 # Only override sort_keys if there were sort keys actually present in the filter string
757 if specified_sort_keys:
758 sort_keys = specified_sort_keys
759 # Attach the operation name as a sort key for a deterministic order
760 # This will ensure that the ordering of results is consistent between queries
761 if not any(sort_key.name == "name" for sort_key in sort_keys):
762 sort_keys.append(SortKey(name="name", descending=False))
764 # Finally, compile the non-sort filters into a filter list
765 custom_filters = build_custom_filters(non_sort_filters)
766 platform_filters = [f for f in non_sort_filters if f.parameter == "platform"]
768 sort_columns = build_sort_column_list(sort_keys)
770 with self._sql_ro.session(exceptions_to_not_raise_on=[Exception]) as session:
771 statement = (
772 select(OperationEntry)
773 .join(JobEntry, OperationEntry.job_name == JobEntry.name)
774 .outerjoin(RequestMetadataEntry)
775 .outerjoin(ClientIdentityEntry)
776 )
777 statement = statement.filter(self._job_in_instance())
779 # If we're filtering by platform, filter using a subquery containing job names
780 # which match the specified platform properties.
781 #
782 # NOTE: A platform filter using `!=` will return only jobs which set that platform
783 # property to an explicitly different value; jobs which don't set the property are
784 # filtered out.
785 if platform_filters:
786 platform_clauses = []
787 for platform_filter in platform_filters:
788 key, value = platform_filter.value.split(":", 1)
789 platform_clauses.append(
790 and_(PlatformEntry.key == key, platform_filter.operator(PlatformEntry.value, value))
791 )
793 job_name_subquery = (
794 select(job_platform_association.c.job_name)
795 .filter(
796 job_platform_association.c.platform_id.in_(
797 select(PlatformEntry.id).filter(or_(*platform_clauses))
798 )
799 )
800 .group_by(job_platform_association.c.job_name)
801 .having(func.count() == len(platform_filters))
802 )
803 statement = statement.filter(JobEntry.name.in_(job_name_subquery))
805 # Apply custom filters (if present)
806 if custom_filters:
807 statement = statement.filter(*custom_filters)
809 # Apply sort order
810 statement = statement.order_by(*sort_columns)
812 # Apply pagination filter
813 if page_token:
814 page_filter = build_page_filter(page_token, sort_keys)
815 statement = statement.filter(page_filter)
816 if page_size:
817 # We limit the number of operations we fetch to the page_size. However, we
818 # fetch an extra operation to determine whether we need to provide a
819 # next_page_token.
820 statement = statement.limit(page_size + 1)
822 operations = session.execute(statement).scalars().all()
824 if not page_size or not operations:
825 next_page_token = ""
827 # If the number of results we got is less than or equal to our page_size,
828 # we're done with the operations listing and don't need to provide another
829 # page token
830 elif len(operations) <= page_size:
831 next_page_token = ""
832 else:
833 # Drop the last operation since we have an extra
834 operations.pop()
835 # Our page token will be the last row of our set
836 next_page_token = build_page_token(operations[-1], sort_keys)
837 return [self._load_operation(operation) for operation in operations], next_page_token
839 def list_workers(self, name_filter: str, page_number: int, page_size: int) -> Tuple[List[BotEntry], int]:
840 stmt = select(BotEntry, func.count().over().label("total"))
841 stmt = stmt.options(selectinload(BotEntry.job).selectinload(JobEntry.operations))
842 stmt = stmt.where(
843 or_(
844 BotEntry.name.ilike(f"%{name_filter}%"),
845 BotEntry.bot_id.ilike(f"%{name_filter}%"),
846 ),
847 BotEntry.instance_name == current_instance(),
848 )
849 stmt = stmt.order_by(BotEntry.bot_id)
851 if page_size:
852 stmt = stmt.limit(page_size)
853 if page_number > 1:
854 stmt = stmt.offset((page_number - 1) * page_size)
856 with self._sql.scoped_session() as session:
857 results = session.execute(stmt).all()
858 count = cast(int, results[0].total) if results else 0
859 session.expunge_all()
861 return [r[0] for r in results], count
863 def get_metrics(self) -> Optional[SchedulerMetrics]:
864 # Skip publishing overall scheduler metrics if we have recently published them
865 last_publish_time = self._last_scheduler_metrics_publish_time
866 time_since_publish = None
867 if last_publish_time:
868 time_since_publish = datetime.utcnow() - last_publish_time
869 if time_since_publish and time_since_publish < self._scheduler_metrics_publish_interval:
870 # Published too recently, skip
871 return None
873 def _get_query_leases_by_state(category: str) -> Select:
874 # Using func.count here to avoid generating a subquery in the WHERE
875 # clause of the resulting query.
876 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
877 return select(
878 [
879 literal_column(f"'{category}'").label("category"),
880 LeaseEntry.state.label("bucket"),
881 func.count(LeaseEntry.id).label("value"),
882 ]
883 ).group_by(LeaseEntry.state)
885 def _cb_query_leases_by_state(leases_by_state: Dict[Any, Any]) -> Dict[Any, Any]:
886 # The database only returns counts > 0, so fill in the gaps
887 for state in LeaseState:
888 if state.value not in leases_by_state:
889 leases_by_state[state.value] = 0
890 return leases_by_state
892 def _get_query_jobs_by_stage(category: str) -> Select:
893 # Using func.count here to avoid generating a subquery in the WHERE
894 # clause of the resulting query.
895 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
896 return select(
897 [
898 literal_column(f"'{category}'").label("category"),
899 JobEntry.stage.label("bucket"),
900 func.count(JobEntry.name).label("value"),
901 ]
902 ).group_by(JobEntry.stage)
904 def _cb_query_jobs_by_stage(jobs_by_stage: Dict[Any, Any]) -> Dict[Any, Any]:
905 # The database only returns counts > 0, so fill in the gaps
906 for stage in OperationStage:
907 if stage.value not in jobs_by_stage:
908 jobs_by_stage[stage.value] = 0
909 return jobs_by_stage
911 metrics: SchedulerMetrics = {}
912 # metrics to gather: (category_name, function_returning_query, callback_function)
913 metrics_to_gather = [
914 (MetricCategories.LEASES.value, _get_query_leases_by_state, _cb_query_leases_by_state),
915 (MetricCategories.JOBS.value, _get_query_jobs_by_stage, _cb_query_jobs_by_stage),
916 ]
918 statements = [query_fn(category) for category, query_fn, _ in metrics_to_gather]
919 metrics_statement = union(*statements)
921 try:
922 with self._sql_ro.session(exceptions_to_not_raise_on=[Exception]) as session:
923 results = session.execute(metrics_statement).all()
925 grouped_results: Dict[str, Any] = {category: {} for category, _, _ in results}
926 for category, bucket, value in results:
927 grouped_results[category][bucket] = value
929 for category, _, category_cb in metrics_to_gather:
930 metrics[category] = category_cb( # type: ignore[literal-required]
931 grouped_results.setdefault(category, {})
932 )
933 except DatabaseError:
934 LOGGER.warning("Unable to gather metrics due to a Database Error.")
935 return {}
937 # This is only updated within the metrics asyncio loop; no race conditions
938 self._last_scheduler_metrics_publish_time = datetime.utcnow()
940 return metrics
942 def _queued_jobs_by_capability(self, capability_hash: str) -> Select:
943 return (
944 select(JobEntry)
945 .with_for_update(skip_locked=True)
946 .where(
947 JobEntry.assigned != True, # noqa: E712
948 self._job_in_instance(),
949 JobEntry.platform_requirements == capability_hash,
950 JobEntry.stage == OperationStage.QUEUED.value,
951 )
952 )
954 def assign_n_leases_by_priority(
955 self,
956 *,
957 capability_hash: str,
958 bot_names: List[str],
959 ) -> List[str]:
960 job_statement = self._queued_jobs_by_capability(capability_hash).order_by(
961 JobEntry.priority, JobEntry.queued_timestamp
962 )
963 return self._assign_n_leases(job_statement=job_statement, bot_names=bot_names)
965 def assign_n_leases_by_age(
966 self,
967 *,
968 capability_hash: str,
969 bot_names: List[str],
970 ) -> List[str]:
971 job_statement = self._queued_jobs_by_capability(capability_hash).order_by(JobEntry.queued_timestamp)
972 return self._assign_n_leases(job_statement=job_statement, bot_names=bot_names)
974 @timed(METRIC.SCHEDULER.ASSIGNMENT_DURATION)
975 def _assign_n_leases(self, *, job_statement: Select, bot_names: List[str]) -> List[str]:
976 bot_statement = (
977 select(BotEntry)
978 .with_for_update(skip_locked=True)
979 .where(
980 BotEntry.lease_id.is_(None),
981 self._bot_in_instance(),
982 BotEntry.name.in_(bot_names),
983 BotEntry.expiry_time > datetime.utcnow(),
984 )
985 )
987 try:
988 with self._sql.session(sqlite_lock_immediately=True, exceptions_to_not_raise_on=[Exception]) as session:
989 jobs = session.execute(job_statement.limit(len(bot_names))).scalars().all()
990 bots = session.execute(bot_statement.limit(len(jobs))).scalars().all()
992 assigned_bot_names: List[str] = []
993 for job, bot in zip(jobs, bots):
994 job.assigned = True
995 job.queued_time_duration = int((datetime.utcnow() - job.queued_timestamp).total_seconds())
996 job.worker_start_timestamp = datetime.utcnow()
997 job.worker_completed_timestamp = None
998 bot.lease_id = job.name
999 bot.last_update_timestamp = datetime.utcnow()
1000 if job.active_leases:
1001 lease = job.active_leases[0]
1002 LOGGER.debug(
1003 "Reassigned existing lease.",
1004 tags=dict(
1005 job_name=job.name,
1006 bot_id=bot.bot_id,
1007 bot_name=bot.name,
1008 prev_lease_state=lease.state,
1009 prev_lease_status=lease.status,
1010 prev_bot_id=lease.worker_name,
1011 ),
1012 )
1013 lease.state = LeaseState.PENDING.value
1014 lease.status = None
1015 lease.worker_name = bot.bot_id
1016 else:
1017 LOGGER.debug(
1018 "Assigned new lease.", tags=dict(job_name=job.name, bot_id=bot.bot_id, bot_name=bot.name)
1019 )
1020 session.add(
1021 LeaseEntry(
1022 job_name=job.name,
1023 state=LeaseState.PENDING.value,
1024 status=None,
1025 worker_name=bot.bot_id,
1026 )
1027 )
1028 assigned_bot_names.append(bot.name)
1030 return assigned_bot_names
1031 except DatabaseError:
1032 LOGGER.warning("Will not assign any leases this time due to a Database Error.")
1033 return []
1035 def queue_timer_loop(self, shutdown_requested: threading.Event) -> None:
1036 """Periodically timeout aged queued jobs"""
1038 if not (opts := self.queue_timeout_options):
1039 return
1041 job_max_age = opts.job_max_age
1042 period = opts.handling_period
1043 limit = opts.max_handling_window
1045 last_timeout_time = datetime.utcnow()
1046 while not shutdown_requested.is_set():
1047 now = datetime.utcnow()
1048 if now - last_timeout_time < period:
1049 LOGGER.info(f"Job queue timeout thread sleeping for {period} seconds")
1050 shutdown_requested.wait(timeout=period.total_seconds())
1051 continue
1053 timeout_jobs_scheduled_before = now - job_max_age
1054 try:
1055 with timer(METRIC.SCHEDULER.QUEUE_TIMEOUT_DURATION):
1056 num_timeout = self._timeout_queued_jobs_scheduled_before(timeout_jobs_scheduled_before, limit)
1057 LOGGER.info(f"Timed-out {num_timeout} queued jobs scheduled before {timeout_jobs_scheduled_before}")
1058 if num_timeout > 0:
1059 publish_counter_metric(METRIC.SCHEDULER.QUEUE_TIMEOUT_COUNT, num_timeout)
1061 except Exception as e:
1062 LOGGER.exception("Failed to timeout aged queued jobs.", exc_info=e)
1063 finally:
1064 last_timeout_time = now
1066 def _timeout_queued_jobs_scheduled_before(self, dt: datetime, limit: int) -> int:
1067 jobs_to_timeout_stmt = (
1068 select(JobEntry)
1069 .where(JobEntry.stage == OperationStage.QUEUED.value)
1070 .where(JobEntry.queued_timestamp < dt)
1071 .limit(limit)
1072 )
1073 return self._batch_timeout_jobs(
1074 jobs_to_timeout_stmt, code_pb2.UNAVAILABLE, "Operation has been queued for too long"
1075 )
1077 def prune_timer_loop(self, shutdown_requested: threading.Event) -> None:
1078 """Running in a background thread, this method wakes up periodically and deletes older records
1079 from the jobs tables using configurable parameters"""
1081 if not (opts := self.pruning_options):
1082 return
1084 job_max_age = opts.job_max_age
1085 pruning_period = opts.handling_period
1086 limit = opts.max_handling_window
1088 utc_last_prune_time = datetime.utcnow()
1089 while not shutdown_requested.is_set():
1090 utcnow = datetime.utcnow()
1091 if (utcnow - pruning_period) < utc_last_prune_time:
1092 LOGGER.info(f"Pruner thread sleeping for {pruning_period}(until {utcnow + pruning_period})")
1093 shutdown_requested.wait(timeout=pruning_period.total_seconds())
1094 continue
1096 delete_before_datetime = utcnow - job_max_age
1097 try:
1098 num_rows = self._delete_jobs_prior_to(delete_before_datetime, limit)
1099 LOGGER.info(f"Pruned {num_rows} row(s) from the jobs table older than {delete_before_datetime}")
1100 except Exception:
1101 LOGGER.exception("Caught exception while deleting jobs records.")
1102 finally:
1103 # Update even if error occurred to avoid potentially infinitely retrying
1104 utc_last_prune_time = utcnow
1106 LOGGER.info("Exiting pruner thread.")
1108 @timed(METRIC.SCHEDULER.PRUNE_DURATION)
1109 def _delete_jobs_prior_to(self, delete_before_datetime: datetime, limit: int) -> int:
1110 """Deletes older records from the jobs tables constrained by `delete_before_datetime` and `limit`"""
1111 delete_stmt = delete(JobEntry).where(
1112 JobEntry.name.in_(
1113 select(JobEntry.name)
1114 .with_for_update(skip_locked=True)
1115 .where(JobEntry.worker_completed_timestamp <= delete_before_datetime)
1116 .limit(limit)
1117 ),
1118 )
1120 with self._sql.session() as session:
1121 options = {"synchronize_session": "fetch"}
1122 num_rows_deleted: int = session.execute(delete_stmt, execution_options=options).rowcount # type: ignore
1124 if num_rows_deleted:
1125 publish_counter_metric(METRIC.SCHEDULER.PRUNE_COUNT, num_rows_deleted)
1127 return num_rows_deleted
1129 def _insert_on_conflict_do_nothing(self, model: Type[OrmBase]) -> Insert:
1130 # `Insert.on_conflict_do_nothing` is a SQLAlchemy "generative method", it
1131 # returns a modified copy of the statement it is called on. For
1132 # some reason mypy can't understand this, so the errors are ignored here.
1133 if self._sql.dialect == "sqlite":
1134 sqlite_insert: sqlite.Insert = sqlite.insert(model)
1135 return sqlite_insert.on_conflict_do_nothing() # type: ignore
1137 elif self._sql.dialect == "postgresql":
1138 insertion: postgresql.Insert = postgresql.insert(model)
1139 return insertion.on_conflict_do_nothing() # type: ignore
1141 else:
1142 # Fall back to the non-specific insert implementation. This doesn't
1143 # support `ON CONFLICT DO NOTHING`, so callers need to be careful to
1144 # still catch IntegrityErrors if other database backends are possible.
1145 return insert(model)
1147 def get_or_create_client_identity_in_store(
1148 self, session: Session, client_id: ClientIdentityEntry
1149 ) -> ClientIdentityEntry:
1150 """Get the ClientIdentity in the storage or create one.
1151 This helper function essentially makes sure the `client_id` is created during the transaction
1153 Args:
1154 session (Session): sqlalchemy Session
1155 client_id (ClientIdentityEntry): identity of the client that creates an operation
1157 Returns:
1158 ClientIdentityEntry: identity of the client that creates an operation
1159 """
1160 insertion = self._insert_on_conflict_do_nothing(ClientIdentityEntry)
1161 insertion = insertion.values(
1162 {
1163 "instance": client_id.instance,
1164 "workflow": client_id.workflow,
1165 "actor": client_id.actor,
1166 "subject": client_id.subject,
1167 }
1168 )
1169 try:
1170 session.execute(insertion)
1172 # Handle unique constraint violation when using an unsupported database (ie. not PostgreSQL or SQLite)
1173 except IntegrityError:
1174 LOGGER.debug("Handled IntegrityError when inserting client identity.")
1176 stmt = (
1177 select(ClientIdentityEntry)
1178 .where(ClientIdentityEntry.instance == client_id.instance)
1179 .where(ClientIdentityEntry.workflow == client_id.workflow)
1180 .where(ClientIdentityEntry.actor == client_id.actor)
1181 .where(ClientIdentityEntry.subject == client_id.subject)
1182 )
1184 result: ClientIdentityEntry = session.execute(stmt).scalar_one()
1185 return result
1187 def get_or_create_request_metadata_in_store(
1188 self, session: Session, request_metadata: RequestMetadata
1189 ) -> RequestMetadataEntry:
1190 insertion = self._insert_on_conflict_do_nothing(RequestMetadataEntry)
1191 insertion = insertion.values(
1192 {
1193 "action_mnemonic": request_metadata.action_mnemonic,
1194 "configuration_id": request_metadata.configuration_id,
1195 "correlated_invocations_id": request_metadata.correlated_invocations_id,
1196 "invocation_id": request_metadata.tool_invocation_id,
1197 "target_id": request_metadata.target_id,
1198 "tool_name": request_metadata.tool_details.tool_name,
1199 "tool_version": request_metadata.tool_details.tool_version,
1200 }
1201 )
1202 try:
1203 session.execute(insertion)
1205 # Handle unique constraint violation when using an unsupported database (ie. not PostgreSQL or SQLite)
1206 except IntegrityError:
1207 LOGGER.debug("Handled IntegrityError when inserting request metadata.")
1209 stmt = (
1210 select(RequestMetadataEntry)
1211 .where(RequestMetadataEntry.action_mnemonic == request_metadata.action_mnemonic)
1212 .where(RequestMetadataEntry.configuration_id == request_metadata.configuration_id)
1213 .where(RequestMetadataEntry.correlated_invocations_id == request_metadata.correlated_invocations_id)
1214 .where(RequestMetadataEntry.invocation_id == request_metadata.tool_invocation_id)
1215 .where(RequestMetadataEntry.target_id == request_metadata.target_id)
1216 .where(RequestMetadataEntry.tool_name == request_metadata.tool_details.tool_name)
1217 .where(RequestMetadataEntry.tool_version == request_metadata.tool_details.tool_version)
1218 )
1220 result: RequestMetadataEntry = session.execute(stmt).scalar_one()
1221 return result
1223 def add_bot_entry(self, *, bot_session_id: str, bot_session_status: int) -> str:
1224 with self._sql.session() as session:
1225 # Check if bot_id is already known. If yes, all leases associated with
1226 # it are requeued and the existing record deleted. A new record is then
1227 # created with the new bot_id/name combination, as it would in the
1228 # unknown case.
1229 locate_bot_stmt = (
1230 select(BotEntry).where(BotEntry.bot_id == bot_session_id, self._bot_in_instance()).with_for_update()
1231 )
1232 self._close_bot_sessions(session, session.execute(locate_bot_stmt).scalars().all())
1234 bot_name = f"{current_instance()}/{str(uuid.uuid4())}"
1235 session.add(
1236 BotEntry(
1237 name=bot_name,
1238 bot_id=bot_session_id,
1239 last_update_timestamp=datetime.utcnow(),
1240 lease_id=None,
1241 bot_status=bot_session_status,
1242 instance_name=current_instance(),
1243 expiry_time=datetime.utcnow() + timedelta(seconds=self.bot_session_keepalive_timeout),
1244 )
1245 )
1246 return bot_name
1248 def close_bot_sessions(self, bot_name: str) -> None:
1249 with self._sql.session() as session:
1250 locate_bot_stmt = (
1251 select(BotEntry).where(BotEntry.name == bot_name, self._bot_in_instance()).with_for_update()
1252 )
1253 self._close_bot_sessions(session, session.execute(locate_bot_stmt).scalars().all())
1255 def _close_bot_sessions(self, session: Session, bots: List[BotEntry]) -> None:
1256 for bot in bots:
1257 log_tags = {
1258 "instance_name": try_current_instance(),
1259 "request.bot_name": bot.name,
1260 "request.bot_id": bot.bot_id,
1261 "request.bot_status": bot.bot_status,
1262 }
1263 LOGGER.debug("Closing bot session.", tags=log_tags)
1264 if bot.lease_id:
1265 if job := self._get_job(bot.lease_id, session, with_for_update=True):
1266 for db_lease in job.active_leases:
1267 lease_tags = {**log_tags, "db.lease_id": job.name, "db.lease_state": db_lease.state}
1268 LOGGER.debug("Reassigning lease for bot session.", tags=lease_tags)
1269 self._retry_job_lease(session, job, db_lease)
1270 self._notify_job_updated(job.name, session)
1271 session.delete(bot)
1273 def session_expiry_timer_loop(self, shutdown_requested: threading.Event) -> None:
1274 LOGGER.info("Starting BotSession reaper.", tags=dict(keepalive_timeout=self.bot_session_keepalive_timeout))
1275 while not shutdown_requested.is_set():
1276 try:
1277 while self.reap_expired_sessions():
1278 if shutdown_requested.is_set():
1279 break
1280 except Exception as exception:
1281 LOGGER.exception(exception)
1282 shutdown_requested.wait(timeout=self.poll_interval)
1284 def reap_expired_sessions(self) -> bool:
1285 """
1286 Find and close expired bot sessions. Returns True if sessions were closed.
1287 Only closes a few sessions to minimize time in transaction.
1288 """
1290 with self._sql.session() as session:
1291 locate_bot_stmt = (
1292 select(BotEntry)
1293 .where(BotEntry.expiry_time < datetime.utcnow())
1294 .order_by(BotEntry.expiry_time.desc())
1295 .with_for_update(skip_locked=True)
1296 .limit(5)
1297 )
1298 if bots := cast(List[BotEntry], session.execute(locate_bot_stmt).scalars().all()):
1299 bots_by_instance: Dict[str, List[BotEntry]] = defaultdict(list)
1300 for bot in bots:
1301 LOGGER.warning(
1302 "BotSession has expired.",
1303 tags=dict(
1304 name=bot.name, bot_id=bot.bot_id, instance_name=bot.instance_name, deadline=bot.expiry_time
1305 ),
1306 )
1307 bots_by_instance[bot.instance_name].append(bot)
1308 for instance_name, instance_bots in bots_by_instance.items():
1309 with instance_context(instance_name):
1310 self._close_bot_sessions(session, instance_bots)
1311 return True
1312 return False
1314 @timed(METRIC.SCHEDULER.SYNCHRONIZE_DURATION)
1315 def synchronize_bot_lease(
1316 self, bot_name: str, bot_id: str, bot_status: int, session_lease: Optional[Lease]
1317 ) -> Optional[Lease]:
1318 log_tags = {
1319 "instance_name": try_current_instance(),
1320 "request.bot_id": bot_id,
1321 "request.bot_status": bot_status,
1322 "request.bot_name": bot_name,
1323 "request.lease_id": session_lease.id if session_lease else "",
1324 "request.lease_state": session_lease.state if session_lease else "",
1325 }
1327 with self._sql.session(exceptions_to_not_raise_on=[Exception]) as session:
1328 locate_bot_stmt = (
1329 select(BotEntry).where(BotEntry.bot_id == bot_id, self._bot_in_instance()).with_for_update()
1330 )
1331 bots: List[BotEntry] = session.execute(locate_bot_stmt).scalars().all()
1332 if not bots:
1333 raise InvalidArgumentError(f"Bot does not exist while validating leases. {log_tags}")
1335 # This is a tricky case. This case happens when a new bot session is created while an older
1336 # session for a bot id is waiting on leases. This can happen when a worker reboots but the
1337 # connection context takes a long time to close. In this case, we DO NOT want to update anything
1338 # in the database, because the work/lease has already been re-assigned to a new session.
1339 # Closing anything in the database at this point would cause the newly restarted worker
1340 # to get cancelled prematurely.
1341 if len(bots) == 1 and bots[0].name != bot_name:
1342 raise BotSessionMismatchError(
1343 "Mismatch between client supplied bot_id/bot_name and buildgrid database record. "
1344 f"db.bot_name=[{bots[0].name}] {log_tags}"
1345 )
1347 # Everything at this point is wrapped in try/catch, so we can raise BotSessionMismatchError or
1348 # BotSessionClosedError and have the session be closed if preconditions from here out fail.
1349 try:
1350 # There should never be time when two bot sessions exist for the same bot id. We have logic to
1351 # assert that old database entries for a given bot id are closed and deleted prior to making a
1352 # new one. If this case happens shut everything down, so we can hopefully recover.
1353 if len(bots) > 1:
1354 raise BotSessionMismatchError(
1355 "Bot id is registered to more than one bot session. "
1356 f"names=[{', '.join(bot.name for bot in bots)}] {log_tags}"
1357 )
1359 bot = bots[0]
1360 log_tags["db.lease_id"] = bot.lease_id
1362 # Validate that the lease_id matches the client and database if both are supplied.
1363 if (session_lease and session_lease.id and bot.lease_id) and (session_lease.id != bot.lease_id):
1364 raise BotSessionMismatchError(
1365 f"Mismatch between client supplied lease_id and buildgrid database record. {log_tags}"
1366 )
1368 # Update the expiry time.
1369 bot.expiry_time = datetime.utcnow() + timedelta(seconds=self.bot_session_keepalive_timeout)
1370 bot.last_update_timestamp = datetime.utcnow()
1371 bot.bot_status = bot_status
1373 # Validate the cases where the database doesn't know about any leases.
1374 if bot.lease_id is None:
1375 # If there's no lease in the database or session, we have nothing to update!
1376 if not session_lease:
1377 LOGGER.debug("No lease in session or database. Skipping.", tags=log_tags)
1378 return None
1380 # If the database has no lease, but the work is completed, we probably timed out the last call.
1381 if session_lease.state == LeaseState.COMPLETED.value:
1382 LOGGER.debug("No lease in database, but session lease is completed. Skipping.", tags=log_tags)
1383 return None
1385 # Otherwise, the bot session has a lease that the server doesn't know about. Bad bad bad.
1386 raise BotSessionClosedError(f"Bot session lease id does not match the database. {log_tags}")
1388 # Let's now lock the job so no more state transitions occur while we perform our updates.
1389 job = self._get_job(bot.lease_id, session, with_for_update=True)
1390 if not job:
1391 raise BotSessionClosedError(f"Bot session lease id points to non-existent job. {log_tags}")
1393 # If we don't have any leases assigned to the job now, someone interrupted us before locking.
1394 # Disconnect our bot from mutating this job.
1395 if not job.leases:
1396 raise BotSessionClosedError(f"Leases were changed while job was being locked. {log_tags}")
1398 db_lease = job.leases[0]
1399 log_tags["db.lease_state"] = db_lease.state
1401 # Assign:
1402 #
1403 # If the lease is in the PENDING state, this means that it is a new lease for the worker, which
1404 # it must acknowledge (the next time it calls UpdateBotSession) by changing the state to ACTIVE.
1405 #
1406 # Leases contain a “payload,” which is an Any proto that must be understandable to the bot.
1407 #
1408 # If at any time the bot issues a call to UpdateBotSession that is inconsistent with what the service
1409 # expects, the service can take appropriate action. For example, the service may have assigned a
1410 # lease to a bot, but the call gets interrupted before the bot receives the message, perhaps because
1411 # the UpdateBotSession call times out. As a result, the next call to UpdateBotSession from the bot
1412 # will not include the lease, and the service can immediately conclude that the lease needs to be
1413 # reassigned.
1414 #
1415 if not session_lease:
1416 if db_lease.state != LeaseState.PENDING.value:
1417 raise BotSessionClosedError(
1418 f"Session has no lease and database entry not in pending state. {log_tags}"
1419 )
1421 job.stage = OperationStage.EXECUTING.value
1422 if self.logstream_channel and self.logstream_instance is not None:
1423 try:
1424 action_digest = string_to_digest(job.action_digest)
1425 parent_base = f"{action_digest.hash}_{action_digest.size_bytes}_{int(time())}"
1426 with logstream_client(self.logstream_channel, self.logstream_instance) as ls_client:
1427 stdout_stream = ls_client.create(f"{parent_base}_stdout")
1428 stderr_stream = ls_client.create(f"{parent_base}_stderr")
1429 job.stdout_stream_name = stdout_stream.name
1430 job.stdout_stream_write_name = stdout_stream.write_resource_name
1431 job.stderr_stream_name = stderr_stream.name
1432 job.stderr_stream_write_name = stderr_stream.write_resource_name
1433 except Exception as e:
1434 LOGGER.warning("Failed to create log stream.", tags=log_tags, exc_info=e)
1436 self._notify_job_updated(job.name, session)
1437 LOGGER.debug("Pending lease sent to bot for ack.", tags=log_tags)
1438 return db_lease.to_protobuf()
1440 # At this point, we know that there's a lease both in the bot session and in the database.
1442 # Accept:
1443 #
1444 # If the lease is in the PENDING state, this means that it is a new lease for the worker,
1445 # which it must acknowledge (the next time it calls UpdateBotSession) by changing the state to ACTIVE
1446 #
1447 if session_lease.state == LeaseState.ACTIVE.value and db_lease.state == LeaseState.PENDING.value:
1448 db_lease.state = LeaseState.ACTIVE.value
1449 self._notify_job_updated(job.name, session)
1450 LOGGER.debug("Bot acked pending lease.", tags=log_tags)
1451 return session_lease
1453 # Complete:
1454 #
1455 # Once the assignment is complete - either because it finishes or because it times out - the bot
1456 # calls Bots.UpdateBotSession again, this time updating the state of the lease from accepted to
1457 # complete, and optionally by also populating the lease’s results field, which is another Any proto.
1458 # The service can then assign it new work (removing any completed leases).
1459 #
1460 # A successfully completed lease may go directly from PENDING to COMPLETED if, for example, the
1461 # lease was completed before the bot has had the opportunity to transition to ACTIVE, or if the
1462 # update transitioning the lease to the ACTIVE state was lost.
1463 #
1464 if session_lease.state == LeaseState.COMPLETED.value and db_lease.state in (
1465 LeaseState.PENDING.value,
1466 LeaseState.ACTIVE.value,
1467 ):
1468 log_tags["request.lease_status_code"] = session_lease.status.code
1469 log_tags["request.lease_status_message"] = session_lease.status.message
1470 log_tags["db.n_tries"] = job.n_tries
1472 bot.lease_id = None
1473 if (
1474 session_lease.status.code in self.RETRYABLE_STATUS_CODES
1475 and job.n_tries < self.max_job_attempts
1476 ):
1477 LOGGER.debug("Retrying bot lease.", tags=log_tags)
1478 self._retry_job_lease(session, job, db_lease)
1479 else:
1480 LOGGER.debug("Bot completed lease.", tags=log_tags)
1481 self._complete_lease(session, job, db_lease, session_lease.status, session_lease.result)
1483 self._notify_job_updated(job.name, session)
1484 return None
1486 # Cancel:
1487 #
1488 # At any time, the service may change the state of a lease from PENDING or ACTIVE to CANCELLED;
1489 # the bot may not change to this state. The service then waits for the bot to acknowledge the
1490 # change by updating its own status to CANCELLED as well. Once both the service and the bot agree,
1491 # the service may remove it from the list of leases.
1492 #
1493 if session_lease.state == db_lease.state == LeaseState.CANCELLED.value:
1494 bot.lease_id = None
1495 LOGGER.debug("Bot acked cancelled lease.", tags=log_tags)
1496 return None
1498 if db_lease.state == LeaseState.CANCELLED.value:
1499 session_lease.state = LeaseState.CANCELLED.value
1500 LOGGER.debug("Cancelled lease sent to bot for ack.", tags=log_tags)
1501 return session_lease
1503 if session_lease.state == LeaseState.CANCELLED.value:
1504 raise BotSessionClosedError(f"Illegal attempt from session to set state as cancelled. {log_tags}")
1506 # Keepalive:
1507 #
1508 # The Bot periodically calls Bots.UpdateBotSession, either if there’s a genuine change (for example,
1509 # an attached phone has died) or simply to let the service know that it’s alive and ready to receive
1510 # work. If the bot doesn’t call back on time, the service considers it to have died, and all work
1511 # from the bot to be lost.
1512 #
1513 if session_lease.state == db_lease.state:
1514 LOGGER.debug("Bot heartbeat acked.", tags=log_tags)
1515 return session_lease
1517 # Any other transition should really never happen... cover it anyways.
1518 raise BotSessionClosedError(f"Unsupported lease state transition. {log_tags}")
1519 # TODO allow creating a session with manual commit logic.
1520 # For now... Sneak the exception past the context manager.
1521 except (BotSessionMismatchError, BotSessionClosedError) as e:
1522 self._close_bot_sessions(session, bots)
1523 err = e
1524 raise err
1526 def _retry_job_lease(self, session: Session, job: JobEntry, lease: LeaseEntry) -> None:
1527 # If the job was mutated before we could lock it, exit fast on terminal states.
1528 if job.cancelled or job.stage == OperationStage.COMPLETED.value:
1529 return
1531 if job.n_tries >= self.max_job_attempts:
1532 status = status_pb2.Status(
1533 code=code_pb2.ABORTED, message=f"Job was retried {job.n_tries} unsuccessfully. Aborting."
1534 )
1535 self._complete_lease(session, job, lease, status=status)
1536 return
1538 job.stage = OperationStage.QUEUED.value
1539 job.assigned = False
1540 job.n_tries += 1
1542 lease.state = LeaseState.PENDING.value
1543 lease.status = None
1544 lease.worker_name = None
1546 def _complete_lease(
1547 self, session: Session, job: JobEntry, lease: LeaseEntry, status: Status, result: Optional[ProtoAny] = None
1548 ) -> None:
1549 lease.state = LeaseState.COMPLETED.value
1550 lease.status = status.code
1552 job.stage = OperationStage.COMPLETED.value
1553 job.status_code = status.code
1554 if not job.do_not_cache:
1555 job.do_not_cache = status.code != code_pb2.OK
1556 job.worker_completed_timestamp = datetime.utcnow()
1558 action_result = ActionResult()
1559 if result is not None and result.Is(action_result.DESCRIPTOR):
1560 result.Unpack(action_result)
1561 now = datetime.utcnow()
1562 action_result.execution_metadata.queued_timestamp.FromDatetime(job.queued_timestamp)
1563 action_result.execution_metadata.worker_start_timestamp.FromDatetime(job.worker_start_timestamp or now)
1564 action_result.execution_metadata.worker_completed_timestamp.FromDatetime(job.worker_completed_timestamp or now)
1565 response = ExecuteResponse(result=action_result, cached_result=False, status=status)
1567 job.result = digest_to_string(self.storage.put_message(response))
1569 if self.action_cache and result and not job.do_not_cache:
1570 action_digest = string_to_digest(job.action_digest)
1571 try:
1572 self.action_cache.update_action_result(action_digest, action_result)
1573 LOGGER.debug(
1574 "Stored action result in ActionCache.",
1575 tags=dict(action_result=action_result, digest=action_digest),
1576 )
1577 except UpdateNotAllowedError:
1578 # The configuration doesn't allow updating the old result
1579 LOGGER.exception(
1580 "ActionCache is not configured to allow updates, ActionResult wasn't updated.",
1581 tags=dict(digest=action_digest),
1582 )
1583 except Exception:
1584 LOGGER.exception(
1585 "Unable to update ActionCache, results will not be stored in the ActionCache.",
1586 tags=dict(digest=action_digest),
1587 )
1589 # Update retentions
1590 self._update_action_retention(
1591 Action.FromString(job.action),
1592 string_to_digest(job.action_digest),
1593 retention_hours=self.completed_action_retention_hours,
1594 )
1595 if action_result.ByteSize() > 0:
1596 self._update_action_result_retention(action_result, retention_hours=self.action_result_retention_hours)
1598 self._publish_execution_stats(session, job.name, action_result.execution_metadata)
1600 def count_bots_by_status(self) -> Dict[BotStatus, int]:
1601 """Count the number of bots with a particular status"""
1602 with self._sql.session() as session:
1603 query = (
1604 session.query(BotEntry.bot_status, func.count(BotEntry.bot_status))
1605 .group_by(BotEntry.bot_status)
1606 .filter(self._bot_in_instance())
1607 )
1609 result = {status: 0 for status in BotStatus}
1610 for [bot_status, count] in query.all():
1611 result[BotStatus(bot_status or 0)] = count
1612 return result
1614 def refresh_bot_expiry_time(self, bot_name: str, bot_id: str) -> datetime:
1615 """
1616 This update is done out-of-band from the main synchronize_bot_lease transaction, as there
1617 are cases where we will skip calling the synchronization, but still want the session to be
1618 updated such that it does not get reaped. This slightly duplicates the update happening in
1619 synchronize_bot_lease, however, that update is still required to not have the job reaped
1620 during its job assignment waiting period.
1622 This method should be called at the end of the update and create bot session methods.
1623 The returned datetime should be assigned to the deadline within the returned session proto.
1624 """
1626 locate_bot_stmt = (
1627 select(BotEntry)
1628 .where(BotEntry.name == bot_name, BotEntry.bot_id == bot_id, self._bot_in_instance())
1629 .with_for_update()
1630 )
1631 with self._sql.session() as session:
1632 if bot := session.execute(locate_bot_stmt).scalar():
1633 now = datetime.utcnow()
1634 bot.last_update_timestamp = now
1635 bot.expiry_time = now + timedelta(seconds=self.bot_session_keepalive_timeout)
1636 return bot.expiry_time
1637 raise BotSessionClosedError("Bot not found to fetch expiry. {bot_name=} {bot_id=}")
1639 def get_metadata_for_leases(self, leases: Iterable[Lease]) -> List[Tuple[str, bytes]]:
1640 """Return a list of Job metadata for a given list of leases.
1642 Args:
1643 leases (list): List of leases to get Job metadata for.
1645 Returns:
1646 List of tuples of the form
1647 ``('executeoperationmetadata-bin': serialized_metadata)``.
1649 """
1650 metadata = []
1651 with self._sql_ro.session() as session:
1652 for lease in leases:
1653 job = self._get_job(lease.id, session)
1654 if job is not None:
1655 job_metadata = ExecuteOperationMetadata(
1656 stage=job.stage, # type: ignore[arg-type]
1657 action_digest=string_to_digest(job.action_digest),
1658 stderr_stream_name=job.stderr_stream_write_name or "",
1659 stdout_stream_name=job.stdout_stream_write_name or "",
1660 partial_execution_metadata=self.get_execute_action_metadata(job),
1661 )
1662 metadata.append(("executeoperationmetadata-bin", job_metadata.SerializeToString()))
1664 return metadata
1666 def get_execute_action_metadata(self, job: JobEntry) -> ExecutedActionMetadata:
1667 worker_name = ""
1668 if job.leases:
1669 worker_name = job.leases[-1].worker_name or ""
1670 executed_action_metadata = ExecutedActionMetadata(worker=worker_name)
1671 executed_action_metadata.queued_timestamp.FromDatetime(job.queued_timestamp)
1672 if job.worker_start_timestamp is not None:
1673 executed_action_metadata.worker_start_timestamp.FromDatetime(job.worker_start_timestamp)
1674 if job.worker_completed_timestamp is not None:
1675 executed_action_metadata.worker_completed_timestamp.FromDatetime(job.worker_completed_timestamp)
1677 return executed_action_metadata
1679 def _fetch_execution_stats(
1680 self, auxiliary_metadata: RepeatedCompositeFieldContainer[ProtoAny]
1681 ) -> Optional[ExecutionStatistics]:
1682 """Fetch ExecutionStatistics from Storage
1683 ProtoAny[Digest] -> ProtoAny[ExecutionStatistics]
1684 """
1685 for aux_metadata_any in auxiliary_metadata:
1686 # Get the wrapped digest
1687 if not aux_metadata_any.Is(Digest.DESCRIPTOR):
1688 continue
1689 aux_metadata_digest = Digest()
1690 try:
1691 aux_metadata_any.Unpack(aux_metadata_digest)
1692 # Get the blob from CAS
1693 execution_stats_any = self.storage.get_message(aux_metadata_digest, ProtoAny)
1694 # Get the wrapped ExecutionStatistics
1695 if execution_stats_any and execution_stats_any.Is(ExecutionStatistics.DESCRIPTOR):
1696 execution_stats = ExecutionStatistics()
1697 execution_stats_any.Unpack(execution_stats)
1698 return execution_stats
1699 except Exception as exc:
1700 LOGGER.exception(
1701 "Cannot fetch ExecutionStatistics from storage.",
1702 tags=dict(auxiliary_metadata=aux_metadata_digest),
1703 exc_info=exc,
1704 )
1705 return None
1706 return None
1708 def publish_execution_stats(self, job_name: str, execution_metadata: ExecutedActionMetadata) -> None:
1709 with self._sql_ro.session(expire_on_commit=False) as session:
1710 self._publish_execution_stats(session, job_name, execution_metadata)
1712 def _publish_execution_stats(
1713 self, session: Session, job_name: str, execution_metadata: ExecutedActionMetadata
1714 ) -> None:
1715 """Publish resource usage of the job"""
1716 queued = execution_metadata.queued_timestamp.ToDatetime()
1717 worker_start = execution_metadata.worker_start_timestamp.ToDatetime()
1718 worker_completed = execution_metadata.worker_completed_timestamp.ToDatetime()
1719 fetch_start = execution_metadata.input_fetch_start_timestamp.ToDatetime()
1720 fetch_completed = execution_metadata.input_fetch_completed_timestamp.ToDatetime()
1721 execution_start = execution_metadata.execution_start_timestamp.ToDatetime()
1722 execution_completed = execution_metadata.execution_completed_timestamp.ToDatetime()
1723 upload_start = execution_metadata.output_upload_start_timestamp.ToDatetime()
1724 upload_completed = execution_metadata.output_upload_completed_timestamp.ToDatetime()
1726 timings = {
1727 "Total": worker_completed - queued,
1728 # This calculates the queue time based purely on values set in the ActionResult's
1729 # ExecutedActionMetadata, which may differ than what the job object's queued_time is.
1730 "Queued": worker_start - queued,
1731 "Worker": worker_completed - worker_start,
1732 "Fetch": fetch_completed - fetch_start,
1733 "Execution": execution_completed - execution_start,
1734 "Upload": upload_completed - upload_start,
1735 }
1736 for tag, value in timings.items():
1737 publish_timer_metric(METRIC.JOB.DURATION, value, state=tag)
1739 if self.metering_client is None or len(execution_metadata.auxiliary_metadata) == 0:
1740 return
1742 execution_stats = self._fetch_execution_stats(execution_metadata.auxiliary_metadata)
1743 if execution_stats is None:
1744 return
1745 usage = Usage(
1746 computing=ComputingUsage(
1747 utime=execution_stats.command_rusage.utime.ToMilliseconds(),
1748 stime=execution_stats.command_rusage.stime.ToMilliseconds(),
1749 maxrss=execution_stats.command_rusage.maxrss,
1750 inblock=execution_stats.command_rusage.inblock,
1751 oublock=execution_stats.command_rusage.oublock,
1752 )
1753 )
1755 try:
1756 operations = (
1757 session.query(OperationEntry)
1758 .where(OperationEntry.job_name == job_name)
1759 .options(joinedload(OperationEntry.client_identity))
1760 .all()
1761 )
1762 for op in operations:
1763 if op.client_identity is None:
1764 continue
1765 client_id = Identity(
1766 instance=op.client_identity.instance,
1767 workflow=op.client_identity.workflow,
1768 actor=op.client_identity.actor,
1769 subject=op.client_identity.subject,
1770 )
1771 self.metering_client.put_usage(identity=client_id, operation_name=op.name, usage=usage)
1772 except Exception as exc:
1773 LOGGER.exception("Cannot publish resource usage.", tags=dict(job_name=job_name), exc_info=exc)
1775 def _update_action_retention(
1776 self, action: Action, action_digest: Digest, retention_hours: Optional[float]
1777 ) -> None:
1778 if not self.asset_client or not retention_hours:
1779 return
1780 uri = DIGEST_URI_TEMPLATE.format(digest_hash=action_digest.hash)
1781 qualifier = {"resource_type": PROTOBUF_MEDIA_TYPE}
1782 expire_at = datetime.now() + timedelta(hours=retention_hours)
1783 referenced_blobs = [action.command_digest]
1784 referenced_directories = [action.input_root_digest]
1786 try:
1787 self.asset_client.push_blob(
1788 uris=[uri],
1789 qualifiers=qualifier,
1790 blob_digest=action_digest,
1791 expire_at=expire_at,
1792 referenced_blobs=referenced_blobs,
1793 referenced_directories=referenced_directories,
1794 )
1795 LOGGER.debug(
1796 "Extended the retention of action.", tags=dict(digest=action_digest, retention_hours=retention_hours)
1797 )
1798 except Exception:
1799 LOGGER.exception("Failed to push action as an asset.", tags=dict(digest=action_digest))
1800 # Not a fatal path, don't reraise here
1802 def _update_action_result_retention(self, action_result: ActionResult, retention_hours: Optional[float]) -> None:
1803 if not self.asset_client or not retention_hours:
1804 return
1805 digest = None
1806 try:
1807 # BuildGrid doesn't store action_result in CAS, but if we push it as an asset
1808 # we need it to be accessible
1809 digest = self.storage.put_message(action_result)
1811 uri = DIGEST_URI_TEMPLATE.format(digest_hash=digest.hash)
1812 qualifier = {"resource_type": PROTOBUF_MEDIA_TYPE}
1813 expire_at = datetime.now() + timedelta(hours=retention_hours)
1815 referenced_blobs: List[Digest] = []
1816 referenced_directories: List[Digest] = []
1818 for file in action_result.output_files:
1819 referenced_blobs.append(file.digest)
1820 for dir in action_result.output_directories:
1821 # Caveat: the underlying directories referenced by this `Tree` message are not referenced by this asset.
1822 # For clients who need to keep all referenced outputs,
1823 # consider setting `Action.output_directory_format` as `DIRECTORY_ONLY` or `TREE_AND_DIRECTORY`.
1824 if dir.tree_digest.ByteSize() != 0:
1825 referenced_blobs.append(dir.tree_digest)
1826 if dir.root_directory_digest.ByteSize() != 0:
1827 referenced_directories.append(dir.root_directory_digest)
1829 if action_result.stdout_digest.ByteSize() != 0:
1830 referenced_blobs.append(action_result.stdout_digest)
1831 if action_result.stderr_digest.ByteSize() != 0:
1832 referenced_blobs.append(action_result.stderr_digest)
1834 self.asset_client.push_blob(
1835 uris=[uri],
1836 qualifiers=qualifier,
1837 blob_digest=digest,
1838 expire_at=expire_at,
1839 referenced_blobs=referenced_blobs,
1840 referenced_directories=referenced_directories,
1841 )
1842 LOGGER.debug(
1843 "Extended the retention of action result.", tags=dict(digest=digest, retention_hours=retention_hours)
1844 )
1846 except Exception as e:
1847 LOGGER.exception("Failed to push action_result as an asset.", tags=dict(digest=digest), exc_info=e)
1848 # Not a fatal path, don't reraise here