Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/sql.py: 92.79%

111 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17SQL Storage 

18================== 

19 

20A CAS storage which stores blobs in a SQL database 

21 

22""" 

23 

24import itertools 

25import logging 

26from io import BytesIO 

27from typing import IO, Dict, Iterator, List, Optional, Sequence, Tuple, TypedDict 

28 

29from sqlalchemy import delete, func, select 

30from sqlalchemy.dialects.postgresql import insert as postgresql_insert 

31from sqlalchemy.dialects.sqlite import insert as sqlite_insert 

32from sqlalchemy.exc import DBAPIError 

33from sqlalchemy.orm.exc import StaleDataError 

34 

35from buildgrid._exceptions import StorageFullError 

36from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest 

37from buildgrid._protos.google.rpc import code_pb2 

38from buildgrid._protos.google.rpc.status_pb2 import Status 

39from buildgrid.server.cas.storage.storage_abc import StorageABC 

40from buildgrid.server.persistence.sql.models import BlobEntry 

41from buildgrid.server.sql.provider import SqlProvider 

42from buildgrid.utils import validate_digest_data 

43 

44LOGGER = logging.getLogger(__name__) 

45 

46 

47class DigestRow(TypedDict): 

48 digest_hash: str 

49 digest_size_bytes: int 

50 data: bytes 

51 

52 

53class SQLStorage(StorageABC): 

54 def __init__(self, sql_provider: SqlProvider, *, sql_ro_provider: Optional[SqlProvider] = None) -> None: 

55 self._sql = sql_provider 

56 self._sql_ro = sql_ro_provider or sql_provider 

57 self._inclause_limit = self._sql.default_inlimit 

58 

59 supported_dialects = ["postgresql", "sqlite"] 

60 

61 if self._sql.dialect not in supported_dialects: 

62 raise RuntimeError( 

63 f"Unsupported dialect {self._sql.dialect}." 

64 f"SQLStorage only supports the following dialects: {supported_dialects}" 

65 ) 

66 

67 # Make a test query against the database to ensure the connection is valid 

68 with self._sql.session() as session: 

69 session.query(BlobEntry).first() 

70 

71 def _sqlite_bulk_insert(self, new_rows: List[DigestRow]) -> None: 

72 with self._sql.session() as session: 

73 session.execute(sqlite_insert(BlobEntry).values(new_rows).on_conflict_do_nothing()) 

74 

75 def _postgresql_bulk_insert(self, new_rows: List[DigestRow]) -> None: 

76 with self._sql.session() as session: 

77 session.execute(postgresql_insert(BlobEntry).values(new_rows).on_conflict_do_nothing()) 

78 

79 def _bulk_insert(self, digests: List[Tuple[Digest, bytes]]) -> None: 

80 new_rows: List[DigestRow] = [ 

81 {"digest_hash": digest.hash, "digest_size_bytes": digest.size_bytes, "data": blob} 

82 for (digest, blob) in digests 

83 ] 

84 

85 if self._sql.dialect == "sqlite": 

86 self._sqlite_bulk_insert(new_rows) 

87 elif self._sql.dialect == "postgresql": 

88 self._postgresql_bulk_insert(new_rows) 

89 else: 

90 raise RuntimeError(f"Unsupported dialect {self._sql.dialect} for bulk_insert") 

91 

92 def has_blob(self, digest: Digest) -> bool: 

93 statement = select(func.count(BlobEntry.digest_hash)).where(BlobEntry.digest_hash == digest.hash) 

94 with self._sql_ro.session() as session: 

95 return session.execute(statement).scalar() == 1 

96 

97 def get_blob(self, digest: Digest) -> Optional[IO[bytes]]: 

98 statement = select(BlobEntry.data).where(BlobEntry.digest_hash == digest.hash) 

99 with self._sql_ro.session() as session: 

100 result = session.execute(statement).scalar() 

101 if result is not None: 

102 return BytesIO(result) 

103 return None 

104 

105 def delete_blob(self, digest: Digest) -> None: 

106 statement = delete(BlobEntry).where(BlobEntry.digest_hash == digest.hash) 

107 with self._sql.session() as session: 

108 # Set synchronize_session to false as we don't have any local session objects 

109 # to keep in sync 

110 session.execute(statement, execution_options={"synchronize_session": False}) 

111 

112 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None: 

113 write_session.seek(0) 

114 blob = write_session.read() 

115 try: 

116 self._bulk_insert([(digest, blob)]) 

117 except DBAPIError as error: 

118 # Error has pgcode attribute (Postgres only) 

119 if hasattr(error.orig, "pgcode"): 

120 # imported here to avoid global dependency on psycopg2 

121 from psycopg2.errors import DiskFull, OutOfMemory 

122 

123 # 53100 == DiskFull && 53200 == OutOfMemory 

124 if error.orig.pgerror in [DiskFull, OutOfMemory]: 

125 raise StorageFullError(f"Postgres Error: {error.orig.pgcode}") from error 

126 raise 

127 

128 def _partitioned_hashes(self, digests: Sequence[Digest]) -> Iterator[Iterator[str]]: 

129 """Given a long list of digests, split it into parts no larger than 

