Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from contextlib import contextmanager 

17import logging 

18import os 

19import select 

20from threading import Thread, Lock 

21import time 

22from datetime import datetime, timedelta 

23from tempfile import NamedTemporaryFile 

24from itertools import chain, combinations 

25from typing import Any, Dict, Iterable, List, Tuple 

26 

27from alembic import command 

28from alembic.config import Config 

29from sqlalchemy import create_engine, event, func, text, union, literal_column 

30from sqlalchemy.orm.session import sessionmaker, Session as SessionType 

31 

32from buildgrid._protos.google.longrunning import operations_pb2 

33from buildgrid._enums import LeaseState, MetricCategories, OperationStage 

34from buildgrid.server.metrics_names import ( 

35 BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, 

36 DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

37 DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, 

38 DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, 

39 DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, 

40 DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, 

41 DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, 

42 DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, 

43 DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, 

44 DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, 

45 DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, 

46 DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, 

47 DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, 

48 DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME 

49) 

50from buildgrid.server.metrics_utils import DurationMetric, publish_timer_metric 

51from buildgrid.server.operations.filtering import OperationFilter, SortKey, DEFAULT_SORT_KEYS 

52from buildgrid.server.persistence.interface import DataStoreInterface 

53from buildgrid.server.persistence.sql.models import digest_to_string, Job, Lease, Operation 

54from buildgrid.server.persistence.sql.utils import ( 

55 build_page_filter, 

56 build_page_token, 

57 extract_sort_keys, 

58 build_custom_filters, 

59 build_sort_column_list 

60) 

61from buildgrid.settings import MAX_JOB_BLOCK_TIME, MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES 

62from buildgrid.utils import JobState, hash_from_dict, convert_values_to_sorted_lists 

63 

64from buildgrid._exceptions import DatabaseError 

65 

66 

67Session = sessionmaker() 

68 

69 

70def sqlite_on_connect(conn, record): 

71 conn.execute("PRAGMA journal_mode=WAL") 

72 conn.execute("PRAGMA synchronous=NORMAL") 

73 

74 

75class SQLDataStore(DataStoreInterface): 

76 

77 def __init__(self, storage, *, connection_string=None, automigrate=False, 

78 connection_timeout=5, poll_interval=1, **kwargs): 

79 super().__init__() 

80 self.__logger = logging.getLogger(__name__) 

81 self.__logger.info("Creating SQL scheduler with: " 

82 f"automigrate=[{automigrate}], connection_timeout=[{connection_timeout}] " 

83 f"poll_interval=[{poll_interval}], kwargs=[{kwargs}]") 

84 

85 self.storage = storage 

86 self.response_cache = {} 

87 self.connection_timeout = connection_timeout 

88 self.poll_interval = poll_interval 

89 self.watcher = Thread(name="JobWatcher", target=self.wait_for_job_updates, daemon=True) 

90 self.watcher_keep_running = True 

91 self.__dispose_pool_on_exceptions: Tuple[Any, ...] = tuple() 

92 self.__last_pool_dispose_time = None 

93 self.__last_pool_dispose_time_lock = Lock() 

94 

95 # Set-up temporary SQLite Database when connection string is not specified 

96 if not connection_string: 

97 tmpdbfile = NamedTemporaryFile(prefix='bgd-', suffix='.db') 

98 self._tmpdbfile = tmpdbfile # Make sure to keep this tempfile for the lifetime of this object 

99 self.__logger.warning("No connection string specified for the DataStore, " 

100 f"will use SQLite with tempfile: [{tmpdbfile.name}]") 

101 automigrate = True # since this is a temporary database, we always need to create it 

102 connection_string = f"sqlite:///{tmpdbfile.name}" 

103 

104 self._create_sqlalchemy_engine(connection_string, automigrate, connection_timeout, **kwargs) 

105 

106 # Make a test query against the database to ensure the connection is valid 

107 with self.session(reraise=True) as session: 

108 session.query(Job).first() 

109 

110 self.watcher.start() 

111 

112 self.capabilities_cache = {} 

113 

114 def _create_sqlalchemy_engine(self, connection_string, automigrate, connection_timeout, **kwargs): 

115 self.automigrate = automigrate 

116 

117 # Disallow sqlite in-memory because multi-threaded access to it is 

118 # complex and potentially problematic at best 

119 # ref: https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#threading-pooling-behavior 

120 if self._is_sqlite_inmemory_connection_string(connection_string): 

121 raise ValueError( 

122 f"Cannot use SQLite in-memory with BuildGrid (connection_string=[{connection_string}]). " 

123 "Use a file or leave the connection_string empty for a tempfile.") 

124 

125 if connection_timeout is not None: 

126 if "connect_args" not in kwargs: 

127 kwargs["connect_args"] = {} 

128 if self._is_sqlite_connection_string(connection_string): 

129 kwargs["connect_args"]["timeout"] = connection_timeout 

130 else: 

131 kwargs["connect_args"]["connect_timeout"] = connection_timeout 

132 

133 # Only pass the (known) kwargs that have been explicitly set by the user 

