Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/persistence/sql/impl.py: 77.70%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

556 statements  

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from contextlib import contextmanager 

17import logging 

18import os 

19import select 

20from threading import Thread 

21import time 

22from datetime import datetime, timedelta 

23from tempfile import NamedTemporaryFile 

24from itertools import chain, combinations 

25from typing import Dict, Iterable, List, Tuple, NamedTuple, Optional 

26 

27from alembic import command 

28from alembic.config import Config 

29from sqlalchemy import (create_engine, event, func, union, literal_column, sql) 

30from sqlalchemy.engine import Connection as EngineConnection 

31from sqlalchemy.orm.session import sessionmaker 

32 

33from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

34from buildgrid._protos.google.longrunning import operations_pb2 

35from buildgrid._enums import LeaseState, MetricCategories, OperationStage 

36from buildgrid.server.sql import sqlutils 

37from buildgrid.server.metrics_names import ( 

38 BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, 

39 DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

40 DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, 

41 DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, 

42 DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, 

43 DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, 

44 DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, 

45 DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, 

46 DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, 

47 DATA_STORE_PRUNER_NUM_ROWS_DELETED_METRIC_NAME, 

48 DATA_STORE_PRUNER_DELETE_TIME_METRIC_NAME, 

49 DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, 

50 DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, 

51 DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, 

52 DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, 

53 DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME 

54) 

55from buildgrid.server.job_metrics import JobMetrics 

56from buildgrid.server.metrics_utils import DurationMetric, publish_timer_metric, Counter 

57from buildgrid.server.operations.filtering import OperationFilter, SortKey, DEFAULT_SORT_KEYS 

58from buildgrid.server.persistence.interface import DataStoreInterface 

59from buildgrid.server.persistence.sql.models import digest_to_string, Job, Lease, Operation 

60from buildgrid.server.persistence.sql.utils import ( 

61 build_page_filter, 

62 build_page_token, 

63 extract_sort_keys, 

64 build_custom_filters, 

65 build_sort_column_list 

66) 

67from buildgrid.settings import ( 

68 MAX_JOB_BLOCK_TIME, 

69 MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES, 

70 COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS, 

71 SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS 

72) 

73from buildgrid.utils import JobState, hash_from_dict, convert_values_to_sorted_lists 

74 

75from buildgrid._exceptions import DatabaseError, RetriableDatabaseError 

76 

77 

78Session = sessionmaker() 

79 

80 

81def sqlite_on_connect(conn, record): 

82 conn.execute("PRAGMA journal_mode=WAL") 

83 conn.execute("PRAGMA synchronous=NORMAL") 

84 

85 

86class PruningOptions(NamedTuple): 

87 pruner_job_max_age: timedelta = timedelta(days=30) 

88 pruner_period: timedelta = timedelta(minutes=5) 

89 pruner_max_delete_window: int = 10000 

90 

91 @staticmethod 

92 def from_config(pruner_job_max_age_cfg: Dict[str, float], 

93 pruner_period_cfg: Dict[str, float] = None, 

94 pruner_max_delete_window_cfg: int = None): 

95 """ Helper method for creating ``PruningOptions`` objects 

96 If input configs are None, assign defaults """ 

97 def _dict_to_timedelta(config: Dict[str, float]) -> timedelta: 

98 return timedelta(weeks=config.get('weeks', 0), 

99 days=config.get('days', 0), 

100 hours=config.get('hours', 0), 

101 minutes=config.get('minutes', 0), 

102 seconds=config.get('seconds', 0)) 

103 

104 return PruningOptions(pruner_job_max_age=_dict_to_timedelta( 

105 pruner_job_max_age_cfg) if pruner_job_max_age_cfg else timedelta(days=30), 

106 pruner_period=_dict_to_timedelta( 

107 pruner_period_cfg) if pruner_period_cfg else timedelta(minutes=5), 

108 pruner_max_delete_window=pruner_max_delete_window_cfg 

109 if pruner_max_delete_window_cfg else 10000) 

110 

111 

112class SQLDataStore(DataStoreInterface): 

113 

114 def __init__(self, storage, *, connection_string=None, automigrate=False, 

115 connection_timeout=5, poll_interval=1, 

116 pruning_options: Optional[PruningOptions] = None, 

117 **kwargs): 

118 super().__init__(storage) 

119 self.__logger = logging.getLogger(__name__) 

120 self.__logger.info("Creating SQL scheduler with: " 

121 f"automigrate=[{automigrate}], connection_timeout=[{connection_timeout}] " 

122 f"poll_interval=[{poll_interval}], " 

123 f"pruning_options=[{pruning_options}], " 

124 f"kwargs=[{kwargs}]") 

125 

126 self.response_cache: Dict[str, remote_execution_pb2.ExecuteResponse] = {} 

127 self.connection_timeout = connection_timeout 

128 self.poll_interval = poll_interval 

129 self.watcher = Thread(name="JobWatcher", target=self.wait_for_job_updates, daemon=True) 

130 self.watcher_keep_running = True 

131 

132 # Set-up temporary SQLite Database when connection string is not specified 

133 if not connection_string: 

134 # pylint: disable=consider-using-with 

135 tmpdbfile = NamedTemporaryFile(prefix='bgd-', suffix='.db') 

136 self._tmpdbfile = tmpdbfile # Make sure to keep this tempfile for the lifetime of this object 

137 self.__logger.warning("No connection string specified for the DataStore, " 

138 f"will use SQLite with tempfile: [{tmpdbfile.name}]") 

