Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/sql/utils.py: 92.74%

179 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2020 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15""" Holds constants and utility functions for the SQL scheduler. """ 

16 

17 

18import operator 

19import random 

20from datetime import datetime, timedelta 

21from threading import Lock 

22from typing import Any, Dict, List, Optional, Tuple, cast 

23 

24from sqlalchemy.engine import Engine 

25from sqlalchemy.orm.session import Session as SessionType 

26from sqlalchemy.sql.expression import ClauseElement, and_, or_ 

27from sqlalchemy.sql.operators import ColumnOperators 

28from sqlalchemy.sql.schema import Column 

29 

30from buildgrid.server.exceptions import InvalidArgumentError 

31from buildgrid.server.logging import buildgrid_logger 

32from buildgrid.server.operations.filtering import OperationFilter, SortKey 

33from buildgrid.server.sql.models import ( 

34 ClientIdentityEntry, 

35 JobEntry, 

36 OperationEntry, 

37 PlatformEntry, 

38 RequestMetadataEntry, 

39) 

40 

41LOGGER = buildgrid_logger(__name__) 

42 

43 

44DATETIME_FORMAT = "%Y-%m-%d-%H-%M-%S-%f" 

45 

46 

47LIST_OPERATIONS_PARAMETER_MODEL_MAP = cast( 

48 Dict[str, Column[Any]], 

49 { 

50 "stage": JobEntry.stage, 

51 "name": OperationEntry.name, 

52 "queued_time": JobEntry.queued_timestamp, 

53 "start_time": JobEntry.worker_start_timestamp, 

54 "completed_time": JobEntry.worker_completed_timestamp, 

55 "invocation_id": RequestMetadataEntry.invocation_id, 

56 "correlated_invocations_id": RequestMetadataEntry.correlated_invocations_id, 

57 "tool_name": RequestMetadataEntry.tool_name, 

58 "tool_version": RequestMetadataEntry.tool_version, 

59 "action_mnemonic": RequestMetadataEntry.action_mnemonic, 

60 "target_id": RequestMetadataEntry.target_id, 

61 "configuration_id": RequestMetadataEntry.configuration_id, 

62 "action_digest": JobEntry.action_digest, 

63 "command": JobEntry.command, 

64 "platform": PlatformEntry.key, 

65 "platform-value": PlatformEntry.value, 

66 "client_workflow": ClientIdentityEntry.workflow, 

67 "client_actor": ClientIdentityEntry.actor, 

68 "client_subject": ClientIdentityEntry.subject, 

69 }, 

70) 

71 

72 

73def is_sqlite_connection_string(connection_string: str) -> bool: 

74 if connection_string: 

75 return connection_string.startswith("sqlite") 

76 return False 

77 

78 

79def is_psycopg2_connection_string(connection_string: str) -> bool: 

80 if connection_string: 

81 if connection_string.startswith("postgresql:"): 

82 return True 

83 if connection_string.startswith("postgresql+psycopg2:"): 

84 return True 

85 return False 

86 

87 

88def is_sqlite_inmemory_connection_string(full_connection_string: str) -> bool: 

89 if is_sqlite_connection_string(full_connection_string): 

90 # Valid connection_strings for in-memory SQLite which we don't support could look like: 

91 # "sqlite:///file:memdb1?option=value&cache=shared&mode=memory", 

92 # "sqlite:///file:memdb1?mode=memory&cache=shared", 

93 # "sqlite:///file:memdb1?cache=shared&mode=memory", 

94 # "sqlite:///file::memory:?cache=shared", 

95 # "sqlite:///file::memory:", 

96 # "sqlite:///:memory:", 

97 # "sqlite:///", 

98 # "sqlite://" 

99 # ref: https://www.sqlite.org/inmemorydb.html 

100 # Note that a user can also specify drivers, so prefix could become 'sqlite+driver:///' 

101 connection_string = full_connection_string 

102 

103 uri_split_index = connection_string.find("?") 

104 if uri_split_index != -1: 

105 connection_string = connection_string[0:uri_split_index] 

106 

107 if connection_string.endswith((":memory:", ":///", "://")): 

108 return True 