134 available_options = set([ 

135 'pool_size', 'max_overflow', 'pool_timeout', 'pool_pre_ping', 

136 'pool_recycle', 'connect_args' 

137 ]) 

138 kwargs_keys = set(kwargs.keys()) 

139 if not kwargs_keys.issubset(available_options): 

140 unknown_options = kwargs_keys - available_options 

141 raise TypeError(f"Unknown keyword arguments: [{unknown_options}]") 

142 

143 self.__logger.debug(f"SQLAlchemy additional kwargs: [{kwargs}]") 

144 

145 self.engine = create_engine(connection_string, echo=False, **kwargs) 

146 Session.configure(bind=self.engine) 

147 

148 if self.engine.dialect.name == "sqlite": 

149 event.listen(self.engine, "connect", sqlite_on_connect) 

150 

151 self._configure_dialect_disposal_exceptions(self.engine.dialect.name) 

152 

153 if self.automigrate: 

154 self._create_or_migrate_db(connection_string) 

155 

156 def _is_sqlite_connection_string(self, connection_string): 

157 if connection_string: 

158 return connection_string.startswith("sqlite") 

159 return False 

160 

161 def _is_sqlite_inmemory_connection_string(self, full_connection_string): 

162 if self._is_sqlite_connection_string(full_connection_string): 

163 # Valid connection_strings for in-memory SQLite which we don't support could look like: 

164 # "sqlite:///file:memdb1?option=value&cache=shared&mode=memory", 

165 # "sqlite:///file:memdb1?mode=memory&cache=shared", 

166 # "sqlite:///file:memdb1?cache=shared&mode=memory", 

167 # "sqlite:///file::memory:?cache=shared", 

168 # "sqlite:///file::memory:", 

169 # "sqlite:///:memory:", 

170 # "sqlite:///", 

171 # "sqlite://" 

172 # ref: https://www.sqlite.org/inmemorydb.html 

173 # Note that a user can also specify drivers, so prefix could become 'sqlite+driver:///' 

174 connection_string = full_connection_string 

175 

176 uri_split_index = connection_string.find("?") 

177 if uri_split_index != -1: 

178 connection_string = connection_string[0:uri_split_index] 

179 

180 if connection_string.endswith((":memory:", ":///", "://")): 

181 return True 

182 elif uri_split_index != -1: 

183 opts = full_connection_string[uri_split_index + 1:].split("&") 

184 if "mode=memory" in opts: 

185 return True 

186 

187 return False 

188 

189 def __repr__(self): 

190 return f"SQL data store interface for `{repr(self.engine.url)}`" 

191 

192 def activate_monitoring(self): 

193 # Don't do anything. This function needs to exist but there's no 

194 # need to actually toggle monitoring in this implementation. 

195 pass 

196 

197 def deactivate_monitoring(self): 

198 # Don't do anything. This function needs to exist but there's no 

199 # need to actually toggle monitoring in this implementation. 

200 pass 

201 

202 def _configure_dialect_disposal_exceptions(self, dialect: str): 

203 self.__dispose_pool_on_exceptions = self._get_dialect_disposal_exceptions(dialect) 

204 

205 def _get_dialect_disposal_exceptions(self, dialect: str) -> Tuple[Any, ...]: 

206 dialect_errors: Tuple[Any, ...] = tuple() 

207 if dialect == 'postgresql': 

208 import psycopg2 # pylint: disable=import-outside-toplevel 

209 dialect_errors = (psycopg2.errors.ReadOnlySqlTransaction, psycopg2.errors.AdminShutdown) 

210 return dialect_errors 

211 

212 def _create_or_migrate_db(self, connection_string): 

213 self.__logger.warning("Will attempt migration to latest version if needed.") 

214 

215 config = Config() 

216 config.set_main_option("script_location", os.path.join(os.path.dirname(__file__), "alembic")) 

217 

218 with self.engine.begin() as connection: 

219 config.attributes['connection'] = connection 

220 command.upgrade(config, "head") 

221 

222 @contextmanager 

223 def session(self, *, sqlite_lock_immediately=False, reraise=False): 

224 # Try to obtain a session 

225 try: 

226 session = Session() 

227 if sqlite_lock_immediately and session.bind.name == "sqlite": 

228 session.execute("BEGIN IMMEDIATE") 

229 except Exception as e: 

230 self.__logger.error("Unable to obtain a database session.", exc_info=True) 

231 raise DatabaseError("Unable to obtain a database session.") from e 

232 

233 # Yield the session and catch exceptions that occur while using it 

234 # to roll-back if needed 

235 try: 

236 yield session 

237 session.commit() 

238 except Exception as e: 

239 self.__logger.error("Error committing database session. Rolling back.", exc_info=True) 

240 self._check_dispose_pool(session, e) 

241 try: 

242 session.rollback() 

243 except Exception: 

244 self.__logger.warning("Rollback error.", exc_info=True) 

245 

246 if reraise: 

247 raise 

248 finally: 

249 session.close() 

250 

251 def _check_dispose_pool(self, session: SessionType, e: Exception): 

252 # Only do this if the config is relevant 

253 if not self.__dispose_pool_on_exceptions: 

