Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/persistence/sql/impl.py: 76.99%

552 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-22 21:04 +0000

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from contextlib import contextmanager 

17import logging 

18import os 

19import select 

20from threading import Thread 

21import time 

22from datetime import datetime, timedelta 

23from tempfile import NamedTemporaryFile 

24from itertools import chain, combinations 

25from typing import Dict, Iterable, List, Tuple, NamedTuple, Optional 

26 

27from alembic import command 

28from alembic.config import Config 

29from sqlalchemy import and_, create_engine, delete, event, func, union, literal_column, update 

30from sqlalchemy import select as sql_select 

31from sqlalchemy.future import Connection 

32from sqlalchemy.orm.session import sessionmaker 

33from sqlalchemy.sql.expression import Select 

34 

35from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

36from buildgrid._protos.google.longrunning import operations_pb2 

37from buildgrid._enums import LeaseState, MetricCategories, OperationStage 

38from buildgrid.server.sql import sqlutils 

39from buildgrid.server.metrics_names import ( 

40 BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, 

41 DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

42 DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, 

43 DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, 

44 DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, 

45 DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, 

46 DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, 

47 DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, 

48 DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, 

49 DATA_STORE_PRUNER_NUM_ROWS_DELETED_METRIC_NAME, 

50 DATA_STORE_PRUNER_DELETE_TIME_METRIC_NAME, 

51 DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, 

52 DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, 

53 DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, 

54 DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, 

55 DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME 

56) 

57from buildgrid.server.job_metrics import JobMetrics 

58from buildgrid.server.metrics_utils import DurationMetric, publish_timer_metric, Counter 

59from buildgrid.server.operations.filtering import OperationFilter, SortKey, DEFAULT_SORT_KEYS 

60from buildgrid.server.persistence.interface import DataStoreInterface 

61from buildgrid.server.persistence.sql.models import digest_to_string, Job, Lease, Operation 

62from buildgrid.server.persistence.sql.utils import ( 

63 build_page_filter, 

64 build_page_token, 

65 extract_sort_keys, 

66 build_custom_filters, 

67 build_sort_column_list 

68) 

69from buildgrid.settings import ( 

70 MAX_JOB_BLOCK_TIME, 

71 MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES, 

72 COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS, 

73 SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS 

74) 

75from buildgrid.utils import JobState, hash_from_dict, convert_values_to_sorted_lists 

76 

77from buildgrid._exceptions import DatabaseError, RetriableDatabaseError 

78 

79 

80Session = sessionmaker(future=True) 

81 

82 

83def sqlite_on_connect(conn, record): 

84 conn.execute("PRAGMA journal_mode=WAL") 

85 conn.execute("PRAGMA synchronous=NORMAL") 

86 

87 

88class PruningOptions(NamedTuple): 

89 pruner_job_max_age: timedelta = timedelta(days=30) 

90 pruner_period: timedelta = timedelta(minutes=5) 

91 pruner_max_delete_window: int = 10000 

92 

93 @staticmethod 

94 def from_config(pruner_job_max_age_cfg: Dict[str, float], 

95 pruner_period_cfg: Dict[str, float] = None, 

96 pruner_max_delete_window_cfg: int = None): 

97 """ Helper method for creating ``PruningOptions`` objects 

98 If input configs are None, assign defaults """ 

99 def _dict_to_timedelta(config: Dict[str, float]) -> timedelta: 

100 return timedelta(weeks=config.get('weeks', 0), 

101 days=config.get('days', 0), 

102 hours=config.get('hours', 0), 

103 minutes=config.get('minutes', 0), 

104 seconds=config.get('seconds', 0)) 

105 

106 return PruningOptions(pruner_job_max_age=_dict_to_timedelta( 

107 pruner_job_max_age_cfg) if pruner_job_max_age_cfg else timedelta(days=30), 

108 pruner_period=_dict_to_timedelta( 

109 pruner_period_cfg) if pruner_period_cfg else timedelta(minutes=5), 

110 pruner_max_delete_window=pruner_max_delete_window_cfg 

111 if pruner_max_delete_window_cfg else 10000) 

112 

113 

114class SQLDataStore(DataStoreInterface): 

115 

116 def __init__(self, storage, *, connection_string=None, automigrate=False, 

117 connection_timeout=5, poll_interval=1, 

118 pruning_options: Optional[PruningOptions] = None, 

119 **kwargs): 

120 super().__init__(storage) 

121 self.__logger = logging.getLogger(__name__) 

122 self.__logger.info("Creating SQL scheduler with: " 

123 f"automigrate=[{automigrate}], connection_timeout=[{connection_timeout}] " 

124 f"poll_interval=[{poll_interval}], " 

125 f"pruning_options=[{pruning_options}], " 

126 f"kwargs=[{kwargs}]") 

127 

128 self.response_cache: Dict[str, remote_execution_pb2.ExecuteResponse] = {} 

129 self.connection_timeout = connection_timeout 

130 self.poll_interval = poll_interval 

131 self.watcher = Thread(name="JobWatcher", target=self.wait_for_job_updates, daemon=True) 

132 self.watcher_keep_running = True 

133 

134 # Set-up temporary SQLite Database when connection string is not specified 

135 if not connection_string: 

136 # pylint: disable=consider-using-with 

137 tmpdbfile = NamedTemporaryFile(prefix='bgd-', suffix='.db') 

138 self._tmpdbfile = tmpdbfile # Make sure to keep this tempfile for the lifetime of this object 

139 self.__logger.warning("No connection string specified for the DataStore, " 

140 f"will use SQLite with tempfile: [{tmpdbfile.name}]") 

