Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/sql.py: 92.68%
123 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-02-11 15:07 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2025-02-11 15:07 +0000
1# Copyright (C) 2024 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16"""
17SQL Storage
18==================
20A CAS storage which stores blobs in a SQL database
22"""
24import itertools
25from io import BytesIO
26from typing import IO, Iterator, Sequence, TypedDict, cast
28from sqlalchemy import delete, func, select
29from sqlalchemy.dialects.postgresql import insert as postgresql_insert
30from sqlalchemy.dialects.sqlite import insert as sqlite_insert
31from sqlalchemy.exc import DBAPIError
32from sqlalchemy.orm.exc import StaleDataError
34from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest
35from buildgrid._protos.google.rpc import code_pb2
36from buildgrid._protos.google.rpc.status_pb2 import Status
37from buildgrid.server.cas.storage.storage_abc import StorageABC
38from buildgrid.server.decorators import timed
39from buildgrid.server.exceptions import StorageFullError
40from buildgrid.server.logging import buildgrid_logger
41from buildgrid.server.metrics_names import METRIC
42from buildgrid.server.sql.models import BlobEntry
43from buildgrid.server.sql.provider import SqlProvider
44from buildgrid.server.utils.digests import validate_digest_data
46LOGGER = buildgrid_logger(__name__)
49class DigestRow(TypedDict):
50 digest_hash: str
51 digest_size_bytes: int
52 data: bytes
55class SQLStorage(StorageABC):
56 TYPE = "SQL"
58 def __init__(self, sql_provider: SqlProvider, *, sql_ro_provider: SqlProvider | None = None) -> None:
59 self._sql = sql_provider
60 self._sql_ro = sql_ro_provider or sql_provider
61 self._inclause_limit = self._sql.default_inlimit
63 supported_dialects = ["postgresql", "sqlite"]
65 if self._sql.dialect not in supported_dialects:
66 raise RuntimeError(
67 f"Unsupported dialect {self._sql.dialect}."
68 f"SQLStorage only supports the following dialects: {supported_dialects}"
69 )
71 # Make a test query against the database to ensure the connection is valid
72 with self._sql.session() as session:
73 session.query(BlobEntry).first()
75 def _sqlite_bulk_insert(self, new_rows: list[DigestRow]) -> None:
76 with self._sql.session() as session:
77 session.execute(sqlite_insert(BlobEntry).values(new_rows).on_conflict_do_nothing())
79 def _postgresql_bulk_insert(self, new_rows: list[DigestRow]) -> None:
80 with self._sql.session() as session:
81 session.execute(postgresql_insert(BlobEntry).values(new_rows).on_conflict_do_nothing())
83 def _bulk_insert(self, digests: list[tuple[Digest, bytes]]) -> None:
84 # Sort digests by hash to ensure consistent order to minimize deadlocks
85 # when BatchUpdateBlobs requests have overlapping blobs
86 new_rows: list[DigestRow] = [
87 {"digest_hash": digest.hash, "digest_size_bytes": digest.size_bytes, "data": blob}
88 for (digest, blob) in sorted(digests, key=lambda x: x[0].hash)
89 ]
91 if self._sql.dialect == "sqlite":
92 self._sqlite_bulk_insert(new_rows)
93 elif self._sql.dialect == "postgresql":
94 self._postgresql_bulk_insert(new_rows)
95 else:
96 raise RuntimeError(f"Unsupported dialect {self._sql.dialect} for bulk_insert")
98 @timed(METRIC.STORAGE.STAT_DURATION, type=TYPE)
99 def has_blob(self, digest: Digest) -> bool:
100 statement = select(func.count(BlobEntry.digest_hash)).where(BlobEntry.digest_hash == digest.hash)
101 with self._sql_ro.session() as session:
102 return session.execute(statement).scalar() == 1
104 @timed(METRIC.STORAGE.READ_DURATION, type=TYPE)
105 def get_blob(self, digest: Digest) -> IO[bytes] | None:
106 statement = select(BlobEntry.data).where(BlobEntry.digest_hash == digest.hash)
107 with self._sql_ro.session() as session:
108 result = session.execute(statement).scalar()
109 if result is not None:
110 return BytesIO(result)
111 return None
113 @timed(METRIC.STORAGE.DELETE_DURATION, type=TYPE)
114 def delete_blob(self, digest: Digest) -> None:
115 statement = delete(BlobEntry).where(BlobEntry.digest_hash == digest.hash)
116 with self._sql.session() as session:
117 # Set synchronize_session to false as we don't have any local session objects
118 # to keep in sync
119 session.execute(statement, execution_options={"synchronize_session": False})
121 @timed(METRIC.STORAGE.WRITE_DURATION, type=TYPE)
122 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None:
123 write_session.seek(0)
124 blob = write_session.read()
125 try:
126 self._bulk_insert([(digest, blob)])
127 except DBAPIError as error:
128 # Error has pgcode attribute (Postgres only)
129 if hasattr(error.orig, "pgcode"):
130 # imported here to avoid global dependency on psycopg2
131 from psycopg2.errors import DiskFull, Error, OutOfMemory
133 # 53100 == DiskFull && 53200 == OutOfMemory
134 original_exception = cast(Error, error.orig)
135 if isinstance(original_exception, (DiskFull, OutOfMemory)):
136 raise StorageFullError(
137 f"Postgres Error: {original_exception.pgerror} ({original_exception.pgcode}"
138 ) from error
139 raise
141 def _partitioned_hashes(self, digests: Sequence[Digest]) -> Iterator[Iterator[str]]:
142 """Given a long list of digests, split it into parts no larger than
143 _inclause_limit and yield the hashes in each part.
144 """
145 for part_start in range(0, len(digests), self._inclause_limit):
146 part_end = min(len(digests), part_start + self._inclause_limit)
147 part_digests = itertools.islice(digests, part_start, part_end)
148 yield (digest.hash for digest in part_digests)
150 @timed(METRIC.STORAGE.BULK_STAT_DURATION, type=TYPE)
151 def missing_blobs(self, digests: list[Digest]) -> list[Digest]:
152 found_hashes = set()
153 with self._sql_ro.session() as session:
154 for part in self._partitioned_hashes(digests):
155 stmt = select(BlobEntry.digest_hash).where(BlobEntry.digest_hash.in_(part))
156 for row in session.execute(stmt):
157 found_hashes.add(row.digest_hash)
159 return [digest for digest in digests if digest.hash not in found_hashes]
161 @timed(METRIC.STORAGE.BULK_WRITE_DURATION, type=TYPE)
162 def bulk_update_blobs( # pylint: disable=arguments-renamed
163 self, digest_blob_pairs: list[tuple[Digest, bytes]]
164 ) -> list[Status]:
165 """Implement the StorageABC's bulk_update_blobs method.
167 The StorageABC interface takes in a list of digest/blob pairs and
168 returns a list of results. The list of results MUST be ordered to
169 correspond with the order of the input list."""
170 results = []
172 pairs_to_insert = []
173 for digest, blob in digest_blob_pairs:
174 if validate_digest_data(digest, blob):
175 results.append(Status(code=code_pb2.OK))
176 pairs_to_insert.append((digest, blob))
177 else:
178 results.append(Status(code=code_pb2.INVALID_ARGUMENT, message="Data doesn't match hash"))
180 self._bulk_insert(pairs_to_insert)
181 return results
183 @timed(METRIC.STORAGE.BULK_READ_DURATION, type=TYPE)
184 def bulk_read_blobs(self, digests: list[Digest]) -> dict[str, bytes]:
185 # Fetch all of the digests in the database
186 results: dict[str, bytes] = {}
187 with self._sql_ro.session() as session:
188 results = {
189 digest_hash: data
190 for part in self._partitioned_hashes(digests)
191 for digest_hash, data in session.execute(
192 select(BlobEntry.digest_hash, BlobEntry.data).where(BlobEntry.digest_hash.in_(part))
193 )
194 }
195 return results
197 @timed(METRIC.STORAGE.BULK_DELETE_DURATION, type=TYPE)
198 def bulk_delete(self, digests: list[Digest]) -> list[str]:
199 hashes = [x.hash for x in digests]
201 # Make sure we don't exceed maximum size of an IN clause
202 n = self._inclause_limit
203 hash_chunks = [hashes[i : i + n] for i in range(0, len(hashes), n)]
205 # We will not raise, rollback, or log on StaleDataErrors.
206 # These errors occur when we delete fewer rows than we were expecting.
207 # This is fine, since the missing rows will get deleted eventually.
208 # When running bulk_deletes concurrently, StaleDataErrors
209 # occur too often to log.
210 num_blobs_deleted = 0
211 with self._sql.session(exceptions_to_not_rollback_on=[StaleDataError]) as session:
212 for chunk in hash_chunks:
213 # Do not wait for locks when deleting rows. Skip locked rows to
214 # avoid deadlocks.
215 stmt = delete(BlobEntry).where(
216 BlobEntry.digest_hash.in_(
217 select(BlobEntry.digest_hash)
218 .where(BlobEntry.digest_hash.in_(chunk))
219 .with_for_update(skip_locked=True)
220 )
221 )
222 # Set synchronize_session to false as we don't have any local session objects
223 # to keep in sync
224 num_blobs_deleted += session.execute(stmt, execution_options={"synchronize_session": False}).rowcount
225 LOGGER.info(
226 "blobs deleted from storage.", tags=dict(deleted_count=num_blobs_deleted, digest_count=len(digests))
227 )
229 # bulk_delete is typically expected to return the digests that were not deleted,
230 # but delete only returns the number of rows deleted and not what was/wasn't
231 # deleted. Getting this info would require extra queries, so assume that
232 # everything was either deleted or already deleted. Failures will continue to throw
233 return []