254 return 

255 

256 # Make sure we have a SQL-related cause to check, otherwise skip 

257 if e.__cause__ and isinstance(e.__cause__, Exception): 

258 cause_type = type(e.__cause__) 

259 # Only allow disposal every MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES 

260 now = datetime.utcnow() 

261 only_if_after = None 

262 

263 # Let's see if this exception is related to known disconnect exceptions 

264 is_connection_error = cause_type in self.__dispose_pool_on_exceptions 

265 

266 if is_connection_error: 

267 # Make sure this connection will not be re-used 

268 session.invalidate() 

269 self.__logger.info( 

270 f'Detected a SQL exception=[{cause_type.__name__}] related to the connection. ' 

271 'Invalidating this connection.' 

272 ) 

273 # Check if we should dispose the rest of the checked in connections 

274 with self.__last_pool_dispose_time_lock: 

275 if self.__last_pool_dispose_time: 

276 only_if_after = self.__last_pool_dispose_time + \ 

277 timedelta(minutes=MIN_TIME_BETWEEN_SQL_POOL_DISPOSE_MINUTES) 

278 if only_if_after and now < only_if_after: 

279 return 

280 

281 # OK, we haven't disposed the pool recently 

282 self.__last_pool_dispose_time = now 

283 self.engine.dispose() 

284 self.__logger.info('Disposing pool checked in connections so that they get recreated') 

285 

286 def _get_job(self, job_name, session, with_for_update=False): 

287 jobs = session.query(Job) 

288 if with_for_update: 

289 jobs = jobs.with_for_update() 

290 jobs = jobs.filter_by(name=job_name) 

291 

292 job = jobs.first() 

293 if job: 

294 self.__logger.debug(f"Loaded job from db: name=[{job_name}], stage=[{job.stage}], result=[{job.result}]") 

295 

296 return job 

297 

298 def _check_job_timeout(self, job_internal, *, max_execution_timeout=None): 

299 """ Do a lazy check of maximum allowed job timeouts when clients try to retrieve 

300 an existing job. 

301 Cancel the job and related operations/leases, if we detect they have 

302 exceeded timeouts on access. 

303 

304 Returns the `buildgrid.server.Job` object, possibly updated with `cancelled=True`. 

305 """ 

306 if job_internal and max_execution_timeout and job_internal.worker_start_timestamp_as_datetime: 

307 if job_internal.operation_stage == OperationStage.EXECUTING: 

308 executing_duration = datetime.utcnow() - job_internal.worker_start_timestamp_as_datetime 

309 if executing_duration.total_seconds() >= max_execution_timeout: 

310 self.__logger.warning(f"Job=[{job_internal}] has been executing for " 

311 f"executing_duration=[{executing_duration}]. " 

312 f"max_execution_timeout=[{max_execution_timeout}] " 

313 "Cancelling.") 

314 job_internal.cancel_all_operations(data_store=self) 

315 self.__logger.info(f"Job=[{job_internal}] has been cancelled.") 

316 return job_internal 

317 

318 @DurationMetric(DATA_STORE_GET_JOB_BY_DIGEST_TIME_METRIC_NAME, instanced=True) 

319 def get_job_by_action(self, action_digest, *, max_execution_timeout=None): 

320 with self.session() as session: 

321 jobs = session.query(Job).filter_by(action_digest=digest_to_string(action_digest)) 

322 jobs = jobs.filter(Job.stage != OperationStage.COMPLETED.value) 

323 job = jobs.first() 

324 if job: 

325 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

326 instance_name=self._instance_name) 

327 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

328 return None 

329 

330 @DurationMetric(DATA_STORE_GET_JOB_BY_NAME_TIME_METRIC_NAME, instanced=True) 

331 def get_job_by_name(self, name, *, max_execution_timeout=None): 

332 with self.session() as session: 

333 job = self._get_job(name, session) 

334 if job: 

335 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

336 instance_name=self._instance_name) 

337 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

338 return None 

339 

340 @DurationMetric(DATA_STORE_GET_JOB_BY_OPERATION_TIME_METRIC_NAME, instanced=True) 

341 def get_job_by_operation(self, operation_name, *, max_execution_timeout=None): 

342 with self.session() as session: 

343 operation = self._get_operation(operation_name, session) 

344 if operation and operation.job: 

345 job = operation.job 

346 internal_job = job.to_internal_job(self, action_browser_url=self._action_browser_url, 

347 instance_name=self._instance_name) 

348 return self._check_job_timeout(internal_job, max_execution_timeout=max_execution_timeout) 

349 return None 

350 

351 def get_all_jobs(self): 

352 with self.session() as session: 

353 jobs = session.query(Job).filter(Job.stage != OperationStage.COMPLETED.value) 

354 return [j.to_internal_job(self, action_browser_url=self._action_browser_url, 

355 instance_name=self._instance_name) for j in jobs] 

356 

357 def get_jobs_by_stage(self, operation_stage): 

358 with self.session() as session: 

359 jobs = session.query(Job).filter(Job.stage == operation_stage.value) 