141 automigrate = True # since this is a temporary database, we always need to create it 

142 connection_string = f"sqlite:///{tmpdbfile.name}" 

143 

144 self._create_sqlalchemy_engine(connection_string, automigrate, connection_timeout, **kwargs) 

145 

146 self._sql_pool_dispose_helper = sqlutils.SQLPoolDisposeHelper(COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS, 

147 MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES, 

148 self.engine) 

149 

150 # Make a test query against the database to ensure the connection is valid 

151 with self.session(reraise=True) as session: 

152 session.execute(sql_select([1])).all() 

153 

154 self.watcher.start() 

155 

156 self.capabilities_cache: Dict[str, List[str]] = {} 

157 

158 # Pruning configuration parameters 

159 if pruning_options is not None: 

160 self.pruner_keep_running = True 

161 self.__logger.info(f"Scheduler pruning enabled: {pruning_options}") 

162 self.__pruner_thread = Thread(name="JobsPruner", target=self._do_prune, args=( 

163 pruning_options.pruner_job_max_age, pruning_options.pruner_period, 

164 pruning_options.pruner_max_delete_window), daemon=True) 

165 self.__pruner_thread.start() 

166 else: 

167 self.__logger.info("Scheduler pruning not enabled") 

168 

169 # Overall Scheduler Metrics (totals of jobs/leases in each state) 

170 # Publish those metrics a bit more sparsely since the SQL requests 

171 # required to gather them can become expensive 

172 self.__last_scheduler_metrics_publish_time = None 

173 self.__scheduler_metrics_publish_interval = timedelta( 

174 seconds=SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS) 

175 

176 def _create_sqlalchemy_engine(self, connection_string, automigrate, connection_timeout, **kwargs): 

177 self.automigrate = automigrate 

178 

179 # Disallow sqlite in-memory because multi-threaded access to it is 

180 # complex and potentially problematic at best 

181 # ref: https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#threading-pooling-behavior 

182 if sqlutils.is_sqlite_inmemory_connection_string(connection_string): 

183 raise ValueError( 

184 f"Cannot use SQLite in-memory with BuildGrid (connection_string=[{connection_string}]). " 

185 "Use a file or leave the connection_string empty for a tempfile.") 

186 

187 if connection_timeout is not None: 

188 if "connect_args" not in kwargs: 

189 kwargs["connect_args"] = {} 

190 if sqlutils.is_sqlite_connection_string(connection_string): 

191 kwargs["connect_args"]["timeout"] = connection_timeout 

192 elif sqlutils.is_psycopg2_connection_string(connection_string): 

193 kwargs["connect_args"]["connect_timeout"] = connection_timeout 

194 # Additional postgres specific timeouts 

195 # Additional libpg options 

196 # Note that those timeouts are in milliseconds (so *1000) 

197 kwargs["connect_args"]["options"] = f'-c lock_timeout={connection_timeout * 1000}' 

198 

199 # Only pass the (known) kwargs that have been explicitly set by the user 

200 available_options = set([ 

201 'pool_size', 'max_overflow', 'pool_timeout', 'pool_pre_ping', 

202 'pool_recycle', 'connect_args' 

203 ]) 

204 kwargs_keys = set(kwargs.keys()) 

205 if not kwargs_keys.issubset(available_options): 

206 unknown_options = kwargs_keys - available_options 

207 raise TypeError(f"Unknown keyword arguments: [{unknown_options}]") 

208 

209 self.__logger.debug(f"SQLAlchemy additional kwargs: [{kwargs}]") 

210 

211 self.engine = create_engine(connection_string, echo=False, future=True, **kwargs) 

212 Session.configure(bind=self.engine) 

213 

214 if self.engine.dialect.name == "sqlite": 

215 event.listen(self.engine, "connect", sqlite_on_connect) 

216 

217 if self.automigrate: 

218 self._create_or_migrate_db(connection_string) 

219 

220 def __repr__(self): 

221 return f"SQL data store interface for `{repr(self.engine.url)}`" 

222 

223 def activate_monitoring(self): 

224 # Don't do anything. This function needs to exist but there's no 

225 # need to actually toggle monitoring in this implementation. 

226 pass 

227 

228 def deactivate_monitoring(self): 

229 # Don't do anything. This function needs to exist but there's no 

230 # need to actually toggle monitoring in this implementation. 

231 pass 

232 

233 def _create_or_migrate_db(self, connection_string): 

234 self.__logger.warning("Will attempt migration to latest version if needed.") 

235 

236 config = Config() 

237 config.set_main_option("script_location", os.path.join(os.path.dirname(__file__), "alembic")) 

238 

239 with self.engine.begin() as connection: 

240 config.attributes['connection'] = connection 

241 command.upgrade(config, "head") 

242 

243 @contextmanager 

244 def session(self, *, sqlite_lock_immediately=False, reraise=False): 

245 # If we recently disposed of the SQL pool due to connection issues 

246 # allow for some cooldown period before we attempt more SQL 

247 self._sql_pool_dispose_helper.wait_if_cooldown_in_effect() 

248 

249 # Try to obtain a session 

250 try: 

251 session = Session() 

252 if sqlite_lock_immediately and session.bind.name == "sqlite": 

253 session.execute("BEGIN IMMEDIATE") 

254 except Exception as e: 

255 self.__logger.error("Unable to obtain a database session.", exc_info=True) 

256 raise DatabaseError("Unable to obtain a database session.") from e 

257 

258 # Yield the session and catch exceptions that occur while using it 

259 # to roll-back if needed 

260 try: 

261 yield session 

262 session.commit() 

263 except Exception as e: 