130 _inclause_limit and yield the hashes in each part. 

131 """ 

132 for part_start in range(0, len(digests), self._inclause_limit): 

133 part_end = min(len(digests), part_start + self._inclause_limit) 

134 part_digests = itertools.islice(digests, part_start, part_end) 

135 yield (digest.hash for digest in part_digests) 

136 

137 def missing_blobs(self, digests: List[Digest]) -> List[Digest]: 

138 found_hashes = set() 

139 with self._sql_ro.session() as session: 

140 for part in self._partitioned_hashes(digests): 

141 stmt = select(BlobEntry.digest_hash).where(BlobEntry.digest_hash.in_(part)) 

142 for row in session.execute(stmt): 

143 found_hashes.add(row.digest_hash) 

144 

145 return [digest for digest in digests if digest.hash not in found_hashes] 

146 

147 def bulk_update_blobs( # pylint: disable=arguments-renamed 

148 self, digest_blob_pairs: List[Tuple[Digest, bytes]] 

149 ) -> List[Status]: 

150 """Implement the StorageABC's bulk_update_blobs method. 

151 

152 The StorageABC interface takes in a list of digest/blob pairs and 

153 returns a list of results. The list of results MUST be ordered to 

154 correspond with the order of the input list.""" 

155 results = [] 

156 

157 pairs_to_insert = [] 

158 for digest, blob in digest_blob_pairs: 

159 if validate_digest_data(digest, blob): 

160 results.append(Status(code=code_pb2.OK)) 

161 pairs_to_insert.append((digest, blob)) 

162 else: 

163 results.append(Status(code=code_pb2.INVALID_ARGUMENT, message="Data doesn't match hash")) 

164 

165 self._bulk_insert(pairs_to_insert) 

166 return results 

167 

168 def bulk_read_blobs(self, digests: List[Digest]) -> Dict[str, bytes]: 

169 # Fetch all of the digests in the database 

170 results: Dict[str, bytes] = {} 

171 with self._sql_ro.session() as session: 

172 results = { 

173 digest_hash: data 

174 for part in self._partitioned_hashes(digests) 

175 for digest_hash, data in session.execute( 

176 select([BlobEntry.digest_hash, BlobEntry.data]).where(BlobEntry.digest_hash.in_(part)) 

177 ) 

178 } 

179 return results 

180 

181 def bulk_delete(self, digests: List[Digest]) -> List[str]: 

182 hashes = [x.hash for x in digests] 

183 

184 # Make sure we don't exceed maximum size of an IN clause 

185 n = self._inclause_limit 

186 hash_chunks = [hashes[i : i + n] for i in range(0, len(hashes), n)] 

187 

188 # We will not raise, rollback, or log on StaleDataErrors. 

189 # These errors occur when we delete fewer rows than we were expecting. 

190 # This is fine, since the missing rows will get deleted eventually. 

191 # When running bulk_deletes concurrently, StaleDataErrors 

192 # occur too often to log. 

193 num_blobs_deleted = 0 

194 with self._sql.session(exceptions_to_not_rollback_on=[StaleDataError]) as session: 

195 for chunk in hash_chunks: 

196 # Do not wait for locks when deleting rows. Skip locked rows to 

197 # avoid deadlocks. 

198 stmt = delete(BlobEntry).where( 

199 BlobEntry.digest_hash.in_( 

200 select(BlobEntry.digest_hash) 

201 .where(BlobEntry.digest_hash.in_(chunk)) 

202 .with_for_update(skip_locked=True) 

203 ) 

204 ) 

205 # Set synchronize_session to false as we don't have any local session objects 

206 # to keep in sync 

207 num_blobs_deleted += session.execute( 

208 stmt, execution_options={"synchronize_session": False} 

209 ).rowcount # type: ignore 

210 LOGGER.info(f"{num_blobs_deleted}/{len(digests)} blobs deleted from storage") 

211 

212 # bulk_delete is typically expected to return the digests that were not deleted, 

213 # but delete only returns the number of rows deleted and not what was/wasn't 

214 # deleted. Getting this info would require extra queries, so assume that 

215 # everything was either deleted or already deleted. Failures will continue to throw 

216 return []