360 return [j.to_internal_job(self, no_result=True, action_browser_url=self._action_browser_url, 

361 instance_name=self._instance_name) for j in jobs] 

362 

363 @DurationMetric(DATA_STORE_CREATE_JOB_TIME_METRIC_NAME, instanced=True) 

364 def create_job(self, job): 

365 with self.session() as session: 

366 if self._get_job(job.name, session) is None: 

367 # Convert requirements values to sorted lists to make them json-serializable 

368 platform_requirements = job.platform_requirements 

369 convert_values_to_sorted_lists(platform_requirements) 

370 # Serialize the requirements 

371 platform_requirements_hash = hash_from_dict(platform_requirements) 

372 

373 session.add(Job( 

374 name=job.name, 

375 action=job.action.SerializeToString(), 

376 action_digest=digest_to_string(job.action_digest), 

377 do_not_cache=job.do_not_cache, 

378 priority=job.priority, 

379 operations=[], 

380 platform_requirements=platform_requirements_hash, 

381 stage=job.operation_stage.value, 

382 queued_timestamp=job.queued_timestamp_as_datetime, 

383 queued_time_duration=job.queued_time_duration.seconds, 

384 worker_start_timestamp=job.worker_start_timestamp_as_datetime, 

385 worker_completed_timestamp=job.worker_completed_timestamp_as_datetime 

386 )) 

387 

388 @DurationMetric(DATA_STORE_QUEUE_JOB_TIME_METRIC_NAME, instanced=True) 

389 def queue_job(self, job_name): 

390 with self.session(sqlite_lock_immediately=True) as session: 

391 job = self._get_job(job_name, session, with_for_update=True) 

392 job.assigned = False 

393 

394 @DurationMetric(DATA_STORE_UPDATE_JOB_TIME_METRIC_NAME, instanced=True) 

395 def update_job(self, job_name, changes, *, skip_notify=False): 

396 if "result" in changes: 

397 changes["result"] = digest_to_string(changes["result"]) 

398 if "action_digest" in changes: 

399 changes["action_digest"] = digest_to_string(changes["action_digest"]) 

400 

401 with self.session() as session: 

402 job = self._get_job(job_name, session) 

403 job.update(changes) 

404 if not skip_notify: 

405 self._notify_job_updated(job_name, session) 

406 

407 def _notify_job_updated(self, job_names, session): 

408 if self.engine.dialect.name == "postgresql": 

409 if isinstance(job_names, str): 

410 job_names = [job_names] 

411 for job_name in job_names: 

412 session.execute(f"NOTIFY job_updated, '{job_name}';") 

413 

414 def delete_job(self, job_name): 

415 if job_name in self.response_cache: 

416 del self.response_cache[job_name] 

417 

418 def wait_for_job_updates(self): 

419 self.__logger.info("Starting job watcher thread") 

420 if self.engine.dialect.name == "postgresql": 

421 self._listen_for_updates() 

422 else: 

423 self._poll_for_updates() 

424 

425 def _listen_for_updates(self): 

426 def _listen_loop(): 

427 try: 

428 conn = self.engine.connect() 

429 conn.execute(text("LISTEN job_updated;").execution_options(autocommit=True)) 

430 except Exception as e: 

431 raise DatabaseError("Could not start listening to DB for job updates") from e 

432 

433 while self.watcher_keep_running: 

434 # Wait until the connection is ready for reading. Timeout after 5 seconds 

435 # and try again if there was nothing to read. If the connection becomes 

436 # readable, collect the notifications it has received and handle them. 

437 # 

438 # See http://initd.org/psycopg/docs/advanced.html#async-notify 

439 if select.select([conn.connection], [], [], self.poll_interval) == ([], [], []): 

440 pass 

441 else: 

442 

443 try: 

444 conn.connection.poll() 

445 except Exception as e: 

446 raise DatabaseError("Error while polling for job updates") from e 

447 

448 while conn.connection.notifies: 

449 notify = conn.connection.notifies.pop() 

450 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

451 instanced=True, instance_name=self._instance_name): 

452 with self.watched_jobs_lock: 

453 spec = self.watched_jobs.get(notify.payload) 

454 if spec is not None: 

455 try: 

456 new_job = self.get_job_by_name(notify.payload) 

457 except Exception as e: 

458 raise DatabaseError( 

459 f"Couldn't get watched job=[{notify.payload}] from DB") from e 

460 

461 # If the job doesn't exist or an exception was supressed by 

462 # get_job_by_name, it returns None instead of the job 

463 if new_job is None: 

464 raise DatabaseError( 

465 f"get_job_by_name returned None for job=[{notify.payload}]") 

466 

467 new_state = JobState(new_job) 

468 if spec.last_state != new_state: 

469 spec.last_state = new_state 

470 spec.event.notify_change() 

471 

472 while self.watcher_keep_running: 

473 # Wait a few seconds if a database exception occurs and then try again 

474 # This could be a short disconnect 

475 try: 

476 _listen_loop() 

477 except DatabaseError as e: 

478 self.__logger.warning(f"JobWatcher encountered exception: [{e}];" 

479 f"Retrying in poll_interval=[{self.poll_interval}] seconds.") 

