Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/sql.py: 92.68%

123 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-02-11 15:07 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17SQL Storage 

18================== 

19 

20A CAS storage which stores blobs in a SQL database 

21 

22""" 

23 

24import itertools 

25from io import BytesIO 

26from typing import IO, Iterator, Sequence, TypedDict, cast 

27 

28from sqlalchemy import delete, func, select 

29from sqlalchemy.dialects.postgresql import insert as postgresql_insert 

30from sqlalchemy.dialects.sqlite import insert as sqlite_insert 

31from sqlalchemy.exc import DBAPIError 

32from sqlalchemy.orm.exc import StaleDataError 

33 

34from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest 

35from buildgrid._protos.google.rpc import code_pb2 

36from buildgrid._protos.google.rpc.status_pb2 import Status 

37from buildgrid.server.cas.storage.storage_abc import StorageABC 

38from buildgrid.server.decorators import timed 

39from buildgrid.server.exceptions import StorageFullError 

40from buildgrid.server.logging import buildgrid_logger 

41from buildgrid.server.metrics_names import METRIC 

42from buildgrid.server.sql.models import BlobEntry 

43from buildgrid.server.sql.provider import SqlProvider 

44from buildgrid.server.utils.digests import validate_digest_data 

45 

46LOGGER = buildgrid_logger(__name__) 

47 

48 

49class DigestRow(TypedDict): 

50 digest_hash: str 

51 digest_size_bytes: int 

52 data: bytes 

53 

54 

55class SQLStorage(StorageABC): 

56 TYPE = "SQL" 

57 

58 def __init__(self, sql_provider: SqlProvider, *, sql_ro_provider: SqlProvider | None = None) -> None: 

59 self._sql = sql_provider 

60 self._sql_ro = sql_ro_provider or sql_provider 

61 self._inclause_limit = self._sql.default_inlimit 

62 

63 supported_dialects = ["postgresql", "sqlite"] 

64 

65 if self._sql.dialect not in supported_dialects: 

66 raise RuntimeError( 

67 f"Unsupported dialect {self._sql.dialect}." 

68 f"SQLStorage only supports the following dialects: {supported_dialects}" 

69 ) 

70 

71 # Make a test query against the database to ensure the connection is valid 

72 with self._sql.session() as session: 

73 session.query(BlobEntry).first() 

74 

75 def _sqlite_bulk_insert(self, new_rows: list[DigestRow]) -> None: 

76 with self._sql.session() as session: 

77 session.execute(sqlite_insert(BlobEntry).values(new_rows).on_conflict_do_nothing()) 

78 

79 def _postgresql_bulk_insert(self, new_rows: list[DigestRow]) -> None: 

80 with self._sql.session() as session: 

81 session.execute(postgresql_insert(BlobEntry).values(new_rows).on_conflict_do_nothing()) 

82 

83 def _bulk_insert(self, digests: list[tuple[Digest, bytes]]) -> None: 

84 # Sort digests by hash to ensure consistent order to minimize deadlocks 

85 # when BatchUpdateBlobs requests have overlapping blobs 

86 new_rows: list[DigestRow] = [ 

87 {"digest_hash": digest.hash, "digest_size_bytes": digest.size_bytes, "data": blob} 

88 for (digest, blob) in sorted(digests, key=lambda x: x[0].hash) 

89 ] 

90 

91 if self._sql.dialect == "sqlite": 

92 self._sqlite_bulk_insert(new_rows) 

93 elif self._sql.dialect == "postgresql": 

94 self._postgresql_bulk_insert(new_rows) 

95 else: 

96 raise RuntimeError(f"Unsupported dialect {self._sql.dialect} for bulk_insert") 

97 

98 @timed(METRIC.STORAGE.STAT_DURATION, type=TYPE) 

99 def has_blob(self, digest: Digest) -> bool: 

100 statement = select(func.count(BlobEntry.digest_hash)).where(BlobEntry.digest_hash == digest.hash) 

101 with self._sql_ro.session() as session: 

102 return session.execute(statement).scalar() == 1 

103 

104 @timed(METRIC.STORAGE.READ_DURATION, type=TYPE) 

105 def get_blob(self, digest: Digest) -> IO[bytes] | None: 

106 statement = select(BlobEntry.data).where(BlobEntry.digest_hash == digest.hash) 

107 with self._sql_ro.session() as session: 

108 result = session.execute(statement).scalar() 

109 if result is not None: 

110 return BytesIO(result) 

111 return None 

112 

113 @timed(METRIC.STORAGE.DELETE_DURATION, type=TYPE) 

114 def delete_blob(self, digest: Digest) -> None: 

115 statement = delete(BlobEntry).where(BlobEntry.digest_hash == digest.hash) 

116 with self._sql.session() as session: 

117 # Set synchronize_session to false as we don't have any local session objects 

118 # to keep in sync 

119 session.execute(statement, execution_options={"synchronize_session": False}) 

120 

121 @timed(METRIC.STORAGE.WRITE_DURATION, type=TYPE) 

122 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None: 

123 write_session.seek(0) 

124 blob = write_session.read() 

125 try: 

126 self._bulk_insert([(digest, blob)]) 

127 except DBAPIError as error: 

128 # Error has pgcode attribute (Postgres only) 

129 if hasattr(error.orig, "pgcode"): 

130 # imported here to avoid global dependency on psycopg2 

131 from psycopg2.errors import DiskFull, Error, OutOfMemory 

132 

133 # 53100 == DiskFull && 53200 == OutOfMemory 

134 original_exception = cast(Error, error.orig) 

135 if isinstance(original_exception, (DiskFull, OutOfMemory)): 

136 raise StorageFullError( 

137 f"Postgres Error: {original_exception.pgerror} ({original_exception.pgcode}" 

138 ) from error 

139 raise 

140 

141 def _partitioned_hashes(self, digests: Sequence[Digest]) -> Iterator[Iterator[str]]: 

142 """Given a long list of digests, split it into parts no larger than 