139 automigrate = True # since this is a temporary database, we always need to create it 

140 connection_string = f"sqlite:///{tmpdbfile.name}" 

141 

142 self._create_sqlalchemy_engine(connection_string, automigrate, connection_timeout, **kwargs) 

143 

144 self._sql_pool_dispose_helper = sqlutils.SQLPoolDisposeHelper(COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS, 

145 MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES, 

146 self.engine) 

147 

148 # Make a test query against the database to ensure the connection is valid 

149 with self.session(reraise=True) as session: 

150 session.query(Job).first() 

151 

152 self.watcher.start() 

153 

154 self.capabilities_cache: Dict[str, List[str]] = {} 

155 

156 # Pruning configuration parameters 

157 if pruning_options is not None: 

158 self.pruner_keep_running = True 

159 self.__logger.info(f"Scheduler pruning enabled: {pruning_options}") 

160 self.__pruner_thread = Thread(name="JobsPruner", target=self._do_prune, args=( 

161 pruning_options.pruner_job_max_age, pruning_options.pruner_period, 

162 pruning_options.pruner_max_delete_window), daemon=True) 

163 self.__pruner_thread.start() 

164 else: 

165 self.__logger.info("Scheduler pruning not enabled") 

166 

167 # Overall Scheduler Metrics (totals of jobs/leases in each state) 

168 # Publish those metrics a bit more sparsely since the SQL requests 

169 # required to gather them can become expensive 

170 self.__last_scheduler_metrics_publish_time = None 

171 self.__scheduler_metrics_publish_interval = timedelta( 

172 seconds=SQL_SCHEDULER_METRICS_PUBLISH_INTERVAL_SECONDS) 

173 

174 def _create_sqlalchemy_engine(self, connection_string, automigrate, connection_timeout, **kwargs): 

175 self.automigrate = automigrate 

176 

177 # Disallow sqlite in-memory because multi-threaded access to it is 

178 # complex and potentially problematic at best 

179 # ref: https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#threading-pooling-behavior 

180 if sqlutils.is_sqlite_inmemory_connection_string(connection_string): 

181 raise ValueError( 

182 f"Cannot use SQLite in-memory with BuildGrid (connection_string=[{connection_string}]). " 

183 "Use a file or leave the connection_string empty for a tempfile.") 

184 

185 if connection_timeout is not None: 

186 if "connect_args" not in kwargs: 

187 kwargs["connect_args"] = {} 

188 if sqlutils.is_sqlite_connection_string(connection_string): 

189 kwargs["connect_args"]["timeout"] = connection_timeout 

190 elif sqlutils.is_psycopg2_connection_string(connection_string): 

191 kwargs["connect_args"]["connect_timeout"] = connection_timeout 

192 # Additional postgres specific timeouts 

193 # Additional libpg options 

194 # Note that those timeouts are in milliseconds (so *1000) 

195 kwargs["connect_args"]["options"] = f'-c lock_timeout={connection_timeout * 1000}' 

196 

197 # Only pass the (known) kwargs that have been explicitly set by the user 

198 available_options = set([ 

199 'pool_size', 'max_overflow', 'pool_timeout', 'pool_pre_ping', 

200 'pool_recycle', 'connect_args' 

201 ]) 

202 kwargs_keys = set(kwargs.keys()) 

203 if not kwargs_keys.issubset(available_options): 

204 unknown_options = kwargs_keys - available_options 

205 raise TypeError(f"Unknown keyword arguments: [{unknown_options}]") 

206 

207 self.__logger.debug(f"SQLAlchemy additional kwargs: [{kwargs}]") 

208 

209 self.engine = create_engine(connection_string, echo=False, **kwargs) 

210 Session.configure(bind=self.engine) 

211 

212 if self.engine.dialect.name == "sqlite": 

213 event.listen(self.engine, "connect", sqlite_on_connect) 

214 

215 if self.automigrate: 

216 self._create_or_migrate_db(connection_string) 

217 

218 def __repr__(self): 

219 return f"SQL data store interface for `{repr(self.engine.url)}`" 

220 

221 def activate_monitoring(self): 

222 # Don't do anything. This function needs to exist but there's no 

223 # need to actually toggle monitoring in this implementation. 

224 pass 

225 

226 def deactivate_monitoring(self): 

227 # Don't do anything. This function needs to exist but there's no 

228 # need to actually toggle monitoring in this implementation. 

229 pass 

230 

231 def _create_or_migrate_db(self, connection_string): 

232 self.__logger.warning("Will attempt migration to latest version if needed.") 

233 

234 config = Config() 

235 config.set_main_option("script_location", os.path.join(os.path.dirname(__file__), "alembic")) 

236 

237 with self.engine.begin() as connection: 

238 config.attributes['connection'] = connection 

239 command.upgrade(config, "head") 

240 

241 @contextmanager 

242 def session(self, *, sqlite_lock_immediately=False, reraise=False): 

243 # If we recently disposed of the SQL pool due to connection issues 

244 # allow for some cooldown period before we attempt more SQL 

245 self._sql_pool_dispose_helper.wait_if_cooldown_in_effect() 

246 

247 # Try to obtain a session 

248 try: 

249 session = Session() 

250 if sqlite_lock_immediately and session.bind.name == "sqlite": 

251 session.execute("BEGIN IMMEDIATE") 

252 except Exception as e: 

253 self.__logger.error("Unable to obtain a database session.", exc_info=True) 

254 raise DatabaseError("Unable to obtain a database session.") from e 

255 

256 # Yield the session and catch exceptions that occur while using it 