480 # Sleep for a bit so that we give enough time for the 

481 # database to potentially recover 

482 time.sleep(self.poll_interval) 

483 

484 def _get_watched_jobs(self): 

485 with self.session() as sess: 

486 jobs = sess.query(Job).filter( 

487 Job.name.in_(self.watched_jobs) 

488 ) 

489 return [job.to_internal_job(self) for job in jobs.all()] 

490 

491 def _poll_for_updates(self): 

492 def _poll_loop(): 

493 while self.watcher_keep_running: 

494 time.sleep(self.poll_interval) 

495 if self.watcher_keep_running: 

496 with DurationMetric(DATA_STORE_CHECK_FOR_UPDATE_TIME_METRIC_NAME, 

497 instanced=True, instance_name=self._instance_name): 

498 with self.watched_jobs_lock: 

499 if self.watcher_keep_running: 

500 try: 

501 watched_jobs = self._get_watched_jobs() 

502 except Exception as e: 

503 raise DatabaseError("Couldn't retrieve watched jobs from DB") from e 

504 

505 if watched_jobs is None: 

506 raise DatabaseError("_get_watched_jobs returned None") 

507 

508 for new_job in watched_jobs: 

509 if self.watcher_keep_running: 

510 spec = self.watched_jobs[new_job.name] 

511 new_state = JobState(new_job) 

512 if spec.last_state != new_state: 

513 spec.last_state = new_state 

514 spec.event.notify_change() 

515 

516 while self.watcher_keep_running: 

517 # Wait a few seconds if a database exception occurs and then try again 

518 try: 

519 _poll_loop() 

520 except DatabaseError as e: 

521 self.__logger.warning(f"JobWatcher encountered exception: [{e}];" 

522 f"Retrying in poll_interval=[{self.poll_interval}] seconds.") 

523 # Sleep for a bit so that we give enough time for the 

524 # database to potentially recover 

525 time.sleep(self.poll_interval) 

526 

527 @DurationMetric(DATA_STORE_STORE_RESPONSE_TIME_METRIC_NAME, instanced=True) 

528 def store_response(self, job, commit_changes=True): 

529 digest = self.storage.put_message(job.execute_response) 

530 changes = {"result": digest, "status_code": job.execute_response.status.code} 

531 self.response_cache[job.name] = job.execute_response 

532 

533 if commit_changes: 

534 self.update_job(job.name, 

535 changes, 

536 skip_notify=True) 

537 return None 

538 else: 

539 # The caller will batch the changes and commit to db 

540 return changes 

541 

542 def _get_operation(self, operation_name, session): 

543 operations = session.query(Operation).filter_by(name=operation_name) 

544 return operations.first() 

545 

546 def get_operations_by_stage(self, operation_stage): 

547 with self.session() as session: 

548 operations = session.query(Operation) 

549 operations = operations.filter(Operation.job.has(stage=operation_stage.value)) 

550 operations = operations.all() 

551 # Return a set of job names here for now, to match the `MemoryDataStore` 

552 # implementation's behaviour 

553 return set(op.job.name for op in operations) 

554 

555 def _cancel_jobs_exceeding_execution_timeout(self, max_execution_timeout: int=None) -> None: 

556 if max_execution_timeout: 

557 stale_job_names = [] 

558 lazy_execution_timeout_threshold = datetime.utcnow() - timedelta(seconds=max_execution_timeout) 

559 

560 jobs_table = Job.__table__ 

561 operations_table = Operation.__table__ 

562 

563 with self.session(sqlite_lock_immediately=True) as session: 

564 # Get the full list of jobs exceeding execution timeout 

565 stale_jobs = session.query(Job).filter_by(stage=OperationStage.EXECUTING.value) 

566 stale_jobs = stale_jobs.filter(Job.worker_start_timestamp <= lazy_execution_timeout_threshold) 

567 stale_job_names = [job.name for job in stale_jobs.with_for_update().all()] 

568 

569 if stale_job_names: 

570 # Mark operations as cancelled 

571 stmt_mark_operations_cancelled = operations_table.update().where( 

572 operations_table.c.job_name.in_(stale_job_names) 

573 ).values(cancelled=True) 

574 session.execute(stmt_mark_operations_cancelled) 

575 

576 # Mark jobs as cancelled 

577 stmt_mark_jobs_cancelled = jobs_table.update().where( 

578 jobs_table.c.name.in_(stale_job_names) 

579 ).values(stage=OperationStage.COMPLETED.value, cancelled=True) 

580 session.execute(stmt_mark_jobs_cancelled) 

581 

582 # Notify all jobs updated 

583 self._notify_job_updated(stale_job_names, session) 

584 

585 if stale_job_names: 

586 self.__logger.info(f"Cancelled n=[{len(stale_job_names)}] jobs " 

587 f"with names={stale_job_names}" 

588 f"due to them exceeding execution_timeout=[" 

589 f"{max_execution_timeout}") 

590 

591 @DurationMetric(DATA_STORE_LIST_OPERATIONS_TIME_METRIC_NAME, instanced=True) 