143 _inclause_limit and yield the hashes in each part. 

144 """ 

145 for part_start in range(0, len(digests), self._inclause_limit): 

146 part_end = min(len(digests), part_start + self._inclause_limit) 

147 part_digests = itertools.islice(digests, part_start, part_end) 

148 yield (digest.hash for digest in part_digests) 

149 

150 @timed(METRIC.STORAGE.BULK_STAT_DURATION, type=TYPE) 

151 def missing_blobs(self, digests: list[Digest]) -> list[Digest]: 

152 found_hashes = set() 

153 with self._sql_ro.session() as session: 

154 for part in self._partitioned_hashes(digests): 

155 stmt = select(BlobEntry.digest_hash).where(BlobEntry.digest_hash.in_(part)) 

156 for row in session.execute(stmt): 

157 found_hashes.add(row.digest_hash) 

158 

159 return [digest for digest in digests if digest.hash not in found_hashes] 

160 

161 @timed(METRIC.STORAGE.BULK_WRITE_DURATION, type=TYPE) 

162 def bulk_update_blobs( # pylint: disable=arguments-renamed 

163 self, digest_blob_pairs: list[tuple[Digest, bytes]] 

164 ) -> list[Status]: 

165 """Implement the StorageABC's bulk_update_blobs method. 

166 

167 The StorageABC interface takes in a list of digest/blob pairs and 

168 returns a list of results. The list of results MUST be ordered to 

169 correspond with the order of the input list.""" 

170 results = [] 

171 

172 pairs_to_insert = [] 

173 for digest, blob in digest_blob_pairs: 

174 if validate_digest_data(digest, blob): 

175 results.append(Status(code=code_pb2.OK)) 

176 pairs_to_insert.append((digest, blob)) 

177 else: 

178 results.append(Status(code=code_pb2.INVALID_ARGUMENT, message="Data doesn't match hash")) 

179 

180 self._bulk_insert(pairs_to_insert) 

181 return results 

182 

183 @timed(METRIC.STORAGE.BULK_READ_DURATION, type=TYPE) 

184 def bulk_read_blobs(self, digests: list[Digest]) -> dict[str, bytes]: 

185 # Fetch all of the digests in the database 

186 results: dict[str, bytes] = {} 

187 with self._sql_ro.session() as session: 

188 results = { 

189 digest_hash: data 

190 for part in self._partitioned_hashes(digests) 

191 for digest_hash, data in session.execute( 

192 select(BlobEntry.digest_hash, BlobEntry.data).where(BlobEntry.digest_hash.in_(part)) 

193 ) 

194 } 

195 return results 

196 

197 @timed(METRIC.STORAGE.BULK_DELETE_DURATION, type=TYPE) 

198 def bulk_delete(self, digests: list[Digest]) -> list[str]: 

199 hashes = [x.hash for x in digests] 

200 

201 # Make sure we don't exceed maximum size of an IN clause 

202 n = self._inclause_limit 

203 hash_chunks = [hashes[i : i + n] for i in range(0, len(hashes), n)] 

204 

205 # We will not raise, rollback, or log on StaleDataErrors. 

206 # These errors occur when we delete fewer rows than we were expecting. 

207 # This is fine, since the missing rows will get deleted eventually. 

208 # When running bulk_deletes concurrently, StaleDataErrors 

209 # occur too often to log. 

210 num_blobs_deleted = 0 

211 with self._sql.session(exceptions_to_not_rollback_on=[StaleDataError]) as session: 

212 for chunk in hash_chunks: 

213 # Do not wait for locks when deleting rows. Skip locked rows to 

214 # avoid deadlocks. 

215 stmt = delete(BlobEntry).where( 

216 BlobEntry.digest_hash.in_( 

217 select(BlobEntry.digest_hash) 

218 .where(BlobEntry.digest_hash.in_(chunk)) 

219 .with_for_update(skip_locked=True) 

220 ) 

221 ) 

222 # Set synchronize_session to false as we don't have any local session objects 

223 # to keep in sync 

224 num_blobs_deleted += session.execute(stmt, execution_options={"synchronize_session": False}).rowcount 

225 LOGGER.info( 

226 "blobs deleted from storage.", tags=dict(deleted_count=num_blobs_deleted, digest_count=len(digests)) 

227 ) 

228 

229 # bulk_delete is typically expected to return the digests that were not deleted, 

230 # but delete only returns the number of rows deleted and not what was/wasn't 

231 # deleted. Getting this info would require extra queries, so assume that 

232 # everything was either deleted or already deleted. Failures will continue to throw 

233 return []