257 # to roll-back if needed 

258 try: 

259 yield session 

260 session.commit() 

261 except Exception as e: 

262 transient_dberr = self._sql_pool_dispose_helper.check_dispose_pool(session, e) 

263 if transient_dberr: 

264 self.__logger.warning("Rolling back database session due to transient database error.", exc_info=True) 

265 else: 

266 self.__logger.error("Error committing database session. Rolling back.", exc_info=True) 

267 try: 

268 session.rollback() 

269 except Exception: 

270 self.__logger.warning("Rollback error.", exc_info=True) 

271 

272 if reraise: 

273 if transient_dberr: 

274 raise RetriableDatabaseError("Database connection was temporarily interrupted, please retry", 

275 timedelta(seconds=COOLDOWN_TIME_AFTER_POOL_DISPOSE_SECONDS)) from e 

276 raise 

277 finally: 

278 session.close() 

279 

280 def _get_job(self, job_name, session, with_for_update=False): 

281 jobs = session.query(Job) 

282 if with_for_update: 

283 jobs = jobs.with_for_update() 

284 jobs = jobs.filter_by(name=job_name) 

285 

286 job = jobs.first() 

287 if job: 

288 self.__logger.debug(f"Loaded job from db: name=[{job_name}], stage=[{job.stage}], result=[{job.result}]") 

289 

290 return job 

291 

292 def _check_job_timeout(self, job_internal, *, max_execution_timeout=None): 

293 """ Do a lazy check of maximum allowed job timeouts when clients try to retrieve 

294 an existing job. 

295 Cancel the job and related operations/leases, if we detect they have 

296 exceeded timeouts on access. 

297 

298 Returns the `buildgrid.server.Job` object, possibly updated with `cancelled=True`. 

299 """ 

300 if job_internal and max_execution_timeout and job_internal.worker_start_timestamp_as_datetime: 

301 if job_internal.operation_stage == OperationStage.EXECUTING: 

302 executing_duration = datetime.utcnow() - job_internal.worker_start_timestamp_as_datetime 

303 if executing_duration.total_seconds() >= max_execution_timeout: 

304 self.__logger.warning(f"Job=[{job_internal}] has been executing for " 

305 f"executing_duration=[{executing_duration}]. " 

306 f"max_execution_timeout=[{max_execution_timeout}] " 

307 "Cancelling.") 

308 job_internal.cancel_all_operations(data_store=self) 

309 self.__logger.info(f"Job=[{job_internal}] has been cancelled.") 

310 return job_internal 

311 

312 @DurationMetric(DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, instanced=True) 

313 def get_job_by_action(self, action_digest, *, max_execution_timeout=None): 

314 with self.session() as session: 

315 jobs = session.query(Job).filter_by(action_digest=digest_to_string(action_digest)) 

316 jobs = jobs.filter(Job.stage != OperationStage.COMPLETED.value) 

317 job = jobs.first() 

318 if job: 

319 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

320 instance_name=self._instance_name) 

321 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

322 return None 

323 

324 @DurationMetric(DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, instanced=True) 

325 def get_job_by_name(self, name, *, max_execution_timeout=None): 

326 with self.session() as session: 

327 job = self._get_job(name, session) 

328 if job: 

329 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

330 instance_name=self._instance_name) 

331 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

332 return None 

333 

334 @DurationMetric(DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, instanced=True) 

335 def get_job_by_operation(self, operation_name, *, max_execution_timeout=None): 

336 with self.session() as session: 

337 operation = self._get_operation(operation_name, session) 

338 if operation and operation.job: 

339 job = operation.job 

340 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

341 instance_name=self._instance_name) 

342 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

343 return None 

344 

345 def get_all_jobs(self): 

346 with self.session() as session: 

347 jobs = session.query(Job).filter(Job.stage != OperationStage.COMPLETED.value) 

348 return [j.to_internal_job(self, action_browser_url=self._action_browser_url, 

349 instance_name=self._instance_name) for j in jobs] 

350 

351 def get_jobs_by_stage(self, operation_stage): 

352 with self.session() as session: 

353 jobs = session.query(Job).filter(Job.stage == operation_stage.value) 

354 return [j.to_internal_job(self, no_result=True, action_browser_url=self._action_browser_url, 

355 instance_name=self._instance_name) for j in jobs] 

356 

357 def get_operation_request_metadata_by_name(self, operation_name): 

358 with self.session() as session: 

359 operation = self._get_operation(operation_name, session) 

360 if not operation: 

361 return None 

362 

363 return {'tool-name': operation.tool_name or '', 

364 'tool-version': operation.tool_version or '', 

365 'invocation-id': operation.invocation_id or '', 

366 'correlated-invocations-id': operation.correlated_invocations_id or ''} 

367 

368 @DurationMetric(DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, instanced=True) 

369 def create_job(self, job): 

370 with self.session() as session: 

371 if self._get_job(job.name, session) is None: 

372 # Convert requirements values to sorted lists to make them json-serializable 

373 platform_requirements = job.platform_requirements 

374 convert_values_to_sorted_lists(platform_requirements) 

375 # Serialize the requirements 

376 platform_requirements_hash = hash_from_dict(platform_requirements) 

377 