592 def list_operations(self, 

593 operation_filters: List[OperationFilter]=None, 

594 page_size: int=None, 

595 page_token: str=None, 

596 max_execution_timeout: int=None) -> Tuple[List[operations_pb2.Operation], str]: 

597 # Lazily timeout jobs as needed before returning the list! 

598 self._cancel_jobs_exceeding_execution_timeout(max_execution_timeout=max_execution_timeout) 

599 

600 # Build filters and sort order 

601 sort_keys = DEFAULT_SORT_KEYS 

602 custom_filters = None 

603 if operation_filters: 

604 # Extract custom sort order (if present) 

605 specified_sort_keys, non_sort_filters = extract_sort_keys(operation_filters) 

606 

607 # Only override sort_keys if there were sort keys actually present in the filter string 

608 if specified_sort_keys: 

609 sort_keys = specified_sort_keys 

610 # Attach the operation name as a sort key for a deterministic order 

611 # This will ensure that the ordering of results is consistent between queries 

612 if not any(sort_key.name == "name" for sort_key in sort_keys): 

613 sort_keys.append(SortKey(name="name", descending=False)) 

614 

615 # Finally, compile the non-sort filters into a filter list 

616 custom_filters = build_custom_filters(non_sort_filters) 

617 

618 sort_columns = build_sort_column_list(sort_keys) 

619 

620 with self.session() as session: 

621 results = session.query(Operation).join(Job, Operation.job_name == Job.name) 

622 

623 # Apply custom filters (if present) 

624 if custom_filters: 

625 results = results.filter(*custom_filters) 

626 

627 # Apply sort order 

628 results = results.order_by(*sort_columns) 

629 

630 # Apply pagination filter 

631 if page_token: 

632 page_filter = build_page_filter(page_token, sort_keys) 

633 results = results.filter(page_filter) 

634 if page_size: 

635 # We limit the number of operations we fetch to the page_size. However, we 

636 # fetch an extra operation to determine whether we need to provide a 

637 # next_page_token. 

638 results = results.limit(page_size + 1) 

639 

640 operations = list(results) 

641 

642 if not page_size or not operations: 

643 next_page_token = "" 

644 

645 # If the number of results we got is less than or equal to our page_size, 

646 # we're done with the operations listing and don't need to provide another 

647 # page token 

648 elif len(operations) <= page_size: 

649 next_page_token = "" 

650 else: 

651 # Drop the last operation since we have an extra 

652 operations.pop() 

653 # Our page token will be the last row of our set 

654 next_page_token = build_page_token(operations[-1], sort_keys) 

655 return [operation.to_protobuf(self) for operation in operations], next_page_token 

656 

657 @DurationMetric(DATA_STORE_CREATE_OPERATION_TIME_METRIC_NAME, instanced=True) 

658 def create_operation(self, operation_name, job_name, request_metadata=None): 

659 with self.session() as session: 

660 operation = Operation( 

661 name=operation_name, 

662 job_name=job_name 

663 ) 

664 if request_metadata is not None: 

665 if request_metadata.tool_invocation_id: 

666 operation.invocation_id = request_metadata.tool_invocation_id 

667 if request_metadata.correlated_invocations_id: 

668 operation.correlated_invocations_id = request_metadata.correlated_invocations_id 

669 if request_metadata.tool_details: 

670 operation.tool_name = request_metadata.tool_details.tool_name 

671 operation.tool_version = request_metadata.tool_details.tool_version 

672 session.add(operation) 

673 

674 @DurationMetric(DATA_STORE_UPDATE_OPERATION_TIME_METRIC_NAME, instanced=True) 

675 def update_operation(self, operation_name, changes): 

676 with self.session() as session: 

677 operation = self._get_operation(operation_name, session) 

678 operation.update(changes) 

679 

680 def delete_operation(self, operation_name): 

681 # Don't do anything. This function needs to exist but there's no 

682 # need to actually delete operations in this implementation. 

683 pass 

684 

685 def get_leases_by_state(self, lease_state): 

686 with self.session() as session: 

687 leases = session.query(Lease).filter_by(state=lease_state.value) 

688 leases = leases.all() 

689 # `lease.job_name` is the same as `lease.id` for a Lease protobuf 

690 return set(lease.job_name for lease in leases) 

691 

692 def get_metrics(self): 

693 

694 def _get_query_leases_by_state(session, category): 

695 # Using func.count here to avoid generating a subquery in the WHERE 

696 # clause of the resulting query. 

697 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

698 query = session.query(literal_column(category).label("category"), 

699 Lease.state.label("bucket"), 

700 func.count(Lease.id).label("value")) 

701 query = query.group_by(Lease.state) 

702 return query 

703 

704 def _cb_query_leases_by_state(leases_by_state): 

705 # The database only returns counts > 0, so fill in the gaps 

706 for state in LeaseState: 

707 if state.value not in leases_by_state: 

708 leases_by_state[state.value] = 0 

709 return leases_by_state 

710 

711 def _get_query_operations_by_stage(session, category): 

712 # Using func.count here to avoid generating a subquery in the WHERE 

713 # clause of the resulting query. 

714 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

