Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/persistence/sql/impl.py : 94.56%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Copyright (C) 2019 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from contextlib import contextmanager
17import logging
18import os
19import select
20from threading import Thread, Lock
21import time
22from datetime import datetime, timedelta
23from tempfile import NamedTemporaryFile
24from itertools import chain, combinations
25from typing import Any, Dict, Iterable, List, Tuple
27from alembic import command
28from alembic.config import Config
29from sqlalchemy import create_engine, event, func, text, union, literal_column
30from sqlalchemy.orm.session import sessionmaker, Session as SessionType
32from buildgrid._protos.google.longrunning import operations_pb2
33from buildgrid._enums import LeaseState, MetricCategories, OperationStage
34from buildgrid.server.metrics_names import (
35 BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME,
36 DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME,
37 DATA_STORE_CREATE_JOB_TIME_METRIC_NAME,
38 DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME,
39 DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME,
40 DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME,
41 DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME,
42 DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME,
43 DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME,
44 DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME,
45 DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME,
46 DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME,
47 DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME,
48 DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME
49)
50from buildgrid.server.metrics_utils import DurationMetric, publish_timer_metric
51from buildgrid.server.operations.filtering import OperationFilter, SortKey, DEFAULT_SORT_KEYS
52from buildgrid.server.persistence.interface import DataStoreInterface
53from buildgrid.server.persistence.sql.models import digest_to_string, Job, Lease, Operation
54from buildgrid.server.persistence.sql.utils import (
55 build_page_filter,
56 build_page_token,
57 extract_sort_keys,
58 build_custom_filters,
59 build_sort_column_list
60)
61from buildgrid.settings import MAX_JOB_BLOCK_TIME, MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES
62from buildgrid.utils import JobState, hash_from_dict, convert_values_to_sorted_lists
64from buildgrid._exceptions import DatabaseError
67Session = sessionmaker()
70def sqlite_on_connect(conn, record):
71 conn.execute("PRAGMA journal_mode=WAL")
72 conn.execute("PRAGMA synchronous=NORMAL")
75class SQLDataStore(DataStoreInterface):
77 def __init__(self, storage, *, connection_string=None, automigrate=False,
78 connection_timeout=5, poll_interval=1, **kwargs):
79 super().__init__()
80 self.__logger = logging.getLogger(__name__)
81 self.__logger.info("Creating SQL scheduler with: "
82 f"automigrate=[{automigrate}], connection_timeout=[{connection_timeout}] "
83 f"poll_interval=[{poll_interval}], kwargs=[{kwargs}]")
85 self.storage = storage
86 self.response_cache = {}
87 self.connection_timeout = connection_timeout
88 self.poll_interval = poll_interval
89 self.watcher = Thread(name="JobWatcher", target=self.wait_for_job_updates, daemon=True)
90 self.watcher_keep_running = True
91 self.__dispose_pool_on_exceptions: Tuple[Any, ...] = tuple()
92 self.__last_pool_dispose_time = None
93 self.__last_pool_dispose_time_lock = Lock()
95 # Set-up temporary SQLite Database when connection string is not specified
96 if not connection_string:
97 tmpdbfile = NamedTemporaryFile(prefix='bgd-', suffix='.db')
98 self._tmpdbfile = tmpdbfile # Make sure to keep this tempfile for the lifetime of this object
99 self.__logger.warning("No connection string specified for the DataStore, "
100 f"will use SQLite with tempfile: [{tmpdbfile.name}]")
101 automigrate = True # since this is a temporary database, we always need to create it
102 connection_string = f"sqlite:///{tmpdbfile.name}"
104 self._create_sqlalchemy_engine(connection_string, automigrate, connection_timeout, **kwargs)
106 # Make a test query against the database to ensure the connection is valid
107 with self.session(reraise=True) as session:
108 session.query(Job).first()
110 self.watcher.start()
112 self.capabilities_cache = {}
114 def _create_sqlalchemy_engine(self, connection_string, automigrate, connection_timeout, **kwargs):
115 self.automigrate = automigrate
117 # Disallow sqlite in-memory because multi-threaded access to it is
118 # complex and potentially problematic at best
119 # ref: https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#threading-pooling-behavior
120 if self._is_sqlite_inmemory_connection_string(connection_string):
121 raise ValueError(
122 f"Cannot use SQLite in-memory with BuildGrid (connection_string=[{connection_string}]). "
123 "Use a file or leave the connection_string empty for a tempfile.")
125 if connection_timeout is not None:
126 if "connect_args" not in kwargs:
127 kwargs["connect_args"] = {}
128 if self._is_sqlite_connection_string(connection_string):
129 kwargs["connect_args"]["timeout"] = connection_timeout
130 else:
131 kwargs["connect_args"]["connect_timeout"] = connection_timeout
133 # Only pass the (known) kwargs that have been explicitly set by the user
134 available_options = set([
135 'pool_size', 'max_overflow', 'pool_timeout', 'pool_pre_ping',
136 'pool_recycle', 'connect_args'
137 ])
138 kwargs_keys = set(kwargs.keys())
139 if not kwargs_keys.issubset(available_options):
140 unknown_options = kwargs_keys - available_options
141 raise TypeError(f"Unknown keyword arguments: [{unknown_options}]")
143 self.__logger.debug(f"SQLAlchemy additional kwargs: [{kwargs}]")
145 self.engine = create_engine(connection_string, echo=False, **kwargs)
146 Session.configure(bind=self.engine)
148 if self.engine.dialect.name == "sqlite":
149 event.listen(self.engine, "connect", sqlite_on_connect)
151 self._configure_dialect_disposal_exceptions(self.engine.dialect.name)
153 if self.automigrate:
154 self._create_or_migrate_db(connection_string)
156 def _is_sqlite_connection_string(self, connection_string):
157 if connection_string:
158 return connection_string.startswith("sqlite")
159 return False
161 def _is_sqlite_inmemory_connection_string(self, full_connection_string):
162 if self._is_sqlite_connection_string(full_connection_string):
163 # Valid connection_strings for in-memory SQLite which we don't support could look like:
164 # "sqlite:///file:memdb1?option=value&cache=shared&mode=memory",
165 # "sqlite:///file:memdb1?mode=memory&cache=shared",
166 # "sqlite:///file:memdb1?cache=shared&mode=memory",
167 # "sqlite:///file::memory:?cache=shared",
168 # "sqlite:///file::memory:",
169 # "sqlite:///:memory:",
170 # "sqlite:///",
171 # "sqlite://"
172 # ref: https://www.sqlite.org/inmemorydb.html
173 # Note that a user can also specify drivers, so prefix could become 'sqlite+driver:///'
174 connection_string = full_connection_string
176 uri_split_index = connection_string.find("?")
177 if uri_split_index != -1:
178 connection_string = connection_string[0:uri_split_index]
180 if connection_string.endswith((":memory:", ":///", "://")):
181 return True
182 elif uri_split_index != -1:
183 opts = full_connection_string[uri_split_index + 1:].split("&")
184 if "mode=memory" in opts:
185 return True
187 return False
189 def __repr__(self):
190 return f"SQL data store interface for `{repr(self.engine.url)}`"
192 def activate_monitoring(self):
193 # Don't do anything. This function needs to exist but there's no
194 # need to actually toggle monitoring in this implementation.
195 pass
197 def deactivate_monitoring(self):
198 # Don't do anything. This function needs to exist but there's no
199 # need to actually toggle monitoring in this implementation.
200 pass
202 def _configure_dialect_disposal_exceptions(self, dialect: str):
203 self.__dispose_pool_on_exceptions = self._get_dialect_disposal_exceptions(dialect)
205 def _get_dialect_disposal_exceptions(self, dialect: str) -> Tuple[Any, ...]:
206 dialect_errors: Tuple[Any, ...] = tuple()
207 if dialect == 'postgresql':
208 import psycopg2 # pylint: disable=import-outside-toplevel
209 dialect_errors = (psycopg2.errors.ReadOnlySqlTransaction, psycopg2.errors.AdminShutdown)
210 return dialect_errors
212 def _create_or_migrate_db(self, connection_string):
213 self.__logger.warning("Will attempt migration to latest version if needed.")
215 config = Config()
216 config.set_main_option("script_location", os.path.join(os.path.dirname(__file__), "alembic"))
218 with self.engine.begin() as connection:
219 config.attributes['connection'] = connection
220 command.upgrade(config, "head")
222 @contextmanager
223 def session(self, *, sqlite_lock_immediately=False, reraise=False):
224 # Try to obtain a session
225 try:
226 session = Session()
227 if sqlite_lock_immediately and session.bind.name == "sqlite":
228 session.execute("BEGIN IMMEDIATE")
229 except Exception as e:
230 self.__logger.error("Unable to obtain a database session.", exc_info=True)
231 raise DatabaseError("Unable to obtain a database session.") from e
233 # Yield the session and catch exceptions that occur while using it
234 # to roll-back if needed
235 try:
236 yield session
237 session.commit()
238 except Exception as e:
239 self.__logger.error("Error committing database session. Rolling back.", exc_info=True)
240 self._check_dispose_pool(session, e)
241 try:
242 session.rollback()
243 except Exception:
244 self.__logger.warning("Rollback error.", exc_info=True)
246 if reraise:
247 raise
248 finally:
249 session.close()
251 def _check_dispose_pool(self, session: SessionType, e: Exception):
252 # Only do this if the config is relevant
253 if not self.__dispose_pool_on_exceptions:
254 return
256 # Make sure we have a SQL-related cause to check, otherwise skip
257 if e.__cause__ and isinstance(e.__cause__, Exception):
258 cause_type = type(e.__cause__)
259 # Only allow disposal every MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES
260 now = datetime.utcnow()
261 only_if_after = None
263 # Let's see if this exception is related to known disconnect exceptions
264 is_connection_error = cause_type in self.__dispose_pool_on_exceptions
266 if is_connection_error:
267 # Make sure this connection will not be re-used
268 session.invalidate()
269 self.__logger.info(
270 f'Detected a SQL exception=[{cause_type.__name__}] related to the connection. '
271 'Invalidating this connection.'
272 )
273 # Check if we should dispose the rest of the checked in connections
274 with self.__last_pool_dispose_time_lock:
275 if self.__last_pool_dispose_time:
276 only_if_after = self.__last_pool_dispose_time + \
277 timedelta(minutes=MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES)
278 if only_if_after and now < only_if_after:
279 return
281 # OK, we haven't disposed the pool recently
282 self.__last_pool_dispose_time = now
283 self.engine.dispose()
284 self.__logger.info('Disposing pool checked in connections so that they get recreated')
286 def _get_job(self, job_name, session, with_for_update=False):
287 jobs = session.query(Job)
288 if with_for_update:
289 jobs = jobs.with_for_update()
290 jobs = jobs.filter_by(name=job_name)
292 job = jobs.first()
293 if job:
294 self.__logger.debug(f"Loaded job from db: name=[{job_name}], stage=[{job.stage}], result=[{job.result}]")
296 return job
298 def _check_job_timeout(self, job_internal, *, max_execution_timeout=None):
299 """ Do a lazy check of maximum allowed job timeouts when clients try to retrieve
300 an existing job.
301 Cancel the job and related operations/leases, if we detect they have
302 exceeded timeouts on access.
304 Returns the `buildgrid.server.Job` object, possibly updated with `cancelled=True`.
305 """
306 if job_internal and max_execution_timeout and job_internal.worker_start_timestamp_as_datetime:
307 if job_internal.operation_stage == OperationStage.EXECUTING:
308 executing_duration = datetime.utcnow() - job_internal.worker_start_timestamp_as_datetime
309 if executing_duration.total_seconds() >= max_execution_timeout:
310 self.__logger.warning(f"Job=[{job_internal}] has been executing for "
311 f"executing_duration=[{executing_duration}]. "
312 f"max_execution_timeout=[{max_execution_timeout}] "
313 "Cancelling.")
314 job_internal.cancel_all_operations(data_store=self)
315 self.__logger.info(f"Job=[{job_internal}] has been cancelled.")
316 return job_internal
318 @DurationMetric(DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, instanced=True)
319 def get_job_by_action(self, action_digest, *, max_execution_timeout=None):
320 with self.session() as session:
321 jobs = session.query(Job).filter_by(action_digest=digest_to_string(action_digest))
322 jobs = jobs.filter(Job.stage != OperationStage.COMPLETED.value)
323 job = jobs.first()
324 if job:
325 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url,
326 instance_name=self._instance_name)
327 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout)
328 return None
330 @DurationMetric(DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, instanced=True)
331 def get_job_by_name(self, name, *, max_execution_timeout=None):
332 with self.session() as session:
333 job = self._get_job(name, session)
334 if job:
335 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url,
336 instance_name=self._instance_name)
337 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout)
338 return None
340 @DurationMetric(DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, instanced=True)
341 def get_job_by_operation(self, operation_name, *, max_execution_timeout=None):
342 with self.session() as session:
343 operation = self._get_operation(operation_name, session)
344 if operation and operation.job:
345 job = operation.job
346 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url,
347 instance_name=self._instance_name)
348 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout)
349 return None
351 def get_all_jobs(self):
352 with self.session() as session:
353 jobs = session.query(Job).filter(Job.stage != OperationStage.COMPLETED.value)
354 return [j.to_internal_job(self, action_browser_url=self._action_browser_url,
355 instance_name=self._instance_name) for j in jobs]
357 def get_jobs_by_stage(self, operation_stage):
358 with self.session() as session:
359 jobs = session.query(Job).filter(Job.stage == operation_stage.value)
360 return [j.to_internal_job(self, no_result=True, action_browser_url=self._action_browser_url,
361 instance_name=self._instance_name) for j in jobs]
363 @DurationMetric(DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, instanced=True)
364 def create_job(self, job):
365 with self.session() as session:
366 if self._get_job(job.name, session) is None:
367 # Convert requirements values to sorted lists to make them json-serializable
368 platform_requirements = job.platform_requirements
369 convert_values_to_sorted_lists(platform_requirements)
370 # Serialize the requirements
371 platform_requirements_hash = hash_from_dict(platform_requirements)
373 session.add(Job(
374 name=job.name,
375 action=job.action.SerializeToString(),
376 action_digest=digest_to_string(job.action_digest),
377 do_not_cache=job.do_not_cache,
378 priority=job.priority,
379 operations=[],
380 platform_requirements=platform_requirements_hash,
381 stage=job.operation_stage.value,
382 queued_timestamp=job.queued_timestamp_as_datetime,
383 queued_time_duration=job.queued_time_duration.seconds,
384 worker_start_timestamp=job.worker_start_timestamp_as_datetime,
385 worker_completed_timestamp=job.worker_completed_timestamp_as_datetime
386 ))
388 @DurationMetric(DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, instanced=True)
389 def queue_job(self, job_name):
390 with self.session(sqlite_lock_immediately=True) as session:
391 job = self._get_job(job_name, session, with_for_update=True)
392 job.assigned = False
394 @DurationMetric(DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, instanced=True)
395 def update_job(self, job_name, changes, *, skip_notify=False):
396 if "result" in changes:
397 changes["result"] = digest_to_string(changes["result"])
398 if "action_digest" in changes:
399 changes["action_digest"] = digest_to_string(changes["action_digest"])
401 with self.session() as session:
402 job = self._get_job(job_name, session)
403 job.update(changes)
404 if not skip_notify:
405 self._notify_job_updated(job_name, session)
407 def _notify_job_updated(self, job_names, session):
408 if self.engine.dialect.name == "postgresql":
409 if isinstance(job_names, str):
410 job_names = [job_names]
411 for job_name in job_names:
412 session.execute(f"NOTIFY job_updated, '{job_name}';")
414 def delete_job(self, job_name):
415 if job_name in self.response_cache:
416 del self.response_cache[job_name]
418 def wait_for_job_updates(self):
419 self.__logger.info("Starting job watcher thread")
420 if self.engine.dialect.name == "postgresql":
421 self._listen_for_updates()
422 else:
423 self._poll_for_updates()
425 def _listen_for_updates(self):
426 def _listen_loop():
427 try:
428 conn = self.engine.connect()
429 conn.execute(text("LISTEN job_updated;").execution_options(autocommit=True))
430 except Exception as e:
431 raise DatabaseError("Could not start listening to DB for job updates") from e
433 while self.watcher_keep_running:
434 # Wait until the connection is ready for reading. Timeout after 5 seconds
435 # and try again if there was nothing to read. If the connection becomes
436 # readable, collect the notifications it has received and handle them.
437 #
438 # See http://initd.org/psycopg/docs/advanced.html#async-notify
439 if select.select([conn.connection], [], [], self.poll_interval) == ([], [], []):
440 pass
441 else:
443 try:
444 conn.connection.poll()
445 except Exception as e:
446 raise DatabaseError("Error while polling for job updates") from e
448 while conn.connection.notifies:
449 notify = conn.connection.notifies.pop()
450 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME,
451 instanced=True, instance_name=self._instance_name):
452 with self.watched_jobs_lock:
453 spec = self.watched_jobs.get(notify.payload)
454 if spec is not None:
455 try:
456 new_job = self.get_job_by_name(notify.payload)
457 except Exception as e:
458 raise DatabaseError(
459 f"Couldn't get watched job=[{notify.payload}] from DB") from e
461 # If the job doesn't exist or an exception was supressed by
462 # get_job_by_name, it returns None instead of the job
463 if new_job is None:
464 raise DatabaseError(
465 f"get_job_by_name returned None for job=[{notify.payload}]")
467 new_state = JobState(new_job)
468 if spec.last_state != new_state:
469 spec.last_state = new_state
470 spec.event.notify_change()
472 while self.watcher_keep_running:
473 # Wait a few seconds if a database exception occurs and then try again
474 # This could be a short disconnect
475 try:
476 _listen_loop()
477 except DatabaseError as e:
478 self.__logger.warning(f"JobWatcher encountered exception: [{e}];"
479 f"Retrying in poll_interval=[{self.poll_interval}] seconds.")
480 # Sleep for a bit so that we give enough time for the
481 # database to potentially recover
482 time.sleep(self.poll_interval)
484 def _get_watched_jobs(self):
485 with self.session() as sess:
486 jobs = sess.query(Job).filter(
487 Job.name.in_(self.watched_jobs)
488 )
489 return [job.to_internal_job(self) for job in jobs.all()]
491 def _poll_for_updates(self):
492 def _poll_loop():
493 while self.watcher_keep_running:
494 time.sleep(self.poll_interval)
495 if self.watcher_keep_running:
496 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME,
497 instanced=True, instance_name=self._instance_name):
498 with self.watched_jobs_lock:
499 if self.watcher_keep_running:
500 try:
501 watched_jobs = self._get_watched_jobs()
502 except Exception as e:
503 raise DatabaseError("Couldn't retrieve watched jobs from DB") from e
505 if watched_jobs is None:
506 raise DatabaseError("_get_watched_jobs returned None")
508 for new_job in watched_jobs:
509 if self.watcher_keep_running:
510 spec = self.watched_jobs[new_job.name]
511 new_state = JobState(new_job)
512 if spec.last_state != new_state:
513 spec.last_state = new_state
514 spec.event.notify_change()
516 while self.watcher_keep_running:
517 # Wait a few seconds if a database exception occurs and then try again
518 try:
519 _poll_loop()
520 except DatabaseError as e:
521 self.__logger.warning(f"JobWatcher encountered exception: [{e}];"
522 f"Retrying in poll_interval=[{self.poll_interval}] seconds.")
523 # Sleep for a bit so that we give enough time for the
524 # database to potentially recover
525 time.sleep(self.poll_interval)
527 @DurationMetric(DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, instanced=True)
528 def store_response(self, job, commit_changes=True):
529 digest = self.storage.put_message(job.execute_response)
530 changes = {"result": digest, "status_code": job.execute_response.status.code}
531 self.response_cache[job.name] = job.execute_response
533 if commit_changes:
534 self.update_job(job.name,
535 changes,
536 skip_notify=True)
537 return None
538 else:
539 # The caller will batch the changes and commit to db
540 return changes
542 def _get_operation(self, operation_name, session):
543 operations = session.query(Operation).filter_by(name=operation_name)
544 return operations.first()
546 def get_operations_by_stage(self, operation_stage):
547 with self.session() as session:
548 operations = session.query(Operation)
549 operations = operations.filter(Operation.job.has(stage=operation_stage.value))
550 operations = operations.all()
551 # Return a set of job names here for now, to match the `MemoryDataStore`
552 # implementation's behaviour
553 return set(op.job.name for op in operations)
555 def _cancel_jobs_exceeding_execution_timeout(self, max_execution_timeout: int=None) -> None:
556 if max_execution_timeout:
557 stale_job_names = []
558 lazy_execution_timeout_threshold = datetime.utcnow() - timedelta(seconds=max_execution_timeout)
560 jobs_table = Job.__table__
561 operations_table = Operation.__table__
563 with self.session(sqlite_lock_immediately=True) as session:
564 # Get the full list of jobs exceeding execution timeout
565 stale_jobs = session.query(Job).filter_by(stage=OperationStage.EXECUTING.value)
566 stale_jobs = stale_jobs.filter(Job.worker_start_timestamp <= lazy_execution_timeout_threshold)
567 stale_job_names = [job.name for job in stale_jobs.with_for_update().all()]
569 if stale_job_names:
570 # Mark operations as cancelled
571 stmt_mark_operations_cancelled = operations_table.update().where(
572 operations_table.c.job_name.in_(stale_job_names)
573 ).values(cancelled=True)
574 session.execute(stmt_mark_operations_cancelled)
576 # Mark jobs as cancelled
577 stmt_mark_jobs_cancelled = jobs_table.update().where(
578 jobs_table.c.name.in_(stale_job_names)
579 ).values(stage=OperationStage.COMPLETED.value, cancelled=True)
580 session.execute(stmt_mark_jobs_cancelled)
582 # Notify all jobs updated
583 self._notify_job_updated(stale_job_names, session)
585 if stale_job_names:
586 self.__logger.info(f"Cancelled n=[{len(stale_job_names)}] jobs "
587 f"with names={stale_job_names}"
588 f"due to them exceeding execution_timeout=["
589 f"{max_execution_timeout}")
591 @DurationMetric(DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, instanced=True)
592 def list_operations(self,
593 operation_filters: List[OperationFilter]=None,
594 page_size: int=None,
595 page_token: str=None,
596 max_execution_timeout: int=None) -> Tuple[List[operations_pb2.Operation], str]:
597 # Lazily timeout jobs as needed before returning the list!
598 self._cancel_jobs_exceeding_execution_timeout(max_execution_timeout=max_execution_timeout)
600 # Build filters and sort order
601 sort_keys = DEFAULT_SORT_KEYS
602 custom_filters = None
603 if operation_filters:
604 # Extract custom sort order (if present)
605 specified_sort_keys, non_sort_filters = extract_sort_keys(operation_filters)
607 # Only override sort_keys if there were sort keys actually present in the filter string
608 if specified_sort_keys:
609 sort_keys = specified_sort_keys
610 # Attach the operation name as a sort key for a deterministic order
611 # This will ensure that the ordering of results is consistent between queries
612 if not any(sort_key.name == "name" for sort_key in sort_keys):
613 sort_keys.append(SortKey(name="name", descending=False))
615 # Finally, compile the non-sort filters into a filter list
616 custom_filters = build_custom_filters(non_sort_filters)
618 sort_columns = build_sort_column_list(sort_keys)
620 with self.session() as session:
621 results = session.query(Operation).join(Job, Operation.job_name == Job.name)
623 # Apply custom filters (if present)
624 if custom_filters:
625 results = results.filter(*custom_filters)
627 # Apply sort order
628 results = results.order_by(*sort_columns)
630 # Apply pagination filter
631 if page_token:
632 page_filter = build_page_filter(page_token, sort_keys)
633 results = results.filter(page_filter)
634 if page_size:
635 # We limit the number of operations we fetch to the page_size. However, we
636 # fetch an extra operation to determine whether we need to provide a
637 # next_page_token.
638 results = results.limit(page_size + 1)
640 operations = list(results)
642 if not page_size or not operations:
643 next_page_token = ""
645 # If the number of results we got is less than or equal to our page_size,
646 # we're done with the operations listing and don't need to provide another
647 # page token
648 elif len(operations) <= page_size:
649 next_page_token = ""
650 else:
651 # Drop the last operation since we have an extra
652 operations.pop()
653 # Our page token will be the last row of our set
654 next_page_token = build_page_token(operations[-1], sort_keys)
655 return [operation.to_protobuf(self) for operation in operations], next_page_token
657 @DurationMetric(DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, instanced=True)
658 def create_operation(self, operation_name, job_name, request_metadata=None):
659 with self.session() as session:
660 operation = Operation(
661 name=operation_name,
662 job_name=job_name
663 )
664 if request_metadata is not None:
665 if request_metadata.tool_invocation_id:
666 operation.invocation_id = request_metadata.tool_invocation_id
667 if request_metadata.correlated_invocations_id:
668 operation.correlated_invocations_id = request_metadata.correlated_invocations_id
669 if request_metadata.tool_details:
670 operation.tool_name = request_metadata.tool_details.tool_name
671 operation.tool_version = request_metadata.tool_details.tool_version
672 session.add(operation)
674 @DurationMetric(DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME, instanced=True)
675 def update_operation(self, operation_name, changes):
676 with self.session() as session:
677 operation = self._get_operation(operation_name, session)
678 operation.update(changes)
680 def delete_operation(self, operation_name):
681 # Don't do anything. This function needs to exist but there's no
682 # need to actually delete operations in this implementation.
683 pass
685 def get_leases_by_state(self, lease_state):
686 with self.session() as session:
687 leases = session.query(Lease).filter_by(state=lease_state.value)
688 leases = leases.all()
689 # `lease.job_name` is the same as `lease.id` for a Lease protobuf
690 return set(lease.job_name for lease in leases)
692 def get_metrics(self):
694 def _get_query_leases_by_state(session, category):
695 # Using func.count here to avoid generating a subquery in the WHERE
696 # clause of the resulting query.
697 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
698 query = session.query(literal_column(category).label("category"),
699 Lease.state.label("bucket"),
700 func.count(Lease.id).label("value"))
701 query = query.group_by(Lease.state)
702 return query
704 def _cb_query_leases_by_state(leases_by_state):
705 # The database only returns counts > 0, so fill in the gaps
706 for state in LeaseState:
707 if state.value not in leases_by_state:
708 leases_by_state[state.value] = 0
709 return leases_by_state
711 def _get_query_operations_by_stage(session, category):
712 # Using func.count here to avoid generating a subquery in the WHERE
713 # clause of the resulting query.
714 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
715 query = session.query(literal_column(category).label("category"),
716 Job.stage.label("bucket"),
717 func.count(Operation.name).label("value"))
718 query = query.join(Job)
719 query = query.group_by(Job.stage)
720 return query
722 def _cb_query_operations_by_stage(operations_by_stage):
723 # The database only returns counts > 0, so fill in the gaps
724 for stage in OperationStage:
725 if stage.value not in operations_by_stage:
726 operations_by_stage[stage.value] = 0
727 return operations_by_stage
729 def _get_query_jobs_by_stage(session, category):
730 # Using func.count here to avoid generating a subquery in the WHERE
731 # clause of the resulting query.
732 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count
733 query = session.query(literal_column(category).label("category"),
734 Job.stage.label("bucket"),
735 func.count(Job.name).label("value"))
736 query = query.group_by(Job.stage)
737 return query
739 def _cb_query_jobs_by_stage(jobs_by_stage):
740 # The database only returns counts > 0, so fill in the gaps
741 for stage in OperationStage:
742 if stage.value not in jobs_by_stage:
743 jobs_by_stage[stage.value] = 0
744 return jobs_by_stage
746 metrics = {}
747 try:
748 with self.session() as session:
749 # metrics to gather: (category_name, function_returning_query, callback_function)
750 metrics_to_gather = [(MetricCategories.LEASES.value, _get_query_leases_by_state,
751 _cb_query_leases_by_state),
752 (MetricCategories.OPERATIONS.value, _get_query_operations_by_stage,
753 _cb_query_operations_by_stage),
754 (MetricCategories.JOBS.value, _get_query_jobs_by_stage,
755 _cb_query_jobs_by_stage)]
757 union_query = union(*[query_fn(session, f"'{category}'")
758 for category, query_fn, _ in metrics_to_gather])
759 union_results = session.execute(union_query).fetchall()
761 grouped_results = {category: {} for category, _, _ in union_results}
762 for category, bucket, value in union_results:
763 grouped_results[category][bucket] = value
765 for category, _, category_cb in metrics_to_gather:
766 metrics[category] = category_cb(grouped_results.setdefault(category, {}))
767 except DatabaseError:
768 self.__logger.warning("Unable to gather metrics due to a Database Error.")
769 return {}
771 return metrics
773 def _create_lease(self, lease, session, job=None):
774 if job is None:
775 job = self._get_job(lease.id, session)
776 job = job.to_internal_job(self)
777 session.add(Lease(
778 job_name=lease.id,
779 state=lease.state,
780 status=None,
781 worker_name=job.worker_name
782 ))
784 def create_lease(self, lease):
785 with self.session() as session:
786 self._create_lease(lease, session)
788 @DurationMetric(DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, instanced=True)
789 def update_lease(self, job_name, changes):
790 with self.session() as session:
791 job = self._get_job(job_name, session)
792 lease = job.active_leases[0]
793 lease.update(changes)
795 def load_unfinished_jobs(self):
796 with self.session() as session:
797 jobs = session.query(Job)
798 jobs = jobs.filter(Job.stage != OperationStage.COMPLETED.value)
799 jobs = jobs.order_by(Job.priority)
800 return [j.to_internal_job(self) for j in jobs.all()]
802 def assign_lease_for_next_job(self, capabilities, callback, timeout=None):
803 """Return a list of leases for the highest priority jobs that can be run by a worker.
805 NOTE: Currently the list only ever has one or zero leases.
807 Query the jobs table to find queued jobs which match the capabilities of
808 a given worker, and return the one with the highest priority. Takes a
809 dictionary of worker capabilities to compare with job requirements.
811 :param capabilities: Dictionary of worker capabilities to compare
812 with job requirements when finding a job.
813 :type capabilities: dict
814 :param callback: Function to run on the next runnable job, should return
815 a list of leases.
816 :type callback: function
817 :param timeout: time to wait for new jobs, caps if longer
818 than MAX_JOB_BLOCK_TIME.
819 :type timeout: int
820 :returns: List of leases
822 """
823 if not timeout:
824 return self._assign_job_leases(capabilities, callback)
826 # Cap the timeout if it's larger than MAX_JOB_BLOCK_TIME
827 if timeout:
828 timeout = min(timeout, MAX_JOB_BLOCK_TIME)
830 start = time.time()
831 while time.time() + self.connection_timeout + 1 - start < timeout:
832 leases = self._assign_job_leases(capabilities, callback)
833 if leases:
834 return leases
835 time.sleep(0.5)
836 if self.connection_timeout > timeout:
837 self.__logger.warning(
838 "Not providing any leases to the worker because the database connection "
839 f"timeout ({self.connection_timeout} s) is longer than the remaining "
840 "time to handle the request. "
841 "Increase the worker's timeout to solve this problem.")
842 return []
844 def flatten_capabilities(self, capabilities: Dict[str, List[str]]) -> List[Tuple[str, str]]:
845 """ Flatten a capabilities dictionary, assuming all of its values are lists. E.g.
847 {'OSFamily': ['Linux'], 'ISA': ['x86-32', 'x86-64']}
849 becomes
851 [('OSFamily', 'Linux'), ('ISA', 'x86-32'), ('ISA', 'x86-64')] """
852 return [
853 (name, value) for name, value_list in capabilities.items()
854 for value in value_list
855 ]
857 def get_partial_capabilities(self, capabilities: Dict[str, List[str]]) -> Iterable[Dict[str, List[str]]]:
858 """ Given a capabilities dictionary with all values as lists,
859 yield all partial capabilities dictionaries. """
860 CAPABILITIES_WARNING_THRESHOLD = 10
862 caps_flat = self.flatten_capabilities(capabilities)
864 if len(caps_flat) > CAPABILITIES_WARNING_THRESHOLD:
865 self.__logger.warning(
866 "A worker with a large capabilities dictionary has been connected. "
867 f"Processing its capabilities may take a while. Capabilities: {capabilities}")
869 # Using the itertools powerset recipe, construct the powerset of the tuples
870 capabilities_powerset = chain.from_iterable(combinations(caps_flat, r) for r in range(len(caps_flat) + 1))
871 for partial_capability_tuples in capabilities_powerset:
872 partial_dict: Dict[str, List[str]] = {}
874 for tup in partial_capability_tuples:
875 partial_dict.setdefault(tup[0], []).append(tup[1])
876 yield partial_dict
878 def get_partial_capabilities_hashes(self, capabilities: Dict) -> List[str]:
879 """ Given a list of configurations, obtain each partial configuration
880 for each configuration, obtain the hash of each partial configuration,
881 compile these into a list, and return the result. """
882 # Convert requirements values to sorted lists to make them json-serializable
883 convert_values_to_sorted_lists(capabilities)
885 # Check to see if we've cached this value
886 capabilities_digest = hash_from_dict(capabilities)
887 try:
888 return self.capabilities_cache[capabilities_digest]
889 except KeyError:
890 # On cache miss, expand the capabilities into each possible partial capabilities dictionary
891 capabilities_list = []
892 for partial_capability in self.get_partial_capabilities(capabilities):
893 capabilities_list.append(hash_from_dict(partial_capability))
895 self.capabilities_cache[capabilities_digest] = capabilities_list
896 return capabilities_list
898 @DurationMetric(BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, instanced=True)
899 def _assign_job_leases(self, capabilities, callback):
900 # pylint: disable=singleton-comparison
901 # Hash the capabilities
902 capabilities_config_hashes = self.get_partial_capabilities_hashes(capabilities)
903 leases = []
904 try:
905 create_lease_start_time = None
906 with self.session(sqlite_lock_immediately=True) as session:
907 jobs = session.query(Job).with_for_update(skip_locked=True)
908 jobs = jobs.filter(Job.stage == OperationStage.QUEUED.value)
909 jobs = jobs.filter(Job.assigned != True) # noqa
910 jobs = jobs.filter(Job.platform_requirements.in_(capabilities_config_hashes))
911 job = jobs.order_by(Job.priority, Job.queued_timestamp).first()
912 # This worker can take this job if it can handle all of its configurations
913 if job:
914 internal_job = job.to_internal_job(self)
915 leases = callback(internal_job)
916 if leases:
917 job.assigned = True
918 job.worker_start_timestamp = internal_job.worker_start_timestamp_as_datetime
919 create_lease_start_time = time.perf_counter()
920 for lease in leases:
921 self._create_lease(lease, session, job=internal_job)
923 # Calculate and publish the time taken to create leases. This is done explicitly
924 # rather than using the DurationMetric helper since we need to measure the actual
925 # execution time of the UPDATE and INSERT queries used in the lease assignment, and
926 # these are only exectuted on exiting the contextmanager.
927 if create_lease_start_time is not None:
928 run_time = timedelta(seconds=time.perf_counter() - create_lease_start_time)
929 metadata = None
930 if self._instance_name is not None:
931 metadata = {'instance-name': self._instance_name}
932 publish_timer_metric(DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, run_time, metadata=metadata)
934 except DatabaseError:
935 self.__logger.warning("Will not assign any leases this time due to a Database Error.")
937 return leases