378 session.add(Job( 

379 name=job.name, 

380 action=job.action.SerializeToString(), 

381 action_digest=digest_to_string(job.action_digest), 

382 do_not_cache=job.do_not_cache, 

383 priority=job.priority, 

384 operations=[], 

385 platform_requirements=platform_requirements_hash, 

386 stage=job.operation_stage.value, 

387 queued_timestamp=job.queued_timestamp_as_datetime, 

388 queued_time_duration=job.queued_time_duration.seconds, 

389 worker_start_timestamp=job.worker_start_timestamp_as_datetime, 

390 worker_completed_timestamp=job.worker_completed_timestamp_as_datetime 

391 )) 

392 

393 @DurationMetric(DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, instanced=True) 

394 def queue_job(self, job_name): 

395 with self.session(sqlite_lock_immediately=True) as session: 

396 job = self._get_job(job_name, session, with_for_update=True) 

397 job.assigned = False 

398 

399 @DurationMetric(DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, instanced=True) 

400 def update_job(self, job_name, changes, *, skip_notify=False): 

401 if "result" in changes: 

402 changes["result"] = digest_to_string(changes["result"]) 

403 if "action_digest" in changes: 

404 changes["action_digest"] = digest_to_string(changes["action_digest"]) 

405 

406 initial_values_for_metrics_use = {} 

407 

408 with self.session() as session: 

409 job = self._get_job(job_name, session) 

410 

411 # Keep track of the state right before we perform this update 

412 initial_values_for_metrics_use["stage"] = OperationStage(job.stage) 

413 

414 job.update(changes) 

415 if not skip_notify: 

416 self._notify_job_updated(job_name, session) 

417 

418 # Upon successful completion of the transaction above, publish metrics 

419 JobMetrics.publish_metrics_on_job_updates(initial_values_for_metrics_use, changes, self._instance_name) 

420 

421 def _notify_job_updated(self, job_names, session): 

422 if self.engine.dialect.name == "postgresql": 

423 if isinstance(job_names, str): 

424 job_names = [job_names] 

425 for job_name in job_names: 

426 session.execute(f"NOTIFY job_updated, '{job_name}';") 

427 

428 def delete_job(self, job_name): 

429 if job_name in self.response_cache: 

430 del self.response_cache[job_name] 

431 

432 def wait_for_job_updates(self): 

433 self.__logger.info("Starting job watcher thread") 

434 if self.engine.dialect.name == "postgresql": 

435 self._listen_for_updates() 

436 else: 

437 self._poll_for_updates() 

438 

439 def _listen_for_updates(self): 

440 def _listen_loop(engine_conn: EngineConnection): 

441 try: 

442 # Get the DBAPI connection object from the SQLAlchemy Engine.Connection wrapper 

443 dbapi_conn = engine_conn.connection 

444 dbapi_conn.cursor().execute("LISTEN job_updated;") 

445 dbapi_conn.commit() 

446 except Exception: 

447 self.__logger.warning( 

448 "Could not start listening to DB for job updates", 

449 exc_info=True) 

450 # Let the context manager handle this 

451 raise 

452 

453 while self.watcher_keep_running: 

454 # Wait until the connection is ready for reading. Timeout after 5 seconds 

455 # and try again if there was nothing to read. If the connection becomes 

456 # readable, collect the notifications it has received and handle them. 

457 # 

458 # See http://initd.org/psycopg/docs/advanced.html#async-notify 

459 if select.select([dbapi_conn.connection], [], [], self.poll_interval) == ([], [], []): 

460 pass 

461 else: 

462 

463 try: 

464 dbapi_conn.connection.poll() 

465 except Exception: 

466 self.__logger.warning("Error while polling for job updates", exc_info=True) 

467 # Let the context manager handle this 

468 raise 

469 

470 while dbapi_conn.connection.notifies: 

471 notify = dbapi_conn.connection.notifies.pop() 

472 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

473 instanced=True, instance_name=self._instance_name): 

474 with self.watched_jobs_lock: 

475 spec = self.watched_jobs.get(notify.payload) 

476 if spec is not None: 

477 try: 

478 new_job = self.get_job_by_name(notify.payload) 

479 except Exception: 

480 self.__logger.warning( 

481 f"Couldn't get watched job=[{notify.payload}] from DB", 

482 exc_info=True) 

483 # Let the context manager handle this 

484 raise 

485 

486 # If the job doesn't exist or an exception was supressed by 

487 # get_job_by_name, it returns None instead of the job 

488 if new_job is None: 

489 raise DatabaseError( 

490 f"get_job_by_name returned None for job=[{notify.payload}]") 

491 

492 new_state = JobState(new_job) 

493 if spec.last_state != new_state: 

494 spec.last_state = new_state 

495 spec.event.notify_change() 

496 

497 while self.watcher_keep_running: 

498 # Wait a few seconds if a database exception occurs and then try again 

499 # This could be a short disconnect 

500 try: 

501 # Use the session contextmanager 

502 # so that we can benefit from the common SQL error-handling 

503 with self.session(reraise=True) as session: 

504 # In our `LISTEN` call, we want to *bypass the ORM* 

505 # and *use the underlying Engine connection directly*. 

506 # (This is because using a `session.execute()` will 

507 # implicitly create a SQL transaction, causing 

508 # notifications to only be delivered when that transaction 

509 # is committed) 

510 _listen_loop(session.connection()) 

511 except Exception as e: 

512 self.__logger.warning(f"JobWatcher encountered exception: [{e}];" 

513 f"Retrying in poll_interval=[{self.poll_interval}] seconds.") 

514 # Sleep for a bit so that we give enough time for the 

515 # database to potentially recover 

516 time.sleep(self.poll_interval) 

517 

518 def _get_watched_jobs(self): 

519 with self.session() as sess: 

520 jobs = sess.query(Job).filter( 

521 Job.name.in_(self.watched_jobs) 

522 ) 

