Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/persistence/sql/impl.py: 76.99%
552 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-22 21:04 +0000
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-22 21:04 +0000
1# Copyright (C) 2019 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from contextlib import contextmanager
17import logging
18import os
19import select
20from threading import Thread
21import time
22from datetime import datetime, timedelta
23from tempfile import NamedTemporaryFile
24from itertools import chain, combinations
25from typing import Dict, Iterable, List, Tuple, NamedTuple, Optional
27from alembic import command
28from alembic.config import Config
29from sqlalchemy import and_, create_engine, delete, event, func, union, literal_column, update
30from sqlalchemy import select as sql_select
31from sqlalchemy.future import Connection
32from sqlalchemy.orm.session import sessionmaker
33from sqlalchemy.sql.expression import Select
35from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
36from buildgrid._protos.google.longrunning import operations_pb2
37from buildgrid._enums import LeaseState, MetricCategories, OperationStage
38from buildgrid.server.sql import sqlutils
39from buildgrid.server.metrics_names import (
40 BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME,
41 DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME,
42 DATA_STORE_CREATE_JOB_TIME_METRIC_NAME,
43 DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME,
44 DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME,
45 DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME,
46 DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME,
47 DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME,
48 DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME,
49 DATA_STORE_PRUNER_NUM_ROWS_DELETED_METRIC_NAME,
50 DATA_STORE_PRUNER_DELETE_TIME_METRIC_NAME,
51 DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME,
52 DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME,
53 DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME,
54 DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME,
55 DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME
56)
57from buildgrid.server.job_metrics import JobMetrics
58from buildgrid.server.metrics_utils import DurationMetric, publish_timer_metric, Counter
59from buildgrid.server.operations.filtering import OperationFilter, SortKey, DEFAULT_SORT_KEYS
60from buildgrid.server.persistence.interface import DataStoreInterface
61from buildgrid.server.persistence.sql.models import digest_to_string, Job, Lease, Operation
62from buildgrid.server.persistence.sql.utils import (
63 build_page_filter,
64 build_page_token,
65 extract_sort_keys,
66 build_custom_filters,
67 build_sort_column_list
68)
69from buildgrid.settings import (
70 MAX_JOB_BLOCK_TIME,
71 MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES,
72 COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS,
73 SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS
74)
75from buildgrid.utils import JobState, hash_from_dict, convert_values_to_sorted_lists
77from buildgrid._exceptions import DatabaseError, RetriableDatabaseError
80Session = sessionmaker(future=True)
83def sqlite_on_connect(conn, record):
84 conn.execute("PRAGMA journal_mode=WAL")
85 conn.execute("PRAGMA synchronous=NORMAL")
88class PruningOptions(NamedTuple):
89 pruner_job_max_age: timedelta = timedelta(days=30)
90 pruner_period: timedelta = timedelta(minutes=5)
91 pruner_max_delete_window: int = 10000
93 @staticmethod
94 def from_config(pruner_job_max_age_cfg: Dict[str, float],
95 pruner_period_cfg: Dict[str, float] = None,
96 pruner_max_delete_window_cfg: int = None):
97 """ Helper method for creating ``PruningOptions`` objects
98 If input configs are None, assign defaults """
99 def _dict_to_timedelta(config: Dict[str, float]) -> timedelta:
100 return timedelta(weeks=config.get('weeks', 0),
101 days=config.get('days', 0),
102 hours=config.get('hours', 0),
103 minutes=config.get('minutes', 0),
104 seconds=config.get('seconds', 0))
106 return PruningOptions(pruner_job_max_age=_dict_to_timedelta(
107 pruner_job_max_age_cfg) if pruner_job_max_age_cfg else timedelta(days=30),
108 pruner_period=_dict_to_timedelta(
109 pruner_period_cfg) if pruner_period_cfg else timedelta(minutes=5),
110 pruner_max_delete_window=pruner_max_delete_window_cfg
111 if pruner_max_delete_window_cfg else 10000)
114class SQLDataStore(DataStoreInterface):
116 def __init__(self, storage, *, connection_string=None, automigrate=False,
117 connection_timeout=5, poll_interval=1,
118 pruning_options: Optional[PruningOptions] = None,
119 **kwargs):
120 super().__init__(storage)
121 self.__logger = logging.getLogger(__name__)
122 self.__logger.info("Creating SQL scheduler with: "
123 f"automigrate=[{automigrate}], connection_timeout=[{connection_timeout}] "
124 f"poll_interval=[{poll_interval}], "
125 f"pruning_options=[{pruning_options}], "
126 f"kwargs=[{kwargs}]")
128 self.response_cache: Dict[str, remote_execution_pb2.ExecuteResponse] = {}
129 self.connection_timeout = connection_timeout
130 self.poll_interval = poll_interval
131 self.watcher = Thread(name="JobWatcher", target=self.wait_for_job_updates, daemon=True)
132 self.watcher_keep_running = True
134 # Set-up temporary SQLite Database when connection string is not specified
135 if not connection_string:
136 # pylint: disable=consider-using-with
137 tmpdbfile = NamedTemporaryFile(prefix='bgd-', suffix='.db')
138 self._tmpdbfile = tmpdbfile # Make sure to keep this tempfile for the lifetime of this object
139 self.__logger.warning("No connection string specified for the DataStore, "
140 f"will use SQLite with tempfile: [{tmpdbfile.name}]")
141 automigrate = True # since this is a temporary database, we always need to create it
142 connection_string = f"sqlite:///{tmpdbfile.name}"
144 self._create_sqlalchemy_engine(connection_string, automigrate, connection_timeout, **kwargs)
146 self._sql_pool_dispose_helper = sqlutils.SQLPoolDisposeHelper(COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS,
147 MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES,
148 self.engine)
150 # Make a test query against the database to ensure the connection is valid
151 with self.session(reraise=True) as session:
152 session.execute(sql_select([1])).all()
154 self.watcher.start()
156 self.capabilities_cache: Dict[str, List[str]] = {}
158 # Pruning configuration parameters
159 if pruning_options is not None:
160 self.pruner_keep_running = True
161 self.__logger.info(f"Scheduler pruning enabled: {pruning_options}")
162 self.__pruner_thread = Thread(name="JobsPruner", target=self._do_prune, args=(
163 pruning_options.pruner_job_max_age, pruning_options.pruner_period,
164 pruning_options.pruner_max_delete_window), daemon=True)
165 self.__pruner_thread.start()
166 else:
167 self.__logger.info("Scheduler pruning not enabled")
169 # Overall Scheduler Metrics (totals of jobs/leases in each state)
170 # Publish those metrics a bit more sparsely since the SQL requests
171 # required to gather them can become expensive
172 self.__last_scheduler_metrics_publish_time = None
173 self.__scheduler_metrics_publish_interval = timedelta(
174 seconds=SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS)
176 def _create_sqlalchemy_engine(self, connection_string, automigrate, connection_timeout, **kwargs):
177 self.automigrate = automigrate
179 # Disallow sqlite in-memory because multi-threaded access to it is
180 # complex and potentially problematic at best
181 # ref: https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#threading-pooling-behavior
182 if sqlutils.is_sqlite_inmemory_connection_string(connection_string):
183 raise ValueError(
184 f"Cannot use SQLite in-memory with BuildGrid (connection_string=[{connection_string}]). "
185 "Use a file or leave the connection_string empty for a tempfile.")
187 if connection_timeout is not None:
188 if "connect_args" not in kwargs:
189 kwargs["connect_args"] = {}
190 if sqlutils.is_sqlite_connection_string(connection_string):
191 kwargs["connect_args"]["timeout"] = connection_timeout
192 elif sqlutils.is_psycopg2_connection_string(connection_string):
193 kwargs["connect_args"]["connect_timeout"] = connection_timeout
194 # Additional postgres specific timeouts
195 # Additional libpg options
196 # Note that those timeouts are in milliseconds (so *1000)
197 kwargs["connect_args"]["options"] = f'-c lock_timeout={connection_timeout * 1000}'
199 # Only pass the (known) kwargs that have been explicitly set by the user
200 available_options = set([
201 'pool_size', 'max_overflow', 'pool_timeout', 'pool_pre_ping',
202 'pool_recycle', 'connect_args'
203 ])
204 kwargs_keys = set(kwargs.keys())
205 if not kwargs_keys.issubset(available_options):
206 unknown_options = kwargs_keys - available_options
207 raise TypeError(f"Unknown keyword arguments: [{unknown_options}]")
209 self.__logger.debug(f"SQLAlchemy additional kwargs: [{kwargs}]")
211 self.engine = create_engine(connection_string, echo=False, future=True, **kwargs)
212 Session.configure(bind=self.engine)
214 if self.engine.dialect.name == "sqlite":
215 event.listen(self.engine, "connect", sqlite_on_connect)
217 if self.automigrate:
218 self._create_or_migrate_db(connection_string)
220 def __repr__(self):
221 return f"SQL data store interface for `{repr(self.engine.url)}`"
223 def activate_monitoring(self):
224 # Don't do anything. This function needs to exist but there's no
225 # need to actually toggle monitoring in this implementation.
226 pass
228 def deactivate_monitoring(self):
229 # Don't do anything. This function needs to exist but there's no
230 # need to actually toggle monitoring in this implementation.
231 pass
233 def _create_or_migrate_db(self, connection_string):
234 self.__logger.warning("Will attempt migration to latest version if needed.")
236 config = Config()
237 config.set_main_option("script_location", os.path.join(os.path.dirname(__file__), "alembic"))
239 with self.engine.begin() as connection:
240 config.attributes['connection'] = connection
241 command.upgrade(config, "head")
243 @contextmanager
244 def session(self, *, sqlite_lock_immediately=False, reraise=False):
245 # If we recently disposed of the SQL pool due to connection issues
246 # allow for some cooldown period before we attempt more SQL
247 self._sql_pool_dispose_helper.wait_if_cooldown_in_effect()
249 # Try to obtain a session
250 try:
251 session = Session()
252 if sqlite_lock_immediately and session.bind.name == "sqlite":
253 session.execute("BEGIN IMMEDIATE")
254 except Exception as e:
255 self.__logger.error("Unable to obtain a database session.", exc_info=True)
256 raise DatabaseError("Unable to obtain a database session.") from e
258 # Yield the session and catch exceptions that occur while using it
259 # to roll-back if needed
260 try:
261 yield session
262 session.commit()
263 except Exception as e:
264 transient_dberr = self._sql_pool_dispose_helper.check_dispose_pool(session, e)
265 if transient_dberr:
266 self.__logger.warning("Rolling back database session due to transient database error.", exc_info=True)
267 else:
268 self.__logger.error("Error committing database session. Rolling back.", exc_info=True)
269 try:
270 session.rollback()
271 except Exception:
272 self.__logger.warning("Rollback error.", exc_info=True)
274 if reraise:
275 if transient_dberr:
276 raise RetriableDatabaseError("Database connection was temporarily interrupted, please retry",
277 timedelta(seconds=COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS)) from e
278 raise
279 finally:
280 session.close()
282 def _get_job(self, job_name, session, with_for_update=False):
283 statement = sql_select(Job).filter_by(name=job_name)
284 if with_for_update:
285 statement = statement.with_for_update()
287 job = session.execute(statement).scalars().first()
288 if job:
289 self.__logger.debug(f"Loaded job from db: name=[{job_name}], stage=[{job.stage}], result=[{job.result}]")
291 return job
293 def _check_job_timeout(self, job_internal, *, max_execution_timeout=None):
294 """ Do a lazy check of maximum allowed job timeouts when clients try to retrieve
295 an existing job.
296 Cancel the job and related operations/leases, if we detect they have
297 exceeded timeouts on access.
299 Returns the `buildgrid.server.Job` object, possibly updated with `cancelled=True`.
300 """
301 if job_internal and max_execution_timeout and job_internal.worker_start_timestamp_as_datetime:
302 if job_internal.operation_stage == OperationStage.EXECUTING:
303 executing_duration = datetime.utcnow() - job_internal.worker_start_timestamp_as_datetime
304 if executing_duration.total_seconds() >= max_execution_timeout:
305 self.__logger.warning(f"Job=[{job_internal}] has been executing for "
306 f"executing_duration=[{executing_duration}]. "
307 f"max_execution_timeout=[{max_execution_timeout}] "
308 "Cancelling.")
309 job_internal.cancel_all_operations(data_store=self)
310 self.__logger.info(f"Job=[{job_internal}] has been cancelled.")
311 return job_internal
313 @DurationMetric(DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, instanced=True)
314 def get_job_by_action(self, action_digest, *, max_execution_timeout=None):
315 statement = sql_select(Job).where(
316 and_(
317 Job.action_digest == digest_to_string(action_digest),
318 Job.stage != OperationStage.COMPLETED.value
319 )
320 )
322 with self.session() as session:
323 job = session.execute(statement).scalars().first()
324 if job:
325 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url,
326 instance_name=self._instance_name)
327 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout)
328 return None
330 @DurationMetric(DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, instanced=True)
331 def get_job_by_name(self, name, *, max_execution_timeout=None):
332 with self.session() as session:
333 job = self._get_job(name, session)
334 if job:
335 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url,
336 instance_name=self._instance_name)
337 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout)
338 return None
340 @DurationMetric(DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, instanced=True)
341 def get_job_by_operation(self, operation_name, *, max_execution_timeout=None):
342 with self.session() as session:
343 operation = self._get_operation(operation_name, session)
344 if operation and operation.job:
345 job = operation.job
346 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url,
347 instance_name=self._instance_name)
348 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout)
349 return None
351 def get_all_jobs(self):
352 statement = sql_select(Job).where(
353 Job.stage != OperationStage.COMPLETED.value
354 )
356 with self.session() as session:
357 jobs = session.execute(statement).scalars()
358 return [j.to_internal_job(self, action_browser_url=self._action_browser_url,
359 instance_name=self._instance_name) for j in jobs]
361 def get_jobs_by_stage(self, operation_stage):
362 statement = sql_select(Job).where(
363 Job.stage == operation_stage.value
364 )
366 with self.session() as session:
367 jobs = session.execute(statement).scalars()
368 return [j.to_internal_job(self, no_result=True, action_browser_url=self._action_browser_url,
369 instance_name=self._instance_name) for j in jobs]
371 def get_operation_request_metadata_by_name(self, operation_name):
372 with self.session() as session:
373 operation = self._get_operation(operation_name, session)
374 if not operation:
375 return None
377 return {'tool-name': operation.tool_name or '',
378 'tool-version': operation.tool_version or '',
379 'invocation-id': operation.invocation_id or '',
380 'correlated-invocations-id': operation.correlated_invocations_id or ''}
382 @DurationMetric(DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, instanced=True)
383 def create_job(self, job):
384 with self.session() as session:
385 if self._get_job(job.name, session) is None:
386 # Convert requirements values to sorted lists to make them json-serializable
387 platform_requirements = job.platform_requirements
388 convert_values_to_sorted_lists(platform_requirements)
389 # Serialize the requirements
390 platform_requirements_hash = hash_from_dict(platform_requirements)
392 session.add(Job(
393 name=job.name,
394 action=job.action.SerializeToString(),
395 action_digest=digest_to_string(job.action_digest),
396 do_not_cache=job.do_not_cache,
397 priority=job.priority,
398 operations=[],
399 platform_requirements=platform_requirements_hash,
400 stage=job.operation_stage.value,
401 queued_timestamp=job.queued_timestamp_as_datetime,
402 queued_time_duration=job.queued_time_duration.seconds,
403 worker_start_timestamp=job.worker_start_timestamp_as_datetime,
404 worker_completed_timestamp=job.worker_completed_timestamp_as_datetime
405 ))
407 @DurationMetric(DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, instanced=True)
408 def queue_job(self, job_name):
409 with self.session(sqlite_lock_immediately=True) as session:
410 job = self._get_job(job_name, session, with_for_update=True)
411 job.assigned = False
413 @DurationMetric(DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, instanced=True)
414 def update_job(self, job_name, changes, *, skip_notify=False):
415 if "result" in changes:
416 changes["result"] = digest_to_string(changes["result"])
417 if "action_digest" in changes:
418 changes["action_digest"] = digest_to_string(changes["action_digest"])
420 initial_values_for_metrics_use = {}
422 with self.session() as session:
423 job = self._get_job(job_name, session)
425 # Keep track of the state right before we perform this update
426 initial_values_for_metrics_use["stage"] = OperationStage(job.stage)
428 job.update(changes)
429 if not skip_notify:
430 self._notify_job_updated(job_name, session)
432 # Upon successful completion of the transaction above, publish metrics
433 JobMetrics.publish_metrics_on_job_updates(initial_values_for_metrics_use, changes, self._instance_name)
435 def _notify_job_updated(self, job_names, session):
436 if self.engine.dialect.name == "postgresql":
437 if isinstance(job_names, str):
438 job_names = [job_names]
439 for job_name in job_names:
440 session.execute(f"NOTIFY job_updated, '{job_name}';")
442 def delete_job(self, job_name):
443 if job_name in self.response_cache:
444 del self.response_cache[job_name]
446 def wait_for_job_updates(self):
447 self.__logger.info("Starting job watcher thread")
448 if self.engine.dialect.name == "postgresql":
449 self._listen_for_updates()
450 else:
451 self._poll_for_updates()
453 def _listen_for_updates(self):
454 def _listen_loop(engine_conn: Connection):
455 try:
456 # Get the DBAPI connection object from the SQLAlchemy Engine.Connection wrapper
457 connection_fairy = engine_conn.connection
458 connection_fairy.cursor().execute("LISTEN job_updated;") # type: ignore
459 connection_fairy.commit()
460 except Exception:
461 self.__logger.warning(
462 "Could not start listening to DB for job updates",
463 exc_info=True)
464 # Let the context manager handle this
465 raise
467 while self.watcher_keep_running:
468 # Get the actual DBAPI connection
469 dbapi_connection = connection_fairy.dbapi_connection # type: ignore
470 # Wait until the connection is ready for reading. Timeout after 5 seconds
471 # and try again if there was nothing to read. If the connection becomes
472 # readable, collect the notifications it has received and handle them.
473 #
474 # See http://initd.org/psycopg/docs/advanced.html#async-notify
475 if select.select([dbapi_connection], [], [], self.poll_interval) == ([], [], []):
476 pass
477 else:
479 try:
480 dbapi_connection.poll()
481 except Exception:
482 self.__logger.warning("Error while polling for job updates", exc_info=True)
483 # Let the context manager handle this
484 raise
486 while dbapi_connection.notifies:
487 notify = dbapi_connection.notifies.pop()
488 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME,
489 instanced=True, instance_name=self._instance_name):
490 with self.watched_jobs_lock:
491 spec = self.watched_jobs.get(notify.payload)
492 if spec is not None:
493 try:
494 new_job = self.get_job_by_name(notify.payload)
495 except Exception:
496 self.__logger.warning(
497 f"Couldn't get watched job=[{notify.payload}] from DB",
498 exc_info=True)
499 # Let the context manager handle this
500 raise
502 # If the job doesn't exist or an exception was supressed by
503 # get_job_by_name, it returns None instead of the job
504 if new_job is None:
505 raise DatabaseError(
506 f"get_job_by_name returned None for job=[{notify.payload}]")
508 new_state = JobState(new_job)
509 if spec.last_state != new_state:
510 spec.last_state = new_state
511 spec.event.notify_change()
513 while self.watcher_keep_running:
514 # Wait a few seconds if a database exception occurs and then try again
515 # This could be a short disconnect
516 try:
517 # Use the session contextmanager
518 # so that we can benefit from the common SQL error-handling
519 with self.session(reraise=True) as session:
520 # In our `LISTEN` call, we want to *bypass the ORM*
521 # and *use the underlying Engine connection directly*.
522 # (This is because using a `session.execute()` will
523 # implicitly create a SQL transaction, causing
524 # notifications to only be delivered when that transaction
525 # is committed)
526 _listen_loop(session.connection())
527 except Exception as e:
528 self.__logger.warning(f"JobWatcher encountered exception: [{e}];"
529 f"Retrying in poll_interval=[{self.poll_interval}] seconds.")
530 # Sleep for a bit so that we give enough time for the
531 # database to potentially recover
532 time.sleep(self.poll_interval)
534 def _get_watched_jobs(self):
535 statement = sql_select(Job).where(
536 Job.name.in_(self.watched_jobs)
537 )
539 with self.session() as sess:
540 jobs = sess.execute(statement).scalars().all()
541 return [job.to_internal_job(self) for job in jobs]
543 def _poll_for_updates(self):
544 def _poll_loop():
545 while self.watcher_keep_running:
546 time.sleep(self.poll_interval)
547 if self.watcher_keep_running:
548 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME,
549 instanced=True, instance_name=self._instance_name):
550 with self.watched_jobs_lock:
551 if self.watcher_keep_running:
552 try:
553 watched_jobs = self._get_watched_jobs()
554 except Exception as e:
555 raise DatabaseError("Couldn't retrieve watched jobs from DB") from e
557 if watched_jobs is None:
558 raise DatabaseError("_get_watched_jobs returned None")
560 for new_job in watched_jobs:
561 if self.watcher_keep_running:
562 spec = self.watched_jobs[new_job.name]
563 new_state = JobState(new_job)
564 if spec.last_state != new_state:
565 spec.last_state = new_state
566 spec.event.notify_change()
568 while self.watcher_keep_running:
569 # Wait a few seconds if a database exception occurs and then try again
570 try:
571 _poll_loop()
572 except DatabaseError as e:
573 self.__logger.warning(f"JobWatcher encountered exception: [{e}];"
574 f"Retrying in poll_interval=[{self.poll_interval}] seconds.")
575 # Sleep for a bit so that we give enough time for the
576 # database to potentially recover
577 time.sleep(self.poll_interval)
579 @DurationMetric(DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, instanced=True)
580 def store_response(self, job, commit_changes=True):
581 digest = self.storage.put_message(job.execute_response)
582 changes = {"result": digest, "status_code": job.execute_response.status.code}
583 self.response_cache[job.name] = job.execute_response
585 if commit_changes:
586 self.update_job(job.name,
587 changes,
588 skip_notify=True)
589 return None
590 else:
591 # The caller will batch the changes and commit to db
592 return changes
594 def _get_operation(self, operation_name, session):
595 statement = sql_select(Operation).where(
596 Operation.name == operation_name
597 )
598 return session.execute(statement).scalars().first()
600 def get_operations_by_stage(self, operation_stage):
601 statement = sql_select(Operation).where(
602 Operation.job.has(stage=operation_stage.value)
603 )
605 with self.session() as session:
606 operations = session.execute(statement).scalars().all()
607 # Return a set of job names here for now, to match the `MemoryDataStore`
608 # implementation's behaviour
609 return set(op.job.name for op in operations)
611 def _cancel_jobs_exceeding_execution_timeout(self, max_execution_timeout: int=None) -> None:
612 if max_execution_timeout:
613 stale_job_names = []
614 lazy_execution_timeout_threshold = datetime.utcnow() - timedelta(seconds=max_execution_timeout)
616 with self.session(sqlite_lock_immediately=True) as session:
617 # Get the full list of jobs exceeding execution timeout
618 stale_jobs_statement = sql_select(Job).with_for_update().where(
619 and_(
620 Job.stage == OperationStage.EXECUTING.value,
621 Job.worker_start_timestamp <= lazy_execution_timeout_threshold
622 )
623 )
624 stale_job_names = [job.name for job in session.execute(stale_jobs_statement).scalars().all()]
626 if stale_job_names:
627 # Mark operations as cancelled
628 stmt_mark_operations_cancelled = update(Operation).where(
629 Operation.job_name.in_(stale_job_names)
630 ).values(cancelled=True)
631 session.execute(stmt_mark_operations_cancelled)
633 # Mark jobs as cancelled
634 stmt_mark_jobs_cancelled = update(Job).where(
635 Job.name.in_(stale_job_names)
636 ).values(stage=OperationStage.COMPLETED.value, cancelled=True)
637 session.execute(stmt_mark_jobs_cancelled)
639 # Notify all jobs updated
640 self._notify_job_updated(stale_job_names, session)
642 if stale_job_names:
643 self.__logger.info(f"Cancelled n=[{len(stale_job_names)}] jobs "
644 f"with names={stale_job_names}"
645 f"due to them exceeding execution_timeout=["
646 f"{max_execution_timeout}")
648 @DurationMetric(DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, instanced=True)
649 def list_operations(self,
650 operation_filters: List[OperationFilter]=None,
651 page_size: int=None,
652 page_token: str=None,
653 max_execution_timeout: int=None) -> Tuple[List[operations_pb2.Operation], str]:
654 # Lazily timeout jobs as needed before returning the list!
655 self._cancel_jobs_exceeding_execution_timeout(max_execution_timeout=max_execution_timeout)
657 # Build filters and sort order
658 sort_keys = DEFAULT_SORT_KEYS
659 custom_filters = None
660 if operation_filters:
661 # Extract custom sort order (if present)
662 specified_sort_keys, non_sort_filters = extract_sort_keys(operation_filters)
664 # Only override sort_keys if there were sort keys actually present in the filter string
665 if specified_sort_keys:
666 sort_keys = specified_sort_keys
667 # Attach the operation name as a sort key for a deterministic order
668 # This will ensure that the ordering of results is consistent between queries
669 if not any(sort_key.name == "name" for sort_key in sort_keys):
670 sort_keys.append(SortKey(name="name", descending=False))
672 # Finally, compile the non-sort filters into a filter list
673 custom_filters = build_custom_filters(non_sort_filters)
675 sort_columns = build_sort_column_list(sort_keys)
677 with self.session() as session:
678 statement = sql_select(Operation).join(Job, Operation.job_name == Job.name)
680 # Apply custom filters (if present)
681 if custom_filters:
682 statement = statement.filter(*custom_filters)
684 # Apply sort order
685 statement = statement.order_by(*sort_columns)
687 # Apply pagination filter
688 if page_token:
689 page_filter = build_page_filter(page_token, sort_keys)
690 statement = statement.filter(page_filter)
691 if page_size:
692 # We limit the number of operations we fetch to the page_size. However, we
693 # fetch an extra operation to determine whether we need to provide a
694 # next_page_token.
695 statement = statement.limit(page_size + 1)
697 operations = session.execute(statement).scalars().all()
699 if not page_size or not operations:
700 next_page_token = ""
702 # If the number of results we got is less than or equal to our page_size,
703 # we're done with the operations listing and don't need to provide another
704 # page token
705 elif len(operations) <= page_size:
706 next_page_token = ""
707 else:
708 # Drop the last operation since we have an extra
709 operations.pop()
710 # Our page token will be the last row of our set
711 next_page_token = build_page_token(operations[-1], sort_keys)
712 return [operation.to_protobuf(self) for operation in operations], next_page_token
714 @DurationMetric(DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, instanced=True)
715 def create_operation(self, operation_name, job_name, request_metadata=None):
716 with self.session() as session:
717 operation = Operation(
718 name=operation_name,
719 job_name=job_name
720 )
721 if request_metadata is not None:
722 if request_metadata.tool_invocation_id:
723 operation.invocation_id = request_metadata.tool_invocation_id
724 if request_metadata.correlated_invocations_id:
725 operation.correlated_invocations_id = request_metadata.correlated_invocations_id
726 if request_metadata.tool_details:
727 operation.tool_name = request_metadata.tool_details.tool_name
728 operation.tool_version = request_metadata.tool_details.tool_version
729 session.add(operation)
731 @DurationMetric(DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME, instanced=True)
732 def update_operation(self, operation_name, changes):
733 with self.session() as session:
734 operation = self._get_operation(operation_name, session)
735 operation.update(changes)
737 def delete_operation(self, operation_name):
738 # Don't do anything. This function needs to exist but there's no
739 # need to actually delete operations in this implementation.
740 pass
742 def get_leases_by_state(self, lease_state):
743 statement = sql_select(Lease).where(
744 Lease.state == lease_state.value
745 )
747 with self.session() as session:
748 leases = session.execute(statement).scalars().all()
749 # `lease.job_name` is the same as `lease.id` for a Lease protobuf
750 return set(lease.job_name for lease in leases)
752 def get_metrics(self):
753 # Skip publishing overall scheduler metrics if we have recently published them
754 last_publish_time = self.__last_scheduler_metrics_publish_time
755 time_since_publish = None
756 if last_publish_time:
757 time_since_publish = datetime.utcnow() - last_publish_time
758 if time_since_publish and time_since_publish < self.__scheduler_metrics_publish_interval:
759 # Published too recently, skip
760 return None
762 def _get_query_leases_by_state(category: str) -> Select:
763 # Using func.count here to avoid generating a subquery in the WHERE
764 # clause of the resulting query.
765 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
766 return sql_select([
767 literal_column(f"'{category}'").label("category"),
768 Lease.state.label("bucket"),
769 func.count(Lease.id).label("value")
770 ]).group_by(Lease.state)
772 def _cb_query_leases_by_state(leases_by_state):
773 # The database only returns counts > 0, so fill in the gaps
774 for state in LeaseState:
775 if state.value not in leases_by_state:
776 leases_by_state[state.value] = 0
777 return leases_by_state
779 def _get_query_jobs_by_stage(category: str) -> Select:
780 # Using func.count here to avoid generating a subquery in the WHERE
781 # clause of the resulting query.
782 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
783 return sql_select([
784 literal_column(f"'{category}'").label("category"),
785 Job.stage.label("bucket"),
786 func.count(Job.name).label("value")
787 ]).group_by(Job.stage)
789 def _cb_query_jobs_by_stage(jobs_by_stage):
790 # The database only returns counts > 0, so fill in the gaps
791 for stage in OperationStage:
792 if stage.value not in jobs_by_stage:
793 jobs_by_stage[stage.value] = 0
794 return jobs_by_stage
796 metrics = {}
797 # metrics to gather: (category_name, function_returning_query, callback_function)
798 metrics_to_gather = [
799 (MetricCategories.LEASES.value, _get_query_leases_by_state, _cb_query_leases_by_state),
800 (MetricCategories.JOBS.value, _get_query_jobs_by_stage, _cb_query_jobs_by_stage)
801 ]
803 statements = [query_fn(category) for category, query_fn, _ in metrics_to_gather]
804 metrics_statement = union(*statements)
806 try:
807 with self.session() as session:
808 results = session.execute(metrics_statement).all()
810 grouped_results = {category: {} for category, _, _ in results}
811 for category, bucket, value in results:
812 grouped_results[category][bucket] = value
814 for category, _, category_cb in metrics_to_gather:
815 metrics[category] = category_cb(grouped_results.setdefault(category, {}))
816 except DatabaseError:
817 self.__logger.warning("Unable to gather metrics due to a Database Error.")
818 return {}
820 # This is only updated within the metrics asyncio loop; no race conditions
821 self.__last_scheduler_metrics_publish_time = datetime.utcnow()
823 return metrics
825 def _create_lease(self, lease, session, job=None, worker_name=None):
826 if job is None:
827 job = self._get_job(lease.id, session)
828 # We only allow one lease, so if there's an existing one update it
829 if job.active_leases:
830 job.active_leases[0].state = lease.state
831 job.active_leases[0].status = None
832 job.active_leases[0].worker_name = worker_name
833 else:
834 session.add(Lease(
835 job_name=lease.id,
836 state=lease.state,
837 status=None,
838 worker_name=worker_name
839 ))
841 def create_lease(self, lease):
842 with self.session() as session:
843 self._create_lease(lease, session)
845 @DurationMetric(DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, instanced=True)
846 def update_lease(self, job_name, changes):
847 initial_values_for_metrics_use = {}
849 with self.session() as session:
850 job = self._get_job(job_name, session)
851 try:
852 lease = job.active_leases[0]
853 except IndexError:
854 return
856 # Keep track of the state right before we perform this update
857 initial_values_for_metrics_use["state"] = lease.state
859 lease.update(changes)
861 # Upon successful completion of the transaction above, publish metrics
862 JobMetrics.publish_metrics_on_lease_updates(initial_values_for_metrics_use, changes, self._instance_name)
864 def load_unfinished_jobs(self):
865 statement = sql_select(Job).where(
866 Job.stage != OperationStage.COMPLETED.value
867 ).order_by(Job.priority)
869 with self.session() as session:
870 jobs = session.execute(statement).scalars().all()
871 return [j.to_internal_job(self) for j in jobs]
873 def assign_lease_for_next_job(self, capabilities, callback, timeout=None):
874 """Return a list of leases for the highest priority jobs that can be run by a worker.
876 NOTE: Currently the list only ever has one or zero leases.
878 Query the jobs table to find queued jobs which match the capabilities of
879 a given worker, and return the one with the highest priority. Takes a
880 dictionary of worker capabilities to compare with job requirements.
882 :param capabilities: Dictionary of worker capabilities to compare
883 with job requirements when finding a job.
884 :type capabilities: dict
885 :param callback: Function to run on the next runnable job, should return
886 a list of leases.
887 :type callback: function
888 :param timeout: time to wait for new jobs, caps if longer
889 than MAX_JOB_BLOCK_TIME.
890 :type timeout: int
891 :returns: List of leases
893 """
894 if not timeout:
895 return self._assign_job_leases(capabilities, callback)
897 # Cap the timeout if it's larger than MAX_JOB_BLOCK_TIME
898 if timeout:
899 timeout = min(timeout, MAX_JOB_BLOCK_TIME)
901 start = time.time()
902 while time.time() + self.connection_timeout + 1 - start < timeout:
903 leases = self._assign_job_leases(capabilities, callback)
904 if leases:
905 return leases
906 time.sleep(0.5)
907 if self.connection_timeout > timeout:
908 self.__logger.warning(
909 "Not providing any leases to the worker because the database connection "
910 f"timeout ({self.connection_timeout} s) is longer than the remaining "
911 "time to handle the request. "
912 "Increase the worker's timeout to solve this problem.")
913 return []
915 def flatten_capabilities(self, capabilities: Dict[str, List[str]]) -> List[Tuple[str, str]]:
916 """ Flatten a capabilities dictionary, assuming all of its values are lists. E.g.
918 {'OSFamily': ['Linux'], 'ISA': ['x86-32', 'x86-64']}
920 becomes
922 [('OSFamily', 'Linux'), ('ISA', 'x86-32'), ('ISA', 'x86-64')] """
923 return [
924 (name, value) for name, value_list in capabilities.items()
925 for value in value_list
926 ]
928 def get_partial_capabilities(self, capabilities: Dict[str, List[str]]) -> Iterable[Dict[str, List[str]]]:
929 """ Given a capabilities dictionary with all values as lists,
930 yield all partial capabilities dictionaries. """
931 CAPABILITIES_WARNING_THRESHOLD = 10
933 caps_flat = self.flatten_capabilities(capabilities)
935 if len(caps_flat) > CAPABILITIES_WARNING_THRESHOLD:
936 self.__logger.warning(
937 "A worker with a large capabilities dictionary has been connected. "
938 f"Processing its capabilities may take a while. Capabilities: {capabilities}")
940 # Using the itertools powerset recipe, construct the powerset of the tuples
941 capabilities_powerset = chain.from_iterable(combinations(caps_flat, r) for r in range(len(caps_flat) + 1))
942 for partial_capability_tuples in capabilities_powerset:
943 partial_dict: Dict[str, List[str]] = {}
945 for tup in partial_capability_tuples:
946 partial_dict.setdefault(tup[0], []).append(tup[1])
947 yield partial_dict
949 def get_partial_capabilities_hashes(self, capabilities: Dict) -> List[str]:
950 """ Given a list of configurations, obtain each partial configuration
951 for each configuration, obtain the hash of each partial configuration,
952 compile these into a list, and return the result. """
953 # Convert requirements values to sorted lists to make them json-serializable
954 convert_values_to_sorted_lists(capabilities)
956 # Check to see if we've cached this value
957 capabilities_digest = hash_from_dict(capabilities)
958 try:
959 return self.capabilities_cache[capabilities_digest]
960 except KeyError:
961 # On cache miss, expand the capabilities into each possible partial capabilities dictionary
962 capabilities_list = []
963 for partial_capability in self.get_partial_capabilities(capabilities):
964 capabilities_list.append(hash_from_dict(partial_capability))
966 self.capabilities_cache[capabilities_digest] = capabilities_list
967 return capabilities_list
969 @DurationMetric(BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, instanced=True)
970 def _assign_job_leases(self, capabilities, callback):
971 # pylint: disable=singleton-comparison
972 # Hash the capabilities
973 capabilities_config_hashes = self.get_partial_capabilities_hashes(capabilities)
974 leases = []
975 create_lease_start_time = None
977 # Select unassigned, queued jobs with platform requirements matching the
978 # given capabilities, ordered by priority and age.
979 job_statement = sql_select(Job).with_for_update(skip_locked=True).where(
980 Job.stage == OperationStage.QUEUED.value,
981 Job.assigned != True,
982 Job.platform_requirements.in_(capabilities_config_hashes)
983 ).order_by(Job.priority, Job.queued_timestamp)
985 try:
986 with self.session(sqlite_lock_immediately=True) as session:
987 job = session.execute(job_statement).scalars().first()
988 if job:
989 internal_job = job.to_internal_job(self)
990 leases = callback(internal_job)
991 if leases:
992 job.assigned = True
993 job.worker_start_timestamp = internal_job.worker_start_timestamp_as_datetime
994 # Only create a lease if there isn't already an active one for this job
995 create_lease_start_time = time.perf_counter()
996 for lease in leases:
997 self._create_lease(lease, session, job=job, worker_name=internal_job.worker_name)
999 # Calculate and publish the time taken to create leases. This is done explicitly
1000 # rather than using the DurationMetric helper since we need to measure the actual
1001 # execution time of the UPDATE and INSERT queries used in the lease assignment, and
1002 # these are only exectuted on exiting the contextmanager.
1003 if create_lease_start_time is not None:
1004 run_time = timedelta(seconds=time.perf_counter() - create_lease_start_time)
1005 metadata = None
1006 if self._instance_name is not None:
1007 metadata = {'instance-name': self._instance_name}
1008 publish_timer_metric(DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, run_time, metadata=metadata)
1010 except DatabaseError:
1011 self.__logger.warning("Will not assign any leases this time due to a Database Error.")
1013 return leases
1015 def _do_prune(self, job_max_age: timedelta, pruning_period: timedelta, limit: int) -> None:
1016 """ Running in a background thread, this method wakes up periodically and deletes older records
1017 from the jobs tables using configurable parameters """
1019 utc_last_prune_time = datetime.utcnow()
1020 while self.pruner_keep_running:
1021 utcnow = datetime.utcnow()
1022 if (utcnow - pruning_period) < utc_last_prune_time:
1023 self.__logger.info(f"Pruner thread sleeping for {pruning_period}(until {utcnow + pruning_period})")
1024 time.sleep(pruning_period.total_seconds())
1025 continue
1027 delete_before_datetime = utcnow - job_max_age
1028 try:
1029 with DurationMetric(DATA_STORE_PRUNER_DELETE_TIME_METRIC_NAME,
1030 instance_name=self._instance_name,
1031 instanced=True):
1032 num_rows = self._delete_jobs_prior_to(delete_before_datetime, limit)
1034 self.__logger.info(
1035 f"Pruned {num_rows} row(s) from the jobs table older than {delete_before_datetime}")
1037 if num_rows > 0:
1038 with Counter(metric_name=DATA_STORE_PRUNER_NUM_ROWS_DELETED_METRIC_NAME,
1039 instance_name=self._instance_name) as num_rows_deleted:
1040 num_rows_deleted.increment(num_rows)
1042 except Exception:
1043 self.__logger.exception("Caught exception while deleting jobs records")
1045 finally:
1046 # Update even if error occurred to avoid potentially infinitely retrying
1047 utc_last_prune_time = utcnow
1049 self.__logger.info("Exiting pruner thread")
1051 def _delete_jobs_prior_to(self, delete_before_datetime: datetime, limit: int) -> int:
1052 """ Deletes older records from the jobs tables constrained by `delete_before_datetime` and `limit` """
1053 delete_stmt = delete(Job).where(
1054 Job.name.in_(
1055 sql_select(Job.name).with_for_update(skip_locked=True).where(
1056 Job.worker_completed_timestamp <= delete_before_datetime
1057 ).limit(limit)
1058 )
1059 )
1061 with self.session(reraise=True) as session:
1062 options = {"synchronize_session": "fetch"}
1063 num_rows_deleted = session.execute(delete_stmt, execution_options=options).rowcount
1065 return num_rows_deleted