109 elif uri_split_index != -1: 

110 opts = full_connection_string[uri_split_index + 1 :].split("&") 

111 if "mode=memory" in opts: 

112 return True 

113 

114 return False 

115 

116 

117class SQLPoolDisposeHelper: 

118 """Helper class for disposing of SQL session connections""" 

119 

120 def __init__( 

121 self, 

122 cooldown_time_in_secs: int, 

123 cooldown_jitter_base_in_secs: int, 

124 min_time_between_dispose_in_minutes: int, 

125 sql_engine: Engine, 

126 ) -> None: 

127 self._cooldown_time_in_secs = cooldown_time_in_secs 

128 self._cooldown_jitter_base_in_secs = cooldown_jitter_base_in_secs 

129 self._min_time_between_dispose_in_minutes = min_time_between_dispose_in_minutes 

130 self._last_pool_dispose_time: Optional[datetime] = None 

131 self._last_pool_dispose_time_lock = Lock() 

132 self._sql_engine = sql_engine 

133 self._dispose_pool_on_exceptions: Tuple[Any, ...] = tuple() 

134 if self._sql_engine.dialect.name == "postgresql": 

135 import psycopg2 

136 

137 self._dispose_pool_on_exceptions = (psycopg2.errors.ReadOnlySqlTransaction, psycopg2.errors.AdminShutdown) 

138 

139 def check_dispose_pool(self, session: SessionType, e: Exception) -> bool: 

140 """For selected exceptions invalidate the SQL session 

141 - returns True when a transient sql connection error is detected 

142 - returns False otherwise 

143 """ 

144 

145 # Only do this if the config is relevant 

146 if not self._dispose_pool_on_exceptions: 

147 return False 

148 

149 # Make sure we have a SQL-related cause to check, otherwise skip 

150 if e.__cause__ and not isinstance(e.__cause__, Exception): 

151 return False 

152 

153 cause_type = type(e.__cause__) 

154 # Let's see if this exception is related to known disconnect exceptions 

155 is_connection_error = cause_type in self._dispose_pool_on_exceptions 

156 if not is_connection_error: 

157 return False 

158 

159 # Make sure this connection will not be re-used 

160 session.invalidate() 

161 LOGGER.info( 

162 "Detected a SQL exception related to the connection. Invalidating this connection.", 

163 tags=dict(exception=cause_type.__name__), 

164 ) 

165 

166 # Only allow disposal every self.__min_time_between_dispose_in_minutes 

167 now = datetime.utcnow() 

168 only_if_after = None 

169 

170 # Check if we should dispose the rest of the checked in connections 

171 with self._last_pool_dispose_time_lock: 

172 if self._last_pool_dispose_time: 

173 only_if_after = self._last_pool_dispose_time + timedelta( 

174 minutes=self._min_time_between_dispose_in_minutes 

175 ) 

176 if only_if_after and now < only_if_after: 

177 return True 

178 

179 # OK, we haven't disposed the pool recently 

180 self._last_pool_dispose_time = now 

181 LOGGER.info( 

182 "Disposing connection pool. New requests will have a fresh SQL connection.", 

183 tags=dict(cooldown_time_in_secs=self._cooldown_time_in_secs), 

184 ) 

185 self._sql_engine.dispose() 

186 

187 return True 

188 

189 def time_until_active_pool(self) -> timedelta: 

190 """The time at which the pool is expected to become 

191 active after a pool disposal. This adds small amounts of jitter 

192 to help spread out load due to retrying clients 

193 """ 

194 if self._last_pool_dispose_time: 

195 time_til_active = self._last_pool_dispose_time + timedelta(seconds=self._cooldown_time_in_secs) 

196 if datetime.utcnow() < time_til_active: 

197 return timedelta( 

198 seconds=self._cooldown_time_in_secs 

199 + random.uniform(-self._cooldown_jitter_base_in_secs, self._cooldown_jitter_base_in_secs) 

200 ) 

201 return timedelta(seconds=0) 

202 

203 

204def strtobool(val: str) -> bool: 

205 """Convert a string representation of truth to true (1) or false (0). 

206 True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values 

207 are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if 

208 'val' is anything else. 

209 """ 

