Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/sql/utils.py: 92.74%
179 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2020 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15""" Holds constants and utility functions for the SQL scheduler. """
18import operator
19import random
20from datetime import datetime, timedelta
21from threading import Lock
22from typing import Any, Dict, List, Optional, Tuple, cast
24from sqlalchemy.engine import Engine
25from sqlalchemy.orm.session import Session as SessionType
26from sqlalchemy.sql.expression import ClauseElement, and_, or_
27from sqlalchemy.sql.operators import ColumnOperators
28from sqlalchemy.sql.schema import Column
30from buildgrid.server.exceptions import InvalidArgumentError
31from buildgrid.server.logging import buildgrid_logger
32from buildgrid.server.operations.filtering import OperationFilter, SortKey
33from buildgrid.server.sql.models import (
34 ClientIdentityEntry,
35 JobEntry,
36 OperationEntry,
37 PlatformEntry,
38 RequestMetadataEntry,
39)
41LOGGER = buildgrid_logger(__name__)
44DATETIME_FORMAT = "%Y-%m-%d-%H-%M-%S-%f"
47LIST_OPERATIONS_PARAMETER_MODEL_MAP = cast(
48 Dict[str, Column[Any]],
49 {
50 "stage": JobEntry.stage,
51 "name": OperationEntry.name,
52 "queued_time": JobEntry.queued_timestamp,
53 "start_time": JobEntry.worker_start_timestamp,
54 "completed_time": JobEntry.worker_completed_timestamp,
55 "invocation_id": RequestMetadataEntry.invocation_id,
56 "correlated_invocations_id": RequestMetadataEntry.correlated_invocations_id,
57 "tool_name": RequestMetadataEntry.tool_name,
58 "tool_version": RequestMetadataEntry.tool_version,
59 "action_mnemonic": RequestMetadataEntry.action_mnemonic,
60 "target_id": RequestMetadataEntry.target_id,
61 "configuration_id": RequestMetadataEntry.configuration_id,
62 "action_digest": JobEntry.action_digest,
63 "command": JobEntry.command,
64 "platform": PlatformEntry.key,
65 "platform-value": PlatformEntry.value,
66 "client_workflow": ClientIdentityEntry.workflow,
67 "client_actor": ClientIdentityEntry.actor,
68 "client_subject": ClientIdentityEntry.subject,
69 },
70)
73def is_sqlite_connection_string(connection_string: str) -> bool:
74 if connection_string:
75 return connection_string.startswith("sqlite")
76 return False
79def is_psycopg2_connection_string(connection_string: str) -> bool:
80 if connection_string:
81 if connection_string.startswith("postgresql:"):
82 return True
83 if connection_string.startswith("postgresql+psycopg2:"):
84 return True
85 return False
88def is_sqlite_inmemory_connection_string(full_connection_string: str) -> bool:
89 if is_sqlite_connection_string(full_connection_string):
90 # Valid connection_strings for in-memory SQLite which we don't support could look like:
91 # "sqlite:///file:memdb1?option=value&cache=shared&mode=memory",
92 # "sqlite:///file:memdb1?mode=memory&cache=shared",
93 # "sqlite:///file:memdb1?cache=shared&mode=memory",
94 # "sqlite:///file::memory:?cache=shared",
95 # "sqlite:///file::memory:",
96 # "sqlite:///:memory:",
97 # "sqlite:///",
98 # "sqlite://"
99 # ref: https://www.sqlite.org/inmemorydb.html
100 # Note that a user can also specify drivers, so prefix could become 'sqlite+driver:///'
101 connection_string = full_connection_string
103 uri_split_index = connection_string.find("?")
104 if uri_split_index != -1:
105 connection_string = connection_string[0:uri_split_index]
107 if connection_string.endswith((":memory:", ":///", "://")):
108 return True
109 elif uri_split_index != -1:
110 opts = full_connection_string[uri_split_index + 1 :].split("&")
111 if "mode=memory" in opts:
112 return True
114 return False
117class SQLPoolDisposeHelper:
118 """Helper class for disposing of SQL session connections"""
120 def __init__(
121 self,
122 cooldown_time_in_secs: int,
123 cooldown_jitter_base_in_secs: int,
124 min_time_between_dispose_in_minutes: int,
125 sql_engine: Engine,
126 ) -> None:
127 self._cooldown_time_in_secs = cooldown_time_in_secs
128 self._cooldown_jitter_base_in_secs = cooldown_jitter_base_in_secs
129 self._min_time_between_dispose_in_minutes = min_time_between_dispose_in_minutes
130 self._last_pool_dispose_time: Optional[datetime] = None
131 self._last_pool_dispose_time_lock = Lock()
132 self._sql_engine = sql_engine
133 self._dispose_pool_on_exceptions: Tuple[Any, ...] = tuple()
134 if self._sql_engine.dialect.name == "postgresql":
135 import psycopg2
137 self._dispose_pool_on_exceptions = (psycopg2.errors.ReadOnlySqlTransaction, psycopg2.errors.AdminShutdown)
139 def check_dispose_pool(self, session: SessionType, e: Exception) -> bool:
140 """For selected exceptions invalidate the SQL session
141 - returns True when a transient sql connection error is detected
142 - returns False otherwise
143 """
145 # Only do this if the config is relevant
146 if not self._dispose_pool_on_exceptions:
147 return False
149 # Make sure we have a SQL-related cause to check, otherwise skip
150 if e.__cause__ and not isinstance(e.__cause__, Exception):
151 return False
153 cause_type = type(e.__cause__)
154 # Let's see if this exception is related to known disconnect exceptions
155 is_connection_error = cause_type in self._dispose_pool_on_exceptions
156 if not is_connection_error:
157 return False
159 # Make sure this connection will not be re-used
160 session.invalidate()
161 LOGGER.info(
162 "Detected a SQL exception related to the connection. Invalidating this connection.",
163 tags=dict(exception=cause_type.__name__),
164 )
166 # Only allow disposal every self.__min_time_between_dispose_in_minutes
167 now = datetime.utcnow()
168 only_if_after = None
170 # Check if we should dispose the rest of the checked in connections
171 with self._last_pool_dispose_time_lock:
172 if self._last_pool_dispose_time:
173 only_if_after = self._last_pool_dispose_time + timedelta(
174 minutes=self._min_time_between_dispose_in_minutes
175 )
176 if only_if_after and now < only_if_after:
177 return True
179 # OK, we haven't disposed the pool recently
180 self._last_pool_dispose_time = now
181 LOGGER.info(
182 "Disposing connection pool. New requests will have a fresh SQL connection.",
183 tags=dict(cooldown_time_in_secs=self._cooldown_time_in_secs),
184 )
185 self._sql_engine.dispose()
187 return True
189 def time_until_active_pool(self) -> timedelta:
190 """The time at which the pool is expected to become
191 active after a pool disposal. This adds small amounts of jitter
192 to help spread out load due to retrying clients
193 """
194 if self._last_pool_dispose_time:
195 time_til_active = self._last_pool_dispose_time + timedelta(seconds=self._cooldown_time_in_secs)
196 if datetime.utcnow() < time_til_active:
197 return timedelta(
198 seconds=self._cooldown_time_in_secs
199 + random.uniform(-self._cooldown_jitter_base_in_secs, self._cooldown_jitter_base_in_secs)
200 )
201 return timedelta(seconds=0)
204def strtobool(val: str) -> bool:
205 """Convert a string representation of truth to true (1) or false (0).
206 True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
207 are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
208 'val' is anything else.
209 """
210 val = val.lower()
211 if val in ("y", "yes", "t", "true", "on", "1"):
212 return True
213 elif val in ("n", "no", "f", "false", "off", "0"):
214 return False
215 else:
216 raise ValueError(f"invalid truth value {val}")
219def parse_list_operations_sort_value(value: str, column: Column[Any]) -> Any:
220 """Convert the string representation of a value to the proper Python type."""
221 python_type = column.property.columns[0].type.python_type
222 if python_type == datetime:
223 return datetime.strptime(value, DATETIME_FORMAT)
224 elif python_type == bool:
225 # Using this distutils function to cover a few different bool representations
226 return strtobool(value)
227 else:
228 return python_type(value)
231def dump_list_operations_token_value(token_value: Any) -> str:
232 """Convert a value to a string for use in the page_token."""
233 if isinstance(token_value, datetime):
234 return datetime.strftime(token_value, DATETIME_FORMAT)
235 else:
236 return str(token_value)
239def build_pagination_clause_for_sort_key(
240 sort_value: Any, previous_sort_values: List[Any], sort_keys: List[SortKey]
241) -> ClauseElement:
242 """Build part of a filter clause to figure out the starting point of the page given
243 by the page_token. See the docstring of build_page_filter for more details."""
244 if len(sort_keys) <= len(previous_sort_values):
245 raise ValueError("Not enough keys to unpack")
247 filter_clause_list = []
248 for i, previous_sort_value in enumerate(previous_sort_values):
249 previous_sort_col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_keys[i].name]
250 filter_clause_list.append(previous_sort_col == previous_sort_value)
251 sort_key = sort_keys[len(previous_sort_values)]
252 sort_col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name]
253 if sort_key.descending:
254 filter_clause_list.append(sort_col < sort_value)
255 else:
256 filter_clause_list.append(sort_col > sort_value)
257 return and_(*filter_clause_list)
260def build_page_filter(page_token: str, sort_keys: List[SortKey]) -> ClauseElement:
261 """Build a filter to determine the starting point of the rows to fetch, based
262 on the page_token.
264 The page_token is directly related to the sort order, and in this way it acts as a
265 "cursor." It is given in the format Xval|Yval|Zval|..., where each element is a value
266 corresponding to an orderable column in the database. If the corresponding rows are
267 X, Y, and Z, then X is the primary sort key, with Y breaking ties between X, and Z
268 breaking ties between X and Y. The corresponding filter clause is then:
270 (X > Xval) OR (X == XVal AND Y > Yval) OR (X == Xval AND Y == Yval AND Z > Zval) ...
271 """
272 # The page token is used as a "cursor" to determine the starting point
273 # of the rows to fetch. It is derived from the sort keys.
274 token_elements = page_token.split("|")
275 if len(token_elements) != len(sort_keys):
276 # It is possible that an extra "|" was in the input
277 # TODO: Handle extra "|"s somehow? Or at least allow escaping them
278 raise InvalidArgumentError(
279 f'Wrong number of "|"-separated elements in page token [{page_token}]. '
280 f"Expected {len(sort_keys)}, got {len(token_elements)}."
281 )
283 sort_key_clause_list = []
284 previous_sort_values: List[Any] = []
285 # Build the compound clause for each sort key in the token
286 for i, sort_key in enumerate(sort_keys):
287 col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name]
288 sort_value = parse_list_operations_sort_value(token_elements[i], col)
289 filter_clause = build_pagination_clause_for_sort_key(sort_value, previous_sort_values, sort_keys)
290 sort_key_clause_list.append(filter_clause)
291 previous_sort_values.append(sort_value)
293 return or_(*sort_key_clause_list)
296def build_page_token(operation: Column[Any], sort_keys: List[SortKey]) -> str:
297 """Use the sort keys to build a page token from the given operation."""
298 token_values = []
299 for sort_key in sort_keys:
300 col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name]
301 col_properties = col.property.columns[0]
302 column_name = col_properties.name
303 table_name = col_properties.table.name
304 if table_name == "operations":
305 token_value = getattr(operation, column_name)
306 elif table_name == "jobs":
307 token_value = getattr(operation.job, column_name)
308 else:
309 raise ValueError("Got invalid table f{table_name} for sort key {sort_key.name} while building page_token")
311 token_values.append(dump_list_operations_token_value(token_value))
313 next_page_token = "|".join(token_values)
314 return next_page_token
317def extract_sort_keys(operation_filters: List[OperationFilter]) -> Tuple[List[SortKey], List[OperationFilter]]:
318 """Splits the operation filters into sort keys and non-sort filters, returning both as
319 separate lists.
321 Sort keys are specified with the "sort_order" parameter in the filter string. Multiple
322 "sort_order"s can appear in the filter string, and all are extracted and returned."""
323 sort_keys = []
324 non_sort_filters = []
325 for op_filter in operation_filters:
326 if op_filter.parameter == "sort_order":
327 if op_filter.operator != operator.eq:
328 raise InvalidArgumentError('sort_order must be specified with the "=" operator.')
329 sort_keys.append(op_filter.value)
330 else:
331 non_sort_filters.append(op_filter)
333 return sort_keys, non_sort_filters
336def build_sort_column_list(sort_keys: List[SortKey]) -> List["ColumnOperators[Any]"]:
337 """Convert the list of sort keys into a list of columns that can be
338 passed to an order_by.
340 This function checks the sort keys to ensure that they are in the
341 parameter-model map and raises an InvalidArgumentError if they are not."""
342 sort_columns: List["ColumnOperators[Any]"] = []
343 for sort_key in sort_keys:
344 try:
345 col = LIST_OPERATIONS_PARAMETER_MODEL_MAP[sort_key.name]
346 if sort_key.descending:
347 sort_columns.append(col.desc())
348 else:
349 sort_columns.append(col.asc())
350 except KeyError:
351 raise InvalidArgumentError(f"[{sort_key.name}] is not a valid sort key.")
352 return sort_columns
355def convert_filter_to_sql_filter(operation_filter: OperationFilter) -> ClauseElement:
356 """Convert the internal representation of a filter to a representation that SQLAlchemy
357 can understand. The return type is a "ColumnElement," per the end of this section in
358 the SQLAlchemy docs: https://docs.sqlalchemy.org/en/13/core/tutorial.html#selecting-specific-columns
360 This function assumes that the parser has appropriately converted the filter
361 value to a Python type that can be compared to the parameter."""
362 try:
363 param = LIST_OPERATIONS_PARAMETER_MODEL_MAP[operation_filter.parameter]
364 except KeyError:
365 raise InvalidArgumentError(f"Invalid parameter: [{operation_filter.parameter}]")
367 if operation_filter.parameter == "command":
368 if operation_filter.operator == operator.eq:
369 return param.like(f"%{operation_filter.value}%")
370 elif operation_filter.operator == operator.ne:
371 return param.notlike(f"%{operation_filter.value}%") # type: ignore[no-any-return]
373 if operation_filter.parameter == "platform":
374 key, value = operation_filter.value.split(":", 1)
375 value_column = LIST_OPERATIONS_PARAMETER_MODEL_MAP["platform-value"]
376 return and_(param == key, operation_filter.operator(value_column, value))
378 # Better type? Returning Any from function declared to return "ClauseElement"
379 return operation_filter.operator(param, operation_filter.value) # type: ignore[no-any-return]
382def build_custom_filters(operation_filters: List[OperationFilter]) -> List[ClauseElement]:
383 return [
384 convert_filter_to_sql_filter(operation_filter)
385 for operation_filter in operation_filters
386 if operation_filter.parameter != "platform"
387 ]