523 return [job.to_internal_job(self) for job in jobs.all()] 

524 

525 def _poll_for_updates(self): 

526 def _poll_loop(): 

527 while self.watcher_keep_running: 

528 time.sleep(self.poll_interval) 

529 if self.watcher_keep_running: 

530 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

531 instanced=True, instance_name=self._instance_name): 

532 with self.watched_jobs_lock: 

533 if self.watcher_keep_running: 

534 try: 

535 watched_jobs = self._get_watched_jobs() 

536 except Exception as e: 

537 raise DatabaseError("Couldn't retrieve watched jobs from DB") from e 

538 

539 if watched_jobs is None: 

540 raise DatabaseError("_get_watched_jobs returned None") 

541 

542 for new_job in watched_jobs: 

543 if self.watcher_keep_running: 

544 spec = self.watched_jobs[new_job.name] 

545 new_state = JobState(new_job) 

546 if spec.last_state != new_state: 

547 spec.last_state = new_state 

548 spec.event.notify_change() 

549 

550 while self.watcher_keep_running: 

551 # Wait a few seconds if a database exception occurs and then try again 

552 try: 

553 _poll_loop() 

554 except DatabaseError as e: 

555 self.__logger.warning(f"JobWatcher encountered exception: [{e}];" 

556 f"Retrying in poll_interval=[{self.poll_interval}] seconds.") 

557 # Sleep for a bit so that we give enough time for the 

558 # database to potentially recover 

559 time.sleep(self.poll_interval) 

560 

561 @DurationMetric(DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, instanced=True) 

562 def store_response(self, job, commit_changes=True): 

563 digest = self.storage.put_message(job.execute_response) 

564 changes = {"result": digest, "status_code": job.execute_response.status.code} 

565 self.response_cache[job.name] = job.execute_response 

566 

567 if commit_changes: 

568 self.update_job(job.name, 

569 changes, 

570 skip_notify=True) 

571 return None 

572 else: 

573 # The caller will batch the changes and commit to db 

574 return changes 

575 

576 def _get_operation(self, operation_name, session): 

577 operations = session.query(Operation).filter_by(name=operation_name) 

578 return operations.first() 

579 

580 def get_operations_by_stage(self, operation_stage): 

581 with self.session() as session: 

582 operations = session.query(Operation) 

583 operations = operations.filter(Operation.job.has(stage=operation_stage.value)) 

584 operations = operations.all() 

585 # Return a set of job names here for now, to match the `MemoryDataStore` 

586 # implementation's behaviour 

587 return set(op.job.name for op in operations) 

588 

589 def _cancel_jobs_exceeding_execution_timeout(self, max_execution_timeout: int=None) -> None: 

590 if max_execution_timeout: 

591 stale_job_names = [] 

592 lazy_execution_timeout_threshold = datetime.utcnow() - timedelta(seconds=max_execution_timeout) 

593 

594 jobs_table = Job.__table__ 

595 operations_table = Operation.__table__ 

596 

597 with self.session(sqlite_lock_immediately=True) as session: 

598 # Get the full list of jobs exceeding execution timeout 

599 stale_jobs = session.query(Job).filter_by(stage=OperationStage.EXECUTING.value) 

600 stale_jobs = stale_jobs.filter(Job.worker_start_timestamp <= lazy_execution_timeout_threshold) 

601 stale_job_names = [job.name for job in stale_jobs.with_for_update().all()] 

602 

603 if stale_job_names: 

604 # Mark operations as cancelled 

605 stmt_mark_operations_cancelled = operations_table.update().where( 

606 operations_table.c.job_name.in_(stale_job_names) 

607 ).values(cancelled=True) 

608 session.execute(stmt_mark_operations_cancelled) 

609 

610 # Mark jobs as cancelled 

611 stmt_mark_jobs_cancelled = jobs_table.update().where( 

612 jobs_table.c.name.in_(stale_job_names) 

613 ).values(stage=OperationStage.COMPLETED.value, cancelled=True) 

614 session.execute(stmt_mark_jobs_cancelled) 

615 

616 # Notify all jobs updated 

617 self._notify_job_updated(stale_job_names, session) 

618 

619 if stale_job_names: 

620 self.__logger.info(f"Cancelled n=[{len(stale_job_names)}] jobs " 

621 f"with names={stale_job_names}" 

622 f"due to them exceeding execution_timeout=[" 

623 f"{max_execution_timeout}") 

624 

625 @DurationMetric(DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, instanced=True) 

626 def list_operations(self, 

627 operation_filters: List[OperationFilter]=None, 

628 page_size: int=None, 

629 page_token: str=None, 

630 max_execution_timeout: int=None) -> Tuple[List[operations_pb2.Operation], str]: 

631 # Lazily timeout jobs as needed before returning the list! 

632 self._cancel_jobs_exceeding_execution_timeout(max_execution_timeout=max_execution_timeout) 

633 

634 # Build filters and sort order 

635 sort_keys = DEFAULT_SORT_KEYS 

636 custom_filters = None 

637 if operation_filters: 

638 # Extract custom sort order (if present) 

639 specified_sort_keys, non_sort_filters = extract_sort_keys(operation_filters) 

640 

641 # Only override sort_keys if there were sort keys actually present in the filter string 

642 if specified_sort_keys: 

643 sort_keys = specified_sort_keys 

644 # Attach the operation name as a sort key for a deterministic order 

645 # This will ensure that the ordering of results is consistent between queries 