210 val = val.lower() 

211 if val in ("y", "yes", "t", "true", "on", "1"): 

212 return True 

213 elif val in ("n", "no", "f", "false", "off", "0"): 

214 return False 

215 else: 

216 raise ValueError(f"invalid truth value {val}") 

217 

218 

219def parse_list_operations_sort_value(value: str, column: Column[Any]) -> Any: 

220 """Convert the string representation of a value to the proper Python type.""" 

221 python_type = column.property.columns[0].type.python_type 

222 if python_type == datetime: 

223 return datetime.strptime(value, DATETIME_FORMAT) 

224 elif python_type == bool: 

225 # Using this distutils function to cover a few different bool representations 

226 return strtobool(value) 

227 else: 

228 return python_type(value) 

229 

230 

231def dump_list_operations_token_value(token_value: Any) -> str: 

232 """Convert a value to a string for use in the page_token.""" 

233 if isinstance(token_value, datetime): 

234 return datetime.strftime(token_value, DATETIME_FORMAT) 

235 else: 

236 return str(token_value) 

237 

238 

239def build_pagination_clause_for_sort_key( 

240 sort_value: Any, previous_sort_values: List[Any], sort_keys: List[SortKey] 

241) -> ClauseElement: 

242 """Build part of a filter clause to figure out the starting point of the page given 

243 by the page_token. See the docstring of build_page_filter for more details.""" 

244 if len(sort_keys) <= len(previous_sort_values): 

245 raise ValueError("Not enough keys to unpack") 

246 

247 filter_clause_list = [] 

248 for i, previous_sort_value in enumerate(previous_sort_values): 

249 previous_sort_col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_keys[i].name] 

250 filter_clause_list.append(previous_sort_col == previous_sort_value) 

251 sort_key = sort_keys[len(previous_sort_values)] 

252 sort_col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name] 

253 if sort_key.descending: 

254 filter_clause_list.append(sort_col < sort_value) 

255 else: 

256 filter_clause_list.append(sort_col > sort_value) 

257 return and_(*filter_clause_list) 

258 

259 

260def build_page_filter(page_token: str, sort_keys: List[SortKey]) -> ClauseElement: 

261 """Build a filter to determine the starting point of the rows to fetch, based 

262 on the page_token. 

263 

264 The page_token is directly related to the sort order, and in this way it acts as a 

265 "cursor." It is given in the format Xval|Yval|Zval|..., where each element is a value 

266 corresponding to an orderable column in the database. If the corresponding rows are 

267 X, Y, and Z, then X is the primary sort key, with Y breaking ties between X, and Z 

268 breaking ties between X and Y. The corresponding filter clause is then: 

269 

270 (X > Xval) OR (X == XVal AND Y > Yval) OR (X == Xval AND Y == Yval AND Z > Zval) ... 

271 """ 

272 # The page token is used as a "cursor" to determine the starting point 

273 # of the rows to fetch. It is derived from the sort keys. 

274 token_elements = page_token.split("|") 

275 if len(token_elements) != len(sort_keys): 

276 # It is possible that an extra "|" was in the input 

277 # TODO: Handle extra "|"s somehow? Or at least allow escaping them 

278 raise InvalidArgumentError( 

279 f'Wrong number of "|"-separated elements in page token [{page_token}]. ' 

280 f"Expected {len(sort_keys)}, got {len(token_elements)}." 

281 ) 

282 

283 sort_key_clause_list = [] 

284 previous_sort_values: List[Any] = [] 

285 # Build the compound clause for each sort key in the token 

286 for i, sort_key in enumerate(sort_keys): 

287 col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name] 

288 sort_value = parse_list_operations_sort_value(token_elements[i], col) 

289 filter_clause = build_pagination_clause_for_sort_key(sort_value, previous_sort_values, sort_keys) 

290 sort_key_clause_list.append(filter_clause) 

291 previous_sort_values.append(sort_value) 

292 

293 return or_(*sort_key_clause_list) 

294 

295 

296def build_page_token(operation: Column[Any], sort_keys: List[SortKey]) -> str: 

297 """Use the sort keys to build a page token from the given operation.""" 