715 query = session.query(literal_column(category).label("category"), 

716 Job.stage.label("bucket"), 

717 func.count(Operation.name).label("value")) 

718 query = query.join(Job) 

719 query = query.group_by(Job.stage) 

720 return query 

721 

722 def _cb_query_operations_by_stage(operations_by_stage): 

723 # The database only returns counts > 0, so fill in the gaps 

724 for stage in OperationStage: 

725 if stage.value not in operations_by_stage: 

726 operations_by_stage[stage.value] = 0 

727 return operations_by_stage 

728 

729 def _get_query_jobs_by_stage(session, category): 

730 # Using func.count here to avoid generating a subquery in the WHERE 

731 # clause of the resulting query. 

732 # https://docs.sqlalchemy.org/en/13/orm/query.html#sqlalchemy.orm.query.Query.count 

733 query = session.query(literal_column(category).label("category"), 

734 Job.stage.label("bucket"), 

735 func.count(Job.name).label("value")) 

736 query = query.group_by(Job.stage) 

737 return query 

738 

739 def _cb_query_jobs_by_stage(jobs_by_stage): 

740 # The database only returns counts > 0, so fill in the gaps 

741 for stage in OperationStage: 

742 if stage.value not in jobs_by_stage: 

743 jobs_by_stage[stage.value] = 0 

744 return jobs_by_stage 

745 

746 metrics = {} 

747 try: 

748 with self.session() as session: 

749 # metrics to gather: (category_name, function_returning_query, callback_function) 

750 metrics_to_gather = [(MetricCategories.LEASES.value, _get_query_leases_by_state, 

751 _cb_query_leases_by_state), 

752 (MetricCategories.OPERATIONS.value, _get_query_operations_by_stage, 

753 _cb_query_operations_by_stage), 

754 (MetricCategories.JOBS.value, _get_query_jobs_by_stage, 

755 _cb_query_jobs_by_stage)] 

756 

757 union_query = union(*[query_fn(session, f"'{category}'") 

758 for category, query_fn, _ in metrics_to_gather]) 

759 union_results = session.execute(union_query).fetchall() 

760 

761 grouped_results = {category: {} for category, _, _ in union_results} 

762 for category, bucket, value in union_results: 

763 grouped_results[category][bucket] = value 

764 

765 for category, _, category_cb in metrics_to_gather: 

766 metrics[category] = category_cb(grouped_results.setdefault(category, {})) 

767 except DatabaseError: 

768 self.__logger.warning("Unable to gather metrics due to a Database Error.") 

769 return {} 

770 

771 return metrics 

772 

773 def _create_lease(self, lease, session, job=None): 

774 if job is None: 

775 job = self._get_job(lease.id, session) 

776 job = job.to_internal_job(self) 

777 session.add(Lease( 

778 job_name=lease.id, 

779 state=lease.state, 

780 status=None, 

781 worker_name=job.worker_name 

782 )) 

783 

784 def create_lease(self, lease): 

785 with self.session() as session: 

786 self._create_lease(lease, session) 

787 

788 @DurationMetric(DATA_STORE_UPDATE_LEASE_TIME_METRIC_NAME, instanced=True) 

789 def update_lease(self, job_name, changes): 

790 with self.session() as session: 

791 job = self._get_job(job_name, session) 

792 lease = job.active_leases[0] 

793 lease.update(changes) 

794 

795 def load_unfinished_jobs(self): 

796 with self.session() as session: 

797 jobs = session.query(Job) 

798 jobs = jobs.filter(Job.stage != OperationStage.COMPLETED.value) 

799 jobs = jobs.order_by(Job.priority) 

800 return [j.to_internal_job(self) for j in jobs.all()] 

801 

802 def assign_lease_for_next_job(self, capabilities, callback, timeout=None): 

803 """Return a list of leases for the highest priority jobs that can be run by a worker. 

804 

805 NOTE: Currently the list only ever has one or zero leases. 

806 

807 Query the jobs table to find queued jobs which match the capabilities of 

808 a given worker, and return the one with the highest priority. Takes a 

809 dictionary of worker capabilities to compare with job requirements. 

810 

811 :param capabilities: Dictionary of worker capabilities to compare 

812 with job requirements when finding a job. 

813 :type capabilities: dict 

814 :param callback: Function to run on the next runnable job, should return 

815 a list of leases. 

816 :type callback: function 

817 :param timeout: time to wait for new jobs, caps if longer 

818 than MAX_JOB_BLOCK_TIME. 

819 :type timeout: int 

820 :returns: List of leases 

821 

822 """ 

823 if not timeout: 

824 return self._assign_job_leases(capabilities, callback) 

825 

826 # Cap the timeout if it's larger than MAX_JOB_BLOCK_TIME 

827 if timeout: 

828 timeout = min(timeout, MAX_JOB_BLOCK_TIME) 

829 

830 start = time.time() 

831 while time.time() + self.connection_timeout + 1 - start < timeout: 

832 leases = self._assign_job_leases(capabilities, callback) 

833 if leases: 

834 return leases 

835 time.sleep(0.5) 

836 if self.connection_timeout > timeout: 