646 if not any(sort_key.name == "name" for sort_key in sort_keys): 

647 sort_keys.append(SortKey(name="name", descending=False)) 

648 

649 # Finally, compile the non-sort filters into a filter list 

650 custom_filters = build_custom_filters(non_sort_filters) 

651 

652 sort_columns = build_sort_column_list(sort_keys) 

653 

654 with self.session() as session: 

655 results = session.query(Operation).join(Job, Operation.job_name == Job.name) 

656 

657 # Apply custom filters (if present) 

658 if custom_filters: 

659 results = results.filter(*custom_filters) 

660 

661 # Apply sort order 

662 results = results.order_by(*sort_columns) 

663 

664 # Apply pagination filter 

665 if page_token: 

666 page_filter = build_page_filter(page_token, sort_keys) 

667 results = results.filter(page_filter) 

668 if page_size: 

669 # We limit the number of operations we fetch to the page_size. However, we 

670 # fetch an extra operation to determine whether we need to provide a 

671 # next_page_token. 

672 results = results.limit(page_size + 1) 

673 

674 operations = list(results) 

675 

676 if not page_size or not operations: 

677 next_page_token = "" 

678 

679 # If the number of results we got is less than or equal to our page_size, 

680 # we're done with the operations listing and don't need to provide another 

681 # page token 

682 elif len(operations) <= page_size: 

683 next_page_token = "" 

684 else: 

685 # Drop the last operation since we have an extra 

686 operations.pop() 

687 # Our page token will be the last row of our set 

688 next_page_token = build_page_token(operations[-1], sort_keys) 

689 return [operation.to_protobuf(self) for operation in operations], next_page_token 

690 

691 @DurationMetric(DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, instanced=True) 

692 def create_operation(self, operation_name, job_name, request_metadata=None): 

693 with self.session() as session: 

694 operation = Operation( 

695 name=operation_name, 

696 job_name=job_name 

697 ) 

698 if request_metadata is not None: 

699 if request_metadata.tool_invocation_id: 

700 operation.invocation_id = request_metadata.tool_invocation_id 

701 if request_metadata.correlated_invocations_id: 

702 operation.correlated_invocations_id = request_metadata.correlated_invocations_id 

703 if request_metadata.tool_details: 

704 operation.tool_name = request_metadata.tool_details.tool_name 

705 operation.tool_version = request_metadata.tool_details.tool_version 

706 session.add(operation) 

707 

708 @DurationMetric(DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME, instanced=True) 

709 def update_operation(self, operation_name, changes): 

710 with self.session() as session: 

711 operation = self._get_operation(operation_name, session) 

712 operation.update(changes) 

713 

714 def delete_operation(self, operation_name): 

715 # Don't do anything. This function needs to exist but there's no 

716 # need to actually delete operations in this implementation. 

717 pass 

718 

719 def get_leases_by_state(self, lease_state): 

720 with self.session() as session: 

721 leases = session.query(Lease).filter_by(state=lease_state.value) 

722 leases = leases.all() 

723 # `lease.job_name` is the same as `lease.id` for a Lease protobuf 

724 return set(lease.job_name for lease in leases) 

725 

726 def get_metrics(self): 

727 # Skip publishing overall scheduler metrics if we have recently published them 

728 last_publish_time = self.__last_scheduler_metrics_publish_time 

729 time_since_publish = None 

730 if last_publish_time: 

731 time_since_publish = datetime.utcnow() - last_publish_time 

732 if time_since_publish and time_since_publish < self.__scheduler_metrics_publish_interval: 

733 # Published too recently, skip 

734 return None 

735 

736 def _get_query_leases_by_state(session, category): 

737 # Using func.count here to avoid generating a subquery in the WHERE 

738 # clause of the resulting query. 

739 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

740 query = session.query(literal_column(category).label("category"), 

741 Lease.state.label("bucket"), 

742 func.count(Lease.id).label("value")) 

743 query = query.group_by(Lease.state) 

744 return query 

745 

746 def _cb_query_leases_by_state(leases_by_state): 

747 # The database only returns counts > 0, so fill in the gaps 

748 for state in LeaseState: 

749 if state.value not in leases_by_state: 

750 leases_by_state[state.value] = 0 

751 return leases_by_state 

752 

753 def _get_query_jobs_by_stage(session, category): 

754 # Using func.count here to avoid generating a subquery in the WHERE 

755 # clause of the resulting query. 

756 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

757 query = session.query(literal_column(category).label("category"), 

758 Job.stage.label("bucket"), 

759 func.count(Job.name).label("value")) 

760 query = query.group_by(Job.stage) 

761 return query 

762 

763 def _cb_query_jobs_by_stage(jobs_by_stage): 

764 # The database only returns counts > 0, so fill in the gaps 

765 for stage in OperationStage: 

766 if stage.value not in jobs_by_stage: 

767 jobs_by_stage[stage.value] = 0 

768 return jobs_by_stage 

769 

770 metrics = {} 

771 try: 

772 with self.session() as session: 

773 # metrics to gather: (category_name, function_returning_query, callback_function) 

774 metrics_to_gather = [(MetricCategories.LEASES.value, _get_query_leases_by_state, 

775 _cb_query_leases_by_state), 

776 (MetricCategories.JOBS.value, _get_query_jobs_by_stage, 

777 _cb_query_jobs_by_stage)] 

778 

779 union_query = union(*[query_fn(session, f"'{category}'") 

780 for category, query_fn, _ in metrics_to_gather]) 

781 union_results = session.execute(union_query).fetchall() 

782 