264 transient_dberr = self._sql_pool_dispose_helper.check_dispose_pool(session, e) 

265 if transient_dberr: 

266 self.__logger.warning("Rolling back database session due to transient database error.", exc_info=True) 

267 else: 

268 self.__logger.error("Error committing database session. Rolling back.", exc_info=True) 

269 try: 

270 session.rollback() 

271 except Exception: 

272 self.__logger.warning("Rollback error.", exc_info=True) 

273 

274 if reraise: 

275 if transient_dberr: 

276 raise RetriableDatabaseError("Database connection was temporarily interrupted, please retry", 

277 timedelta(seconds=COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS)) from e 

278 raise 

279 finally: 

280 session.close() 

281 

282 def _get_job(self, job_name, session, with_for_update=False): 

283 statement = sql_select(Job).filter_by(name=job_name) 

284 if with_for_update: 

285 statement = statement.with_for_update() 

286 

287 job = session.execute(statement).scalars().first() 

288 if job: 

289 self.__logger.debug(f"Loaded job from db: name=[{job_name}], stage=[{job.stage}], result=[{job.result}]") 

290 

291 return job 

292 

293 def _check_job_timeout(self, job_internal, *, max_execution_timeout=None): 

294 """ Do a lazy check of maximum allowed job timeouts when clients try to retrieve 

295 an existing job. 

296 Cancel the job and related operations/leases, if we detect they have 

297 exceeded timeouts on access. 

298 

299 Returns the `buildgrid.server.Job` object, possibly updated with `cancelled=True`. 

300 """ 

301 if job_internal and max_execution_timeout and job_internal.worker_start_timestamp_as_datetime: 

302 if job_internal.operation_stage == OperationStage.EXECUTING: 

303 executing_duration = datetime.utcnow() - job_internal.worker_start_timestamp_as_datetime 

304 if executing_duration.total_seconds() >= max_execution_timeout: 

305 self.__logger.warning(f"Job=[{job_internal}] has been executing for " 

306 f"executing_duration=[{executing_duration}]. " 

307 f"max_execution_timeout=[{max_execution_timeout}] " 

308 "Cancelling.") 

309 job_internal.cancel_all_operations(data_store=self) 

310 self.__logger.info(f"Job=[{job_internal}] has been cancelled.") 

311 return job_internal 

312 

313 @DurationMetric(DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, instanced=True) 

314 def get_job_by_action(self, action_digest, *, max_execution_timeout=None): 

315 statement = sql_select(Job).where( 

316 and_( 

317 Job.action_digest == digest_to_string(action_digest), 

318 Job.stage != OperationStage.COMPLETED.value 

319 ) 

320 ) 

321 

322 with self.session() as session: 

323 job = session.execute(statement).scalars().first() 

324 if job: 

325 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

326 instance_name=self._instance_name) 

327 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

328 return None 

329 

330 @DurationMetric(DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, instanced=True) 

331 def get_job_by_name(self, name, *, max_execution_timeout=None): 

332 with self.session() as session: 

333 job = self._get_job(name, session) 

334 if job: 

335 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

336 instance_name=self._instance_name) 

337 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

338 return None 

339 

340 @DurationMetric(DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, instanced=True) 

341 def get_job_by_operation(self, operation_name, *, max_execution_timeout=None): 

342 with self.session() as session: 

343 operation = self._get_operation(operation_name, session) 

344 if operation and operation.job: 

345 job = operation.job 

346 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

347 instance_name=self._instance_name) 

348 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

349 return None 

350 

351 def get_all_jobs(self): 

352 statement = sql_select(Job).where( 

353 Job.stage != OperationStage.COMPLETED.value 

354 ) 

355 

356 with self.session() as session: 

357 jobs = session.execute(statement).scalars() 

358 return [j.to_internal_job(self, action_browser_url=self._action_browser_url, 

359 instance_name=self._instance_name) for j in jobs] 

360 

361 def get_jobs_by_stage(self, operation_stage): 

362 statement = sql_select(Job).where( 

363 Job.stage == operation_stage.value 

364 ) 

365 

366 with self.session() as session: 

367 jobs = session.execute(statement).scalars() 

368 return [j.to_internal_job(self, no_result=True, action_browser_url=self._action_browser_url, 

369 instance_name=self._instance_name) for j in jobs] 

370 

371 def get_operation_request_metadata_by_name(self, operation_name): 

372 with self.session() as session: 

373 operation = self._get_operation(operation_name, session) 

374 if not operation: 

375 return None 

376 

377 return {'tool-name': operation.tool_name or '', 

378 'tool-version': operation.tool_version or '', 

379 'invocation-id': operation.invocation_id or '', 

380 'correlated-invocations-id': operation.correlated_invocations_id or ''} 

381 

382 @DurationMetric(DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, instanced=True) 

383 def create_job(self, job): 

384 with self.session() as session: 

385 if self._get_job(job.name, session) is None: 

386 # Convert requirements values to sorted lists to make them json-serializable 

387 platform_requirements = job.platform_requirements 

388 convert_values_to_sorted_lists(platform_requirements) 

389 # Serialize the requirements 

390 platform_requirements_hash = hash_from_dict(platform_requirements) 

391 

392 session.add(Job( 

393 name=job.name, 

394 action=job.action.SerializeToString(), 

395 action_digest=digest_to_string(job.action_digest), 

396 do_not_cache=job.do_not_cache, 

397 priority=job.priority, 

398 operations=[], 

399 platform_requirements=platform_requirements_hash, 

400 stage=job.operation_stage.value, 

401 queued_timestamp=job.queued_timestamp_as_datetime, 

402 queued_time_duration=job.queued_time_duration.seconds, 

403 worker_start_timestamp=job.worker_start_timestamp_as_datetime, 

404 worker_completed_timestamp=job.worker_completed_timestamp_as_datetime 

405 )) 