837 self.__logger.warning( 

838 "Not providing any leases to the worker because the database connection " 

839 f"timeout ({self.connection_timeout} s) is longer than the remaining " 

840 "time to handle the request. " 

841 "Increase the worker's timeout to solve this problem.") 

842 return [] 

843 

844 def flatten_capabilities(self, capabilities: Dict[str, List[str]]) -> List[Tuple[str, str]]: 

845 """ Flatten a capabilities dictionary, assuming all of its values are lists. E.g. 

846 

847 {'OSFamily': ['Linux'], 'ISA': ['x86-32', 'x86-64']} 

848 

849 becomes 

850 

851 [('OSFamily', 'Linux'), ('ISA', 'x86-32'), ('ISA', 'x86-64')] """ 

852 return [ 

853 (name, value) for name, value_list in capabilities.items() 

854 for value in value_list 

855 ] 

856 

857 def get_partial_capabilities(self, capabilities: Dict[str, List[str]]) -> Iterable[Dict[str, List[str]]]: 

858 """ Given a capabilities dictionary with all values as lists, 

859 yield all partial capabilities dictionaries. """ 

860 CAPABILITIES_WARNING_THRESHOLD = 10 

861 

862 caps_flat = self.flatten_capabilities(capabilities) 

863 

864 if len(caps_flat) > CAPABILITIES_WARNING_THRESHOLD: 

865 self.__logger.warning( 

866 "A worker with a large capabilities dictionary has been connected. " 

867 f"Processing its capabilities may take a while. Capabilities: {capabilities}") 

868 

869 # Using the itertools powerset recipe, construct the powerset of the tuples 

870 capabilities_powerset = chain.from_iterable(combinations(caps_flat, r) for r in range(len(caps_flat) + 1)) 

871 for partial_capability_tuples in capabilities_powerset: 

872 partial_dict: Dict[str, List[str]] = {} 

873 

874 for tup in partial_capability_tuples: 

875 partial_dict.setdefault(tup[0], []).append(tup[1]) 

876 yield partial_dict 

877 

878 def get_partial_capabilities_hashes(self, capabilities: Dict) -> List[str]: 

879 """ Given a list of configurations, obtain each partial configuration 

880 for each configuration, obtain the hash of each partial configuration, 

881 compile these into a list, and return the result. """ 

882 # Convert requirements values to sorted lists to make them json-serializable 

883 convert_values_to_sorted_lists(capabilities) 

884 

885 # Check to see if we've cached this value 

886 capabilities_digest = hash_from_dict(capabilities) 

887 try: 

888 return self.capabilities_cache[capabilities_digest] 

889 except KeyError: 

890 # On cache miss, expand the capabilities into each possible partial capabilities dictionary 

891 capabilities_list = [] 

892 for partial_capability in self.get_partial_capabilities(capabilities): 

893 capabilities_list.append(hash_from_dict(partial_capability)) 

894 

895 self.capabilities_cache[capabilities_digest] = capabilities_list 

896 return capabilities_list 

897 

898 @DurationMetric(BOTS_ASSIGN_JOB_LEASES_TIME_METRIC_NAME, instanced=True) 

899 def _assign_job_leases(self, capabilities, callback): 

900 # pylint: disable=singleton-comparison 

901 # Hash the capabilities 

902 capabilities_config_hashes = self.get_partial_capabilities_hashes(capabilities) 

903 leases = [] 

904 try: 

905 create_lease_start_time = None 

906 with self.session(sqlite_lock_immediately=True) as session: 

907 jobs = session.query(Job).with_for_update(skip_locked=True) 

908 jobs = jobs.filter(Job.stage == OperationStage.QUEUED.value) 

909 jobs = jobs.filter(Job.assigned != True) # noqa 

910 jobs = jobs.filter(Job.platform_requirements.in_(capabilities_config_hashes)) 

911 job = jobs.order_by(Job.priority, Job.queued_timestamp).first() 

912 # This worker can take this job if it can handle all of its configurations 

913 if job: 

914 internal_job = job.to_internal_job(self) 

915 leases = callback(internal_job) 

916 if leases: 

917 job.assigned = True 

918 job.worker_start_timestamp = internal_job.worker_start_timestamp_as_datetime 

919 create_lease_start_time = time.perf_counter() 

920 for lease in leases: 

921 self._create_lease(lease, session, job=internal_job) 

922 

923 # Calculate and publish the time taken to create leases. This is done explicitly 

924 # rather than using the DurationMetric helper since we need to measure the actual 

925 # execution time of the UPDATE and INSERT queries used in the lease assignment, and 

926 # these are only exectuted on exiting the contextmanager. 

927 if create_lease_start_time is not None: 

928 run_time = timedelta(seconds=time.perf_counter() - create_lease_start_time) 

929 metadata = None 

930 if self._instance_name is not None: 

931 metadata = {'instance-name': self._instance_name} 

932 publish_timer_metric(DATA_STORE_CREATE_LEASE_TIME_METRIC_NAME, run_time, metadata=metadata) 

933 

934 except DatabaseError: 

935 self.__logger.warning("Will not assign any leases this time due to a Database Error.") 

936 

937 return leases