783 grouped_results = {category: {} for category, _, _ in union_results} 

784 for category, bucket, value in union_results: 

785 grouped_results[category][bucket] = value 

786 

787 for category, _, category_cb in metrics_to_gather: 

788 metrics[category] = category_cb(grouped_results.setdefault(category, {})) 

789 except DatabaseError: 

790 self.__logger.warning("Unable to gather metrics due to a Database Error.") 

791 return {} 

792 

793 # This is only updated within the metrics asyncio loop; no race conditions 

794 self.__last_scheduler_metrics_publish_time = datetime.utcnow() 

795 

796 return metrics 

797 

798 def _create_lease(self, lease, session, job=None): 

799 if job is None: 

800 job = self._get_job(lease.id, session) 

801 job = job.to_internal_job(self) 

802 session.add(Lease( 

803 job_name=lease.id, 

804 state=lease.state, 

805 status=None, 

806 worker_name=job.worker_name 

807 )) 

808 

809 def create_lease(self, lease): 

810 with self.session() as session: 

811 self._create_lease(lease, session) 

812 

813 @DurationMetric(DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, instanced=True) 

814 def update_lease(self, job_name, changes): 

815 initial_values_for_metrics_use = {} 

816 

817 with self.session() as session: 

818 job = self._get_job(job_name, session) 

819 try: 

820 lease = job.active_leases[0] 

821 except IndexError: 

822 return 

823 

824 # Keep track of the state right before we perform this update 

825 initial_values_for_metrics_use["state"] = lease.state 

826 

827 lease.update(changes) 

828 

829 # Upon successful completion of the transaction above, publish metrics 

830 JobMetrics.publish_metrics_on_lease_updates(initial_values_for_metrics_use, changes, self._instance_name) 

831 

832 def load_unfinished_jobs(self): 

833 with self.session() as session: 

834 jobs = session.query(Job) 

835 jobs = jobs.filter(Job.stage != OperationStage.COMPLETED.value) 

836 jobs = jobs.order_by(Job.priority) 

837 return [j.to_internal_job(self) for j in jobs.all()] 

838 

839 def assign_lease_for_next_job(self, capabilities, callback, timeout=None): 

840 """Return a list of leases for the highest priority jobs that can be run by a worker. 

841 

842 NOTE: Currently the list only ever has one or zero leases. 

843 

844 Query the jobs table to find queued jobs which match the capabilities of 

845 a given worker, and return the one with the highest priority. Takes a 

846 dictionary of worker capabilities to compare with job requirements. 

847 

848 :param capabilities: Dictionary of worker capabilities to compare 

849 with job requirements when finding a job. 

850 :type capabilities: dict 

851 :param callback: Function to run on the next runnable job, should return 

852 a list of leases. 

853 :type callback: function 

854 :param timeout: time to wait for new jobs, caps if longer 

855 than MAX_JOB_BLOCK_TIME. 

856 :type timeout: int 

857 :returns: List of leases 

858 

859 """ 

860 if not timeout: 

861 return self._assign_job_leases(capabilities, callback) 

862 

863 # Cap the timeout if it's larger than MAX_JOB_BLOCK_TIME 

864 if timeout: 

865 timeout = min(timeout, MAX_JOB_BLOCK_TIME) 

866 

867 start = time.time() 

868 while time.time() + self.connection_timeout + 1 - start < timeout: 

869 leases = self._assign_job_leases(capabilities, callback) 

870 if leases: 

871 return leases 

872 time.sleep(0.5) 

873 if self.connection_timeout > timeout: 

874 self.__logger.warning( 

875 "Not providing any leases to the worker because the database connection " 

876 f"timeout ({self.connection_timeout} s) is longer than the remaining " 

877 "time to handle the request. " 

878 "Increase the worker's timeout to solve this problem.") 

879 return [] 

880 

881 def flatten_capabilities(self, capabilities: Dict[str, List[str]]) -> List[Tuple[str, str]]: 

882 """ Flatten a capabilities dictionary, assuming all of its values are lists. E.g. 

883 

884 {'OSFamily': ['Linux'], 'ISA': ['x86-32', 'x86-64']} 

885 

886 becomes 

887 

888 [('OSFamily', 'Linux'), ('ISA', 'x86-32'), ('ISA', 'x86-64')] """ 

889 return [ 

890 (name, value) for name, value_list in capabilities.items() 

891 for value in value_list 

892 ] 

893 

894 def get_partial_capabilities(self, capabilities: Dict[str, List[str]]) -> Iterable[Dict[str, List[str]]]: 

895 """ Given a capabilities dictionary with all values as lists, 

896 yield all partial capabilities dictionaries. """ 

897 CAPABILITIES_WARNING_THRESHOLD = 10 

898 

899 caps_flat = self.flatten_capabilities(capabilities) 

900 

901 if len(caps_flat) > CAPABILITIES_WARNING_THRESHOLD: 

902 self.__logger.warning( 

903 "A worker with a large capabilities dictionary has been connected. " 

904 f"Processing its capabilities may take a while. Capabilities: {capabilities}") 

905 

906 # Using the itertools powerset recipe, construct the powerset of the tuples 

907 capabilities_powerset = chain.from_iterable(combinations(caps_flat, r) for r in range(len(caps_flat) + 1)) 

908 for partial_capability_tuples in capabilities_powerset: 

909 partial_dict: Dict[str, List[str]] = {} 

910 

911 for tup in partial_capability_tuples: 

912 partial_dict.setdefault(tup[0], []).append(tup[1]) 

913 yield partial_dict 