406 

407 @DurationMetric(DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, instanced=True) 

408 def queue_job(self, job_name): 

409 with self.session(sqlite_lock_immediately=True) as session: 

410 job = self._get_job(job_name, session, with_for_update=True) 

411 job.assigned = False 

412 

413 @DurationMetric(DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, instanced=True) 

414 def update_job(self, job_name, changes, *, skip_notify=False): 

415 if "result" in changes: 

416 changes["result"] = digest_to_string(changes["result"]) 

417 if "action_digest" in changes: 

418 changes["action_digest"] = digest_to_string(changes["action_digest"]) 

419 

420 initial_values_for_metrics_use = {} 

421 

422 with self.session() as session: 

423 job = self._get_job(job_name, session) 

424 

425 # Keep track of the state right before we perform this update 

426 initial_values_for_metrics_use["stage"] = OperationStage(job.stage) 

427 

428 job.update(changes) 

429 if not skip_notify: 

430 self._notify_job_updated(job_name, session) 

431 

432 # Upon successful completion of the transaction above, publish metrics 

433 JobMetrics.publish_metrics_on_job_updates(initial_values_for_metrics_use, changes, self._instance_name) 

434 

435 def _notify_job_updated(self, job_names, session): 

436 if self.engine.dialect.name == "postgresql": 

437 if isinstance(job_names, str): 

438 job_names = [job_names] 

439 for job_name in job_names: 

440 session.execute(f"NOTIFY job_updated, '{job_name}';") 

441 

442 def delete_job(self, job_name): 

443 if job_name in self.response_cache: 

444 del self.response_cache[job_name] 

445 

446 def wait_for_job_updates(self): 

447 self.__logger.info("Starting job watcher thread") 

448 if self.engine.dialect.name == "postgresql": 

449 self._listen_for_updates() 

450 else: 

451 self._poll_for_updates() 

452 

453 def _listen_for_updates(self): 

454 def _listen_loop(engine_conn: Connection): 

455 try: 

456 # Get the DBAPI connection object from the SQLAlchemy Engine.Connection wrapper 

457 connection_fairy = engine_conn.connection 

458 connection_fairy.cursor().execute("LISTEN job_updated;") # type: ignore 

459 connection_fairy.commit() 

460 except Exception: 

461 self.__logger.warning( 

462 "Could not start listening to DB for job updates", 

463 exc_info=True) 

464 # Let the context manager handle this 

465 raise 

466 

467 while self.watcher_keep_running: 

468 # Get the actual DBAPI connection 

469 dbapi_connection = connection_fairy.dbapi_connection # type: ignore 

470 # Wait until the connection is ready for reading. Timeout after 5 seconds 

471 # and try again if there was nothing to read. If the connection becomes 

472 # readable, collect the notifications it has received and handle them. 

473 # 

474 # See http://initd.org/psycopg/docs/advanced.html#async-notify 

475 if select.select([dbapi_connection], [], [], self.poll_interval) == ([], [], []): 

476 pass 

477 else: 

478 

479 try: 

480 dbapi_connection.poll() 

481 except Exception: 

482 self.__logger.warning("Error while polling for job updates", exc_info=True) 

483 # Let the context manager handle this 

484 raise 

485 

486 while dbapi_connection.notifies: 

487 notify = dbapi_connection.notifies.pop() 

488 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

489 instanced=True, instance_name=self._instance_name): 

490 with self.watched_jobs_lock: 

491 spec = self.watched_jobs.get(notify.payload) 

492 if spec is not None: 

493 try: 

494 new_job = self.get_job_by_name(notify.payload) 

495 except Exception: 

496 self.__logger.warning( 

497 f"Couldn't get watched job=[{notify.payload}] from DB", 

498 exc_info=True) 

499 # Let the context manager handle this 

500 raise 

501 

502 # If the job doesn't exist or an exception was supressed by 

503 # get_job_by_name, it returns None instead of the job 

504 if new_job is None: 

505 raise DatabaseError( 

506 f"get_job_by_name returned None for job=[{notify.payload}]") 

507 

508 new_state = JobState(new_job) 

509 if spec.last_state != new_state: 

510 spec.last_state = new_state 

511 spec.event.notify_change() 

512 

513 while self.watcher_keep_running: 

514 # Wait a few seconds if a database exception occurs and then try again 

515 # This could be a short disconnect 

516 try: 

517 # Use the session contextmanager 

518 # so that we can benefit from the common SQL error-handling 

519 with self.session(reraise=True) as session: 

520 # In our `LISTEN` call, we want to *bypass the ORM* 

521 # and *use the underlying Engine connection directly*. 

522 # (This is because using a `session.execute()` will 

523 # implicitly create a SQL transaction, causing 

524 # notifications to only be delivered when that transaction 

525 # is committed) 

526 _listen_loop(session.connection()) 

527 except Exception as e: 

528 self.__logger.warning(f"JobWatcher encountered exception: [{e}];" 

529 f"Retrying in poll_interval=[{self.poll_interval}] seconds.") 

530 # Sleep for a bit so that we give enough time for the 

531 # database to potentially recover 

532 time.sleep(self.poll_interval) 

533 

534 def _get_watched_jobs(self): 

535 statement = sql_select(Job).where( 

536 Job.name.in_(self.watched_jobs) 

537 ) 

538 

539 with self.session() as sess: 

540 jobs = sess.execute(statement).scalars().all() 

541 return [job.to_internal_job(self) for job in jobs] 