298 token_values = [] 

299 for sort_key in sort_keys: 

300 col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name] 

301 col_properties = col.property.columns[0] 

302 column_name = col_properties.name 

303 table_name = col_properties.table.name 

304 if table_name == "operations": 

305 token_value = getattr(operation, column_name) 

306 elif table_name == "jobs": 

307 token_value = getattr(operation.job, column_name) 

308 else: 

309 raise ValueError("Got invalid table f{table_name} for sort key {sort_key.name} while building page_token") 

310 

311 token_values.append(dump_list_operations_token_value(token_value)) 

312 

313 next_page_token = "|".join(token_values) 

314 return next_page_token 

315 

316 

317def extract_sort_keys(operation_filters: List[OperationFilter]) -> Tuple[List[SortKey], List[OperationFilter]]: 

318 """Splits the operation filters into sort keys and non-sort filters, returning both as 

319 separate lists. 

320 

321 Sort keys are specified with the "sort_order" parameter in the filter string. Multiple 

322 "sort_order"s can appear in the filter string, and all are extracted and returned.""" 

323 sort_keys = [] 

324 non_sort_filters = [] 

325 for op_filter in operation_filters: 

326 if op_filter.parameter == "sort_order": 

327 if op_filter.operator != operator.eq: 

328 raise InvalidArgumentError('sort_order must be specified with the "=" operator.') 

329 sort_keys.append(op_filter.value) 

330 else: 

331 non_sort_filters.append(op_filter) 

332 

333 return sort_keys, non_sort_filters 

334 

335 

336def build_sort_column_list(sort_keys: List[SortKey]) -> List["ColumnOperators[Any]"]: 

337 """Convert the list of sort keys into a list of columns that can be 

338 passed to an order_by. 

339 

340 This function checks the sort keys to ensure that they are in the 

341 parameter-model map and raises an InvalidArgumentError if they are not.""" 

342 sort_columns: List["ColumnOperators[Any]"] = [] 

343 for sort_key in sort_keys: 

344 try: 

345 col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name] 

346 if sort_key.descending: 

347 sort_columns.append(col.desc()) 

348 else: 

349 sort_columns.append(col.asc()) 

350 except KeyError: 

351 raise InvalidArgumentError(f"[{sort_key.name}] is not a valid sort key.") 

352 return sort_columns 

353 

354 

355def convert_filter_to_sql_filter(operation_filter: OperationFilter) -> ClauseElement: 

356 """Convert the internal representation of a filter to a representation that SQLAlchemy 

357 can understand. The return type is a "ColumnElement," per the end of this section in 

358 the SQLAlchemy docs: https://docs.sqlalchemy.org/en/13/core/tutorial.html#selecting-specific-columns 

359 

360 This function assumes that the parser has appropriately converted the filter 

361 value to a Python type that can be compared to the parameter.""" 

362 try: 

363 param = LIST_OPERATIONS_PARAMETER_MODEL_MAP[operation_filter.parameter] 

364 except KeyError: 

365 raise InvalidArgumentError(f"Invalid parameter: [{operation_filter.parameter}]") 

366 

367 if operation_filter.parameter == "command": 

368 if operation_filter.operator == operator.eq: 

369 return param.like(f"%{operation_filter.value}%") 

370 elif operation_filter.operator == operator.ne: 

371 return param.notlike(f"%{operation_filter.value}%") # type: ignore[no-any-return] 

372 

373 if operation_filter.parameter == "platform": 

374 key, value = operation_filter.value.split(":", 1) 

375 value_column = LIST_OPERATIONS_PARAMETER_MODEL_MAP["platform-value"] 

376 return and_(param == key, operation_filter.operator(value_column, value)) 

377 

378 # Better type? Returning Any from function declared to return "ClauseElement" 

379 return operation_filter.operator(param, operation_filter.value) # type: ignore[no-any-return] 

380 

381 

382def build_custom_filters(operation_filters: List[OperationFilter]) -> List[ClauseElement]: 

383 return [ 

384 convert_filter_to_sql_filter(operation_filter) 

385 for operation_filter in operation_filters 

386 if operation_filter.parameter != "platform" 

387 ]