914 

915 def get_partial_capabilities_hashes(self, capabilities: Dict) -> List[str]: 

916 """ Given a list of configurations, obtain each partial configuration 

917 for each configuration, obtain the hash of each partial configuration, 

918 compile these into a list, and return the result. """ 

919 # Convert requirements values to sorted lists to make them json-serializable 

920 convert_values_to_sorted_lists(capabilities) 

921 

922 # Check to see if we've cached this value 

923 capabilities_digest = hash_from_dict(capabilities) 

924 try: 

925 return self.capabilities_cache[capabilities_digest] 

926 except KeyError: 

927 # On cache miss, expand the capabilities into each possible partial capabilities dictionary 

928 capabilities_list = [] 

929 for partial_capability in self.get_partial_capabilities(capabilities): 

930 capabilities_list.append(hash_from_dict(partial_capability)) 

931 

932 self.capabilities_cache[capabilities_digest] = capabilities_list 

933 return capabilities_list 

934 

935 @DurationMetric(BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, instanced=True) 

936 def _assign_job_leases(self, capabilities, callback): 

937 # pylint: disable=singleton-comparison 

938 # Hash the capabilities 

939 capabilities_config_hashes = self.get_partial_capabilities_hashes(capabilities) 

940 leases = [] 

941 try: 

942 create_lease_start_time = None 

943 with self.session(sqlite_lock_immediately=True) as session: 

944 jobs = session.query(Job).with_for_update(skip_locked=True) 

945 jobs = jobs.filter(Job.stage == OperationStage.QUEUED.value) 

946 jobs = jobs.filter(Job.assigned != True) # noqa 

947 jobs = jobs.filter(Job.platform_requirements.in_(capabilities_config_hashes)) 

948 job = jobs.order_by(Job.priority, Job.queued_timestamp).first() 

949 # This worker can take this job if it can handle all of its configurations 

950 if job: 

951 internal_job = job.to_internal_job(self) 

952 leases = callback(internal_job) 

953 if leases: 

954 job.assigned = True 

955 job.worker_start_timestamp = internal_job.worker_start_timestamp_as_datetime 

956 create_lease_start_time = time.perf_counter() 

957 for lease in leases: 

958 self._create_lease(lease, session, job=internal_job) 

959 

960 # Calculate and publish the time taken to create leases. This is done explicitly 

961 # rather than using the DurationMetric helper since we need to measure the actual 

962 # execution time of the UPDATE and INSERT queries used in the lease assignment, and 

963 # these are only exectuted on exiting the contextmanager. 

964 if create_lease_start_time is not None: 

965 run_time = timedelta(seconds=time.perf_counter() - create_lease_start_time) 

966 metadata = None 

967 if self._instance_name is not None: 

968 metadata = {'instance-name': self._instance_name} 

969 publish_timer_metric(DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, run_time, metadata=metadata) 

970 

971 except DatabaseError: 

972 self.__logger.warning("Will not assign any leases this time due to a Database Error.") 

973 

974 return leases 

975 

976 def _do_prune(self, job_max_age: timedelta, pruning_period: timedelta, limit: int) -> None: 

977 """ Running in a background thread, this method wakes up periodically and deletes older records 

978 from the jobs tables using configurable parameters """ 

979 

980 utc_last_prune_time = datetime.utcnow() 

981 while self.pruner_keep_running: 

982 utcnow = datetime.utcnow() 

983 if (utcnow - pruning_period) < utc_last_prune_time: 

984 self.__logger.info(f"Pruner thread sleeping for {pruning_period}(until {utcnow + pruning_period})") 

985 time.sleep(pruning_period.total_seconds()) 

986 continue 

987 

988 delete_before_datetime = utcnow - job_max_age 

989 try: 

990 with DurationMetric(DATA_STORE_PRUNER_DELETE_TIME_METRIC_NAME, 

991 instance_name=self._instance_name, 

992 instanced=True): 

993 num_rows = self._delete_jobs_prior_to(delete_before_datetime, limit) 

994 

995 self.__logger.info( 

996 f"Pruned {num_rows} row(s) from the jobs table older than {delete_before_datetime}") 

997 

998 if num_rows > 0: 

999 with Counter(metric_name=DATA_STORE_PRUNER_NUM_ROWS_DELETED_METRIC_NAME, 

1000 instance_name=self._instance_name) as num_rows_deleted: 

1001 num_rows_deleted.increment(num_rows) 

1002 

1003 except Exception: 

1004 self.__logger.exception("Caught exception while deleting jobs records") 

1005 

1006 finally: 

1007 # Update even if error occurred to avoid potentially infinitely retrying 

1008 utc_last_prune_time = utcnow 

1009 

1010 self.__logger.info("Exiting pruner thread") 

1011 

1012 def _delete_jobs_prior_to(self, delete_before_datetime: datetime, limit: int) -> int: 

1013 """ Deletes older records from the jobs tables constrained by `delete_before_datetime` and `limit` """ 

1014 

1015 jobs_table = Job.__table__ 

1016 with self.session(reraise=True) as session: 

1017 delete_stmt = jobs_table.delete().where( 

1018 jobs_table.c.name.in_( 

1019 sql.select([jobs_table.c.name]).where( 

1020 jobs_table.c.worker_completed_timestamp <= delete_before_datetime). 

1021 with_for_update(skip_locked=True). 

1022 limit(limit) 

1023 ) 

1024 ) 

1025 num_rows_deleted = session.execute(delete_stmt).rowcount 

1026 

1027 return num_rows_deleted