542 

543 def _poll_for_updates(self): 

544 def _poll_loop(): 

545 while self.watcher_keep_running: 

546 time.sleep(self.poll_interval) 

547 if self.watcher_keep_running: 

548 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

549 instanced=True, instance_name=self._instance_name): 

550 with self.watched_jobs_lock: 

551 if self.watcher_keep_running: 

552 try: 

553 watched_jobs = self._get_watched_jobs() 

554 except Exception as e: 

555 raise DatabaseError("Couldn't retrieve watched jobs from DB") from e 

556 

557 if watched_jobs is None: 

558 raise DatabaseError("_get_watched_jobs returned None") 

559 

560 for new_job in watched_jobs: 

561 if self.watcher_keep_running: 

562 spec = self.watched_jobs[new_job.name] 

563 new_state = JobState(new_job) 

564 if spec.last_state != new_state: 

565 spec.last_state = new_state 

566 spec.event.notify_change() 

567 

568 while self.watcher_keep_running: 

569 # Wait a few seconds if a database exception occurs and then try again 

570 try: 

571 _poll_loop() 

572 except DatabaseError as e: 

573 self.__logger.warning(f"JobWatcher encountered exception: [{e}];" 

574 f"Retrying in poll_interval=[{self.poll_interval}] seconds.") 

575 # Sleep for a bit so that we give enough time for the 

576 # database to potentially recover 

577 time.sleep(self.poll_interval) 

578 

579 @DurationMetric(DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, instanced=True) 

580 def store_response(self, job, commit_changes=True): 

581 digest = self.storage.put_message(job.execute_response) 

582 changes = {"result": digest, "status_code": job.execute_response.status.code} 

583 self.response_cache[job.name] = job.execute_response 

584 

585 if commit_changes: 

586 self.update_job(job.name, 

587 changes, 

588 skip_notify=True) 

589 return None 

590 else: 

591 # The caller will batch the changes and commit to db 

592 return changes 

593 

594 def _get_operation(self, operation_name, session): 

595 statement = sql_select(Operation).where( 

596 Operation.name == operation_name 

597 ) 

598 return session.execute(statement).scalars().first() 

599 

600 def get_operations_by_stage(self, operation_stage): 

601 statement = sql_select(Operation).where( 

602 Operation.job.has(stage=operation_stage.value) 

603 ) 

604 

605 with self.session() as session: 

606 operations = session.execute(statement).scalars().all() 

607 # Return a set of job names here for now, to match the `MemoryDataStore` 

608 # implementation's behaviour 

609 return set(op.job.name for op in operations) 

610 

611 def _cancel_jobs_exceeding_execution_timeout(self, max_execution_timeout: int=None) -> None: 

612 if max_execution_timeout: 

613 stale_job_names = [] 

614 lazy_execution_timeout_threshold = datetime.utcnow() - timedelta(seconds=max_execution_timeout) 

615 

616 with self.session(sqlite_lock_immediately=True) as session: 

617 # Get the full list of jobs exceeding execution timeout 

618 stale_jobs_statement = sql_select(Job).with_for_update().where( 

619 and_( 

620 Job.stage == OperationStage.EXECUTING.value, 

621 Job.worker_start_timestamp <= lazy_execution_timeout_threshold 

622 ) 

623 ) 

624 stale_job_names = [job.name for job in session.execute(stale_jobs_statement).scalars().all()] 

625 

626 if stale_job_names: 

627 # Mark operations as cancelled 

628 stmt_mark_operations_cancelled = update(Operation).where( 

629 Operation.job_name.in_(stale_job_names) 

630 ).values(cancelled=True) 

631 session.execute(stmt_mark_operations_cancelled) 

632 

633 # Mark jobs as cancelled 

634 stmt_mark_jobs_cancelled = update(Job).where( 

635 Job.name.in_(stale_job_names) 

636 ).values(stage=OperationStage.COMPLETED.value, cancelled=True) 

637 session.execute(stmt_mark_jobs_cancelled) 

638 

639 # Notify all jobs updated 

640 self._notify_job_updated(stale_job_names, session) 

641 

642 if stale_job_names: 

643 self.__logger.info(f"Cancelled n=[{len(stale_job_names)}] jobs " 

644 f"with names={stale_job_names}" 

645 f"due to them exceeding execution_timeout=[" 

646 f"{max_execution_timeout}") 

647 

648 @DurationMetric(DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, instanced=True) 

649 def list_operations(self, 

650 operation_filters: List[OperationFilter]=None, 

651 page_size: int=None, 

652 page_token: str=None, 

653 max_execution_timeout: int=None) -> Tuple[List[operations_pb2.Operation], str]: 

654 # Lazily timeout jobs as needed before returning the list! 

655 self._cancel_jobs_exceeding_execution_timeout(max_execution_timeout=max_execution_timeout) 

656 

657 # Build filters and sort order 

658 sort_keys = DEFAULT_SORT_KEYS 

659 custom_filters = None 

660 if operation_filters: 

661 # Extract custom sort order (if present) 

662 specified_sort_keys, non_sort_filters = extract_sort_keys(operation_filters) 

663 

664 # Only override sort_keys if there were sort keys actually present in the filter string 

665 if specified_sort_keys: 

666 sort_keys = specified_sort_keys 

667 # Attach the operation name as a sort key for a deterministic order 

668 # This will ensure that the ordering of results is consistent between queries 

669 if not any(sort_key.name == "name" for sort_key in sort_keys): 

670 sort_keys.append(SortKey(name="name", descending=False)) 

671 

672 # Finally, compile the non-sort filters into a filter list 

673 custom_filters = build_custom_filters(non_sort_filters) 

674 

675 sort_columns = build_sort_column_list(sort_keys) 

676 

677 with self.session() as session: 

678 statement = sql_select(Operation).join(Job, Operation.job_name == Job.name) 

679 

680 # Apply custom filters (if present) 

681 if custom_filters: 

682 statement = statement.filter(*custom_filters) 

683 

684 # Apply sort order 

685 statement = statement.order_by(*sort_columns) 

686 

687 # Apply pagination filter 

688 if page_token: 

689 page_filter = build_page_filter(page_token, sort_keys) 

690 statement = statement.filter(page_filter) 

691 if page_size: 

692 # We limit the number of operations we fetch to the page_size. However, we 

693 # fetch an extra operation to determine whether we need to provide a 

694 # next_page_token. 

695 statement = statement.limit(page_size + 1) 

696 

697 operations = session.execute(statement).scalars().all() 

698 

699 if not page_size or not operations: 

700 next_page_token = "" 

701 

702 # If the number of results we got is less than or equal to our page_size, 

703 # we're done with the operations listing and don't need to provide another 

704 # page token 

705 elif len(operations) <= page_size: 

706 next_page_token = "" 

707 else: 

708 # Drop the last operation since we have an extra 

709 operations.pop() 

710 # Our page token will be the last row of our set 

711 next_page_token = build_page_token(operations[-1], sort_keys) 

712 return [operation.to_protobuf(self) for operation in operations], next_page_token 

713 

714 @DurationMetric(DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, instanced=True) 

715 def create_operation(self, operation_name, job_name, request_metadata=None): 

716 with self.session() as session: 

717 operation = Operation( 

718 name=operation_name, 

719 job_name=job_name 

720 ) 

721 if request_metadata is not None: 

722 if request_metadata.tool_invocation_id: 

723 operation.invocation_id = request_metadata.tool_invocation_id 

724 if request_metadata.correlated_invocations_id: 

725 operation.correlated_invocations_id = request_metadata.correlated_invocations_id 

726 if request_metadata.tool_details: 

727 operation.tool_name = request_metadata.tool_details.tool_name 

728 operation.tool_version = request_metadata.tool_details.tool_version 

729 session.add(operation) 

730 

731 @DurationMetric(DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME, instanced=True) 

732 def update_operation(self, operation_name, changes): 

733 with self.session() as session: 

734 operation = self._get_operation(operation_name, session) 

735 operation.update(changes) 

736 

737 def delete_operation(self, operation_name): 

738 # Don't do anything. This function needs to exist but there's no 

739 # need to actually delete operations in this implementation. 

740 pass 

741 

742 def get_leases_by_state(self, lease_state): 

743 statement = sql_select(Lease).where( 

744 Lease.state == lease_state.value 

745 ) 

746 

747 with self.session() as session: 

748 leases = session.execute(statement).scalars().all() 

749 # `lease.job_name` is the same as `lease.id` for a Lease protobuf 

750 return set(lease.job_name for lease in leases) 

751 

752 def get_metrics(self): 

753 # Skip publishing overall scheduler metrics if we have recently published them 

754 last_publish_time = self.__last_scheduler_metrics_publish_time 

755 time_since_publish = None 

756 if last_publish_time: 

757 time_since_publish = datetime.utcnow() - last_publish_time 

758 if time_since_publish and time_since_publish < self.__scheduler_metrics_publish_interval: 

759 # Published too recently, skip 

760 return None 

761 

762 def _get_query_leases_by_state(category: str) -> Select: 

763 # Using func.count here to avoid generating a subquery in the WHERE 

764 # clause of the resulting query. 

765 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

766 return sql_select([ 

767 literal_column(f"'{category}'").label("category"), 

768 Lease.state.label("bucket"), 

769 func.count(Lease.id).label("value") 

770 ]).group_by(Lease.state) 

771 

772 def _cb_query_leases_by_state(leases_by_state): 

773 # The database only returns counts > 0, so fill in the gaps 

774 for state in LeaseState: 

775 if state.value not in leases_by_state: 

776 leases_by_state[state.value] = 0 

777 return leases_by_state 

778 

779 def _get_query_jobs_by_stage(category: str) -> Select: 

780 # Using func.count here to avoid generating a subquery in the WHERE 

781 # clause of the resulting query. 

782 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

783 return sql_select([ 

784 literal_column(f"'{category}'").label("category"), 

785 Job.stage.label("bucket"), 

786 func.count(Job.name).label("value") 

787 ]).group_by(Job.stage) 

788 

789 def _cb_query_jobs_by_stage(jobs_by_stage): 

790 # The database only returns counts > 0, so fill in the gaps 

791 for stage in OperationStage: 

792 if stage.value not in jobs_by_stage: 

793 jobs_by_stage[stage.value] = 0 

794 return jobs_by_stage 

795 

796 metrics = {} 

797 # metrics to gather: (category_name, function_returning_query, callback_function) 

798 metrics_to_gather = [ 

799 (MetricCategories.LEASES.value, _get_query_leases_by_state, _cb_query_leases_by_state), 

800 (MetricCategories.JOBS.value, _get_query_jobs_by_stage, _cb_query_jobs_by_stage) 

801 ] 

802 

803 statements = [query_fn(category) for category, query_fn, _ in metrics_to_gather] 

804 metrics_statement = union(*statements) 

805 

806 try: 

807 with self.session() as session: 

808 results = session.execute(metrics_statement).all() 

809 

810 grouped_results = {category: {} for category, _, _ in results} 

811 for category, bucket, value in results: 

812 grouped_results[category][bucket] = value 

813 

814 for category, _, category_cb in metrics_to_gather: 

815 metrics[category] = category_cb(grouped_results.setdefault(category, {})) 

816 except DatabaseError: 

817 self.__logger.warning("Unable to gather metrics due to a Database Error.") 

818 return {} 

819 

820 # This is only updated within the metrics asyncio loop; no race conditions 

821 self.__last_scheduler_metrics_publish_time = datetime.utcnow() 

822 

823 return metrics 

824 

825 def _create_lease(self, lease, session, job=None, worker_name=None): 

826 if job is None: 

827 job = self._get_job(lease.id, session) 

828 # We only allow one lease, so if there's an existing one update it 

829 if job.active_leases: 

830 job.active_leases[0].state = lease.state 

831 job.active_leases[0].status = None 

832 job.active_leases[0].worker_name = worker_name 

833 else: 

834 session.add(Lease( 

835 job_name=lease.id, 

836 state=lease.state, 

837 status=None, 

838 worker_name=worker_name 

839 )) 

840 

841 def create_lease(self, lease): 

842 with self.session() as session: 

843 self._create_lease(lease, session) 

844 

845 @DurationMetric(DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, instanced=True) 

846 def update_lease(self, job_name, changes): 

847 initial_values_for_metrics_use = {} 

848 

849 with self.session() as session: 

850 job = self._get_job(job_name, session) 

851 try: 

852 lease = job.active_leases[0] 

853 except IndexError: 

854 return 

855 

856 # Keep track of the state right before we perform this update 

857 initial_values_for_metrics_use["state"] = lease.state 

858 

859 lease.update(changes) 

860 

861 # Upon successful completion of the transaction above, publish metrics 

862 JobMetrics.publish_metrics_on_lease_updates(initial_values_for_metrics_use, changes, self._instance_name) 

863 

864 def load_unfinished_jobs(self): 

865 statement = sql_select(Job).where( 

866 Job.stage != OperationStage.COMPLETED.value 

867 ).order_by(Job.priority) 

868 

869 with self.session() as session: 

870 jobs = session.execute(statement).scalars().all() 

871 return [j.to_internal_job(self) for j in jobs] 

872 

873 def assign_lease_for_next_job(self, capabilities, callback, timeout=None): 

874 """Return a list of leases for the highest priority jobs that can be run by a worker. 

875 

876 NOTE: Currently the list only ever has one or zero leases. 

877 

878 Query the jobs table to find queued jobs which match the capabilities of 

879 a given worker, and return the one with the highest priority. Takes a 

880 dictionary of worker capabilities to compare with job requirements. 

881 

882 :param capabilities: Dictionary of worker capabilities to compare 

883 with job requirements when finding a job. 

884 :type capabilities: dict 

885 :param callback: Function to run on the next runnable job, should return 

886 a list of leases. 

887 :type callback: function 

888 :param timeout: time to wait for new jobs, caps if longer 

889 than MAX_JOB_BLOCK_TIME. 

890 :type timeout: int 

891 :returns: List of leases 

892 

893 """ 

894 if not timeout: 

895 return self._assign_job_leases(capabilities, callback) 

896 

897 # Cap the timeout if it's larger than MAX_JOB_BLOCK_TIME 

898 if timeout: 

899 timeout = min(timeout, MAX_JOB_BLOCK_TIME) 

900 

901 start = time.time() 

902 while time.time() + self.connection_timeout + 1 - start < timeout: 

903 leases = self._assign_job_leases(capabilities, callback) 

904 if leases: 

905 return leases 

906 time.sleep(0.5) 

907 if self.connection_timeout > timeout: 

908 self.__logger.warning( 

909 "Not providing any leases to the worker because the database connection " 

910 f"timeout ({self.connection_timeout} s) is longer than the remaining " 

911 "time to handle the request. " 

912 "Increase the worker's timeout to solve this problem.") 

913 return [] 

914 

915 def flatten_capabilities(self, capabilities: Dict[str, List[str]]) -> List[Tuple[str, str]]: 

916 """ Flatten a capabilities dictionary, assuming all of its values are lists. E.g. 

917 

918 {'OSFamily': ['Linux'], 'ISA': ['x86-32', 'x86-64']} 

919 

920 becomes 

921 

922 [('OSFamily', 'Linux'), ('ISA', 'x86-32'), ('ISA', 'x86-64')] """ 

923 return [ 

924 (name, value) for name, value_list in capabilities.items() 

925 for value in value_list 

926 ] 

927 

928 def get_partial_capabilities(self, capabilities: Dict[str, List[str]]) -> Iterable[Dict[str, List[str]]]: 

929 """ Given a capabilities dictionary with all values as lists, 

930 yield all partial capabilities dictionaries. """ 

931 CAPABILITIES_WARNING_THRESHOLD = 10 

932 

933 caps_flat = self.flatten_capabilities(capabilities) 

934 

935 if len(caps_flat) > CAPABILITIES_WARNING_THRESHOLD: 

936 self.__logger.warning( 

937 "A worker with a large capabilities dictionary has been connected. " 

938 f"Processing its capabilities may take a while. Capabilities: {capabilities}") 

939 

940 # Using the itertools powerset recipe, construct the powerset of the tuples 

941 capabilities_powerset = chain.from_iterable(combinations(caps_flat, r) for r in range(len(caps_flat) + 1)) 

942 for partial_capability_tuples in capabilities_powerset: 

943 partial_dict: Dict[str, List[str]] = {} 

944 

945 for tup in partial_capability_tuples: 

946 partial_dict.setdefault(tup[0], []).append(tup[1]) 

947 yield partial_dict 

948 

949 def get_partial_capabilities_hashes(self, capabilities: Dict) -> List[str]: 

950 """ Given a list of configurations, obtain each partial configuration 

951 for each configuration, obtain the hash of each partial configuration, 

952 compile these into a list, and return the result. """ 

953 # Convert requirements values to sorted lists to make them json-serializable 

954 convert_values_to_sorted_lists(capabilities) 

955 

956 # Check to see if we've cached this value 

957 capabilities_digest = hash_from_dict(capabilities) 

958 try: 

959 return self.capabilities_cache[capabilities_digest] 

960 except KeyError: 

961 # On cache miss, expand the capabilities into each possible partial capabilities dictionary 

962 capabilities_list = [] 

963 for partial_capability in self.get_partial_capabilities(capabilities): 

964 capabilities_list.append(hash_from_dict(partial_capability)) 

965 

966 self.capabilities_cache[capabilities_digest] = capabilities_list 

967 return capabilities_list 

968 

969 @DurationMetric(BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, instanced=True) 

970 def _assign_job_leases(self, capabilities, callback): 

971 # pylint: disable=singleton-comparison 

972 # Hash the capabilities 

973 capabilities_config_hashes = self.get_partial_capabilities_hashes(capabilities) 

974 leases = [] 

975 create_lease_start_time = None 

976 

977 # Select unassigned, queued jobs with platform requirements matching the 

978 # given capabilities, ordered by priority and age. 

979 job_statement = sql_select(Job).with_for_update(skip_locked=True).where( 

980 Job.stage == OperationStage.QUEUED.value, 

981 Job.assigned != True, 

982 Job.platform_requirements.in_(capabilities_config_hashes) 

983 ).order_by(Job.priority, Job.queued_timestamp) 

984 

985 try: 

986 with self.session(sqlite_lock_immediately=True) as session: 

987 job = session.execute(job_statement).scalars().first() 

988 if job: 

989 internal_job = job.to_internal_job(self) 

990 leases = callback(internal_job) 

991 if leases: 

992 job.assigned = True 

993 job.worker_start_timestamp = internal_job.worker_start_timestamp_as_datetime 

994 # Only create a lease if there isn't already an active one for this job 

995 create_lease_start_time = time.perf_counter() 

996 for lease in leases: 

997 self._create_lease(lease, session, job=job, worker_name=internal_job.worker_name) 

998 

999 # Calculate and publish the time taken to create leases. This is done explicitly 

1000 # rather than using the DurationMetric helper since we need to measure the actual 

1001 # execution time of the UPDATE and INSERT queries used in the lease assignment, and 

1002 # these are only exectuted on exiting the contextmanager. 

1003 if create_lease_start_time is not None: 

1004 run_time = timedelta(seconds=time.perf_counter() - create_lease_start_time) 

1005 metadata = None 

1006 if self._instance_name is not None: 

1007 metadata = {'instance-name': self._instance_name} 

1008 publish_timer_metric(DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, run_time, metadata=metadata) 

1009 

1010 except DatabaseError: 

1011 self.__logger.warning("Will not assign any leases this time due to a Database Error.") 

1012 

1013 return leases 

1014 

1015 def _do_prune(self, job_max_age: timedelta, pruning_period: timedelta, limit: int) -> None: 

1016 """ Running in a background thread, this method wakes up periodically and deletes older records 

1017 from the jobs tables using configurable parameters """ 

1018 

1019 utc_last_prune_time = datetime.utcnow() 

1020 while self.pruner_keep_running: 

1021 utcnow = datetime.utcnow() 

1022 if (utcnow - pruning_period) < utc_last_prune_time: 

1023 self.__logger.info(f"Pruner thread sleeping for {pruning_period}(until {utcnow + pruning_period})") 

1024 time.sleep(pruning_period.total_seconds()) 

1025 continue 

1026 

1027 delete_before_datetime = utcnow - job_max_age 

1028 try: 

1029 with DurationMetric(DATA_STORE_PRUNER_DELETE_TIME_METRIC_NAME, 

1030 instance_name=self._instance_name, 

1031 instanced=True): 

1032 num_rows = self._delete_jobs_prior_to(delete_before_datetime, limit) 

1033 

1034 self.__logger.info( 

1035 f"Pruned {num_rows} row(s) from the jobs table older than {delete_before_datetime}") 

1036 

1037 if num_rows > 0: 

1038 with Counter(metric_name=DATA_STORE_PRUNER_NUM_ROWS_DELETED_METRIC_NAME, 

1039 instance_name=self._instance_name) as num_rows_deleted: 

1040 num_rows_deleted.increment(num_rows) 

1041 

1042 except Exception: 

1043 self.__logger.exception("Caught exception while deleting jobs records") 

1044 

1045 finally: 

1046 # Update even if error occurred to avoid potentially infinitely retrying 

1047 utc_last_prune_time = utcnow 

1048 

1049 self.__logger.info("Exiting pruner thread") 

1050 

1051 def _delete_jobs_prior_to(self, delete_before_datetime: datetime, limit: int) -> int: 

1052 """ Deletes older records from the jobs tables constrained by `delete_before_datetime` and `limit` """ 

1053 delete_stmt = delete(Job).where( 

1054 Job.name.in_( 

1055 sql_select(Job.name).with_for_update(skip_locked=True).where( 

1056 Job.worker_completed_timestamp <= delete_before_datetime 

1057 ).limit(limit) 

1058 ) 

1059 ) 

1060 

1061 with self.session(reraise=True) as session: 

1062 options = {"synchronize_session": "fetch"} 

1063 num_rows_deleted = session.execute(delete_stmt, execution_options=options).rowcount 

1064 

1065 return num_rows_deleted