Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/instance.py: 93.21%

265 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-03-28 16:20 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17Storage Instances 

18================= 

19Instances of CAS and ByteStream 

20""" 

21 

22import logging 

23from datetime import timedelta 

24from typing import Iterable, Iterator, Optional, Sequence, Tuple 

25 

26from cachetools import TTLCache 

27from grpc import RpcError 

28 

29from buildgrid._exceptions import ( 

30 InvalidArgumentError, 

31 NotFoundError, 

32 OutOfRangeError, 

33 PermissionDeniedError, 

34 RetriableError, 

35) 

36from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import DESCRIPTOR as RE_DESCRIPTOR 

37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

38 BatchReadBlobsResponse, 

39 BatchUpdateBlobsRequest, 

40 BatchUpdateBlobsResponse, 

41 Digest, 

42 DigestFunction, 

43 Directory, 

44 FindMissingBlobsResponse, 

45 GetTreeRequest, 

46 GetTreeResponse, 

47 SymlinkAbsolutePathStrategy, 

48 Tree, 

49) 

50from buildgrid._protos.google.bytestream import bytestream_pb2 as bs_pb2 

51from buildgrid._protos.google.rpc import code_pb2, status_pb2 

52from buildgrid.server.cas.storage.storage_abc import StorageABC, create_write_session 

53from buildgrid.server.metrics_names import ( 

54 CAS_BATCH_READ_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

55 CAS_BATCH_READ_BLOBS_SIZE_BYTES, 

56 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, 

57 CAS_BATCH_UPDATE_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

58 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, 

59 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, 

60 CAS_BYTESTREAM_READ_EXCEPTION_COUNT_METRIC_NAME, 

61 CAS_BYTESTREAM_READ_SIZE_BYTES, 

62 CAS_BYTESTREAM_READ_TIME_METRIC_NAME, 

63 CAS_BYTESTREAM_WRITE_EXCEPTION_COUNT_METRIC_NAME, 

64 CAS_BYTESTREAM_WRITE_SIZE_BYTES, 

65 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, 

66 CAS_DOWNLOADED_BYTES_METRIC_NAME, 

67 CAS_EXCEPTION_COUNT_METRIC_NAME, 

68 CAS_FIND_MISSING_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

69 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

70 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

71 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

72 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, 

73 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, 

74 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, 

75 CAS_GET_TREE_CACHE_HIT, 

76 CAS_GET_TREE_CACHE_MISS, 

77 CAS_GET_TREE_EXCEPTION_COUNT_METRIC_NAME, 

78 CAS_GET_TREE_TIME_METRIC_NAME, 

79 CAS_UPLOADED_BYTES_METRIC_NAME, 

80) 

81from buildgrid.server.metrics_utils import ( 

82 Counter, 

83 Distribution, 

84 DurationMetric, 

85 ExceptionCounter, 

86 generator_method_duration_metric, 

87 generator_method_exception_counter, 

88) 

89from buildgrid.server.servicer import Instance 

90from buildgrid.settings import HASH, HASH_LENGTH, MAX_REQUEST_COUNT, STREAM_ERROR_RETRY_PERIOD 

91from buildgrid.utils import create_digest, get_unique_objects_by_attribute 

92 

93LOGGER = logging.getLogger(__name__) 

94 

95EMPTY_BLOB = b"" 

96EMPTY_BLOB_DIGEST: Digest = create_digest(EMPTY_BLOB) 

97 

98 

99class ContentAddressableStorageInstance(Instance): 

100 SERVICE_NAME = RE_DESCRIPTOR.services_by_name["ContentAddressableStorage"].full_name 

101 

102 def __init__( 

103 self, 

104 storage: StorageABC, 

105 read_only: bool = False, 

106 tree_cache_size: Optional[int] = None, 

107 tree_cache_ttl_minutes: float = 60, 

108 ) -> None: 

109 self._storage = storage 

110 self.__read_only = read_only 

111 

112 self._tree_cache: Optional[TTLCache[Tuple[str, int], Digest]] = None 

113 if tree_cache_size: 

114 self._tree_cache = TTLCache(tree_cache_size, tree_cache_ttl_minutes * 60) 

115 

116 # --- Public API --- 

117 

118 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

119 def set_instance_name(self, instance_name: str) -> None: 

120 super().set_instance_name(instance_name) 

121 self._storage.set_instance_name(instance_name) 

122 

123 def start(self) -> None: 

124 self._storage.start() 

125 

126 def stop(self) -> None: 

127 self._storage.stop() 

128 

129 def hash_type(self) -> "DigestFunction.Value.ValueType": 

130 return self._storage.hash_type() 

131 

132 def max_batch_total_size_bytes(self) -> int: 

133 return self._storage.max_batch_total_size_bytes() 

134 

135 def symlink_absolute_path_strategy(self) -> "SymlinkAbsolutePathStrategy.Value.ValueType": 

136 return self._storage.symlink_absolute_path_strategy() 

137 

138 find_missing_blobs_ignored_exceptions = (RetriableError,) 

139 

140 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, instanced=True) 

141 @ExceptionCounter( 

142 CAS_FIND_MISSING_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

143 ignored_exceptions=find_missing_blobs_ignored_exceptions, 

144 ) 

145 def find_missing_blobs(self, blob_digests: Sequence[Digest]) -> FindMissingBlobsResponse: 

146 storage = self._storage 

147 blob_digests = list(get_unique_objects_by_attribute(blob_digests, "hash")) 

148 missing_blobs = storage.missing_blobs(blob_digests) 

149 

150 num_blobs_in_request = len(blob_digests) 

151 if num_blobs_in_request > 0: 

152 num_blobs_missing = len(missing_blobs) 

153 percent_missing = float((num_blobs_missing / num_blobs_in_request) * 100) 

154 

155 with Distribution( 

156 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, instance_name=self._instance_name 

157 ) as metric_num_requested: 

158 metric_num_requested.count = float(num_blobs_in_request) 

159 

160 with Distribution( 

161 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, instance_name=self._instance_name 

162 ) as metric_num_missing: 

163 metric_num_missing.count = float(num_blobs_missing) 

164 

165 with Distribution( 

166 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, instance_name=self._instance_name 

167 ) as metric_percent_missing: 

168 metric_percent_missing.count = percent_missing 

169 

170 for digest in blob_digests: 

171 with Distribution( 

172 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, instance_name=self._instance_name 

173 ) as metric_requested_blob_size: 

174 metric_requested_blob_size.count = float(digest.size_bytes) 

175 

176 for digest in missing_blobs: 

177 with Distribution( 

178 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, instance_name=self._instance_name 

179 ) as metric_missing_blob_size: 

180 metric_missing_blob_size.count = float(digest.size_bytes) 

181 

182 return FindMissingBlobsResponse(missing_blob_digests=missing_blobs) 

183 

184 batch_update_blobs_ignored_exceptions = (RetriableError,) 

185 

186 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, instanced=True) 

187 @ExceptionCounter( 

188 CAS_BATCH_UPDATE_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

189 ignored_exceptions=batch_update_blobs_ignored_exceptions, 

190 ) 

191 def batch_update_blobs(self, requests: Sequence[BatchUpdateBlobsRequest.Request]) -> BatchUpdateBlobsResponse: 

192 if self.__read_only: 

193 raise PermissionDeniedError(f"CAS instance {self._instance_name} is read-only") 

194 

195 storage = self._storage 

196 store = [] 

197 for request_proto in get_unique_objects_by_attribute(requests, "digest.hash"): 

198 store.append((request_proto.digest, request_proto.data)) 

199 

200 with Distribution( 

201 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, instance_name=self._instance_name 

202 ) as metric_blob_size: 

203 metric_blob_size.count = float(request_proto.digest.size_bytes) 

204 

205 response = BatchUpdateBlobsResponse() 

206 statuses = storage.bulk_update_blobs(store) 

207 

208 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter: 

209 for (digest, _), status in zip(store, statuses): 

210 response_proto = response.responses.add() 

211 response_proto.digest.CopyFrom(digest) 

212 response_proto.status.CopyFrom(status) 

213 if response_proto.status.code == 0: 

214 bytes_counter.increment(response_proto.digest.size_bytes) 

215 

216 return response 

217 

218 batch_read_blobs_ignored_exceptions = (RetriableError,) 

219 

220 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, instanced=True) 

221 @ExceptionCounter( 

222 CAS_BATCH_READ_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

223 ignored_exceptions=batch_read_blobs_ignored_exceptions, 

224 ) 

225 def batch_read_blobs(self, digests: Sequence[Digest]) -> BatchReadBlobsResponse: 

226 storage = self._storage 

227 

228 max_batch_size = storage.max_batch_total_size_bytes() 

229 

230 # Only process unique digests 

231 good_digests = [] 

232 bad_digests = [] 

233 requested_bytes = 0 

234 for digest in get_unique_objects_by_attribute(digests, "hash"): 

235 if len(digest.hash) != HASH_LENGTH: 

236 bad_digests.append(digest) 

237 else: 

238 good_digests.append(digest) 

239 requested_bytes += digest.size_bytes 

240 

241 if requested_bytes > max_batch_size: 

242 raise InvalidArgumentError( 

243 "Combined total size of blobs exceeds " 

244 "server limit. " 

245 f"({requested_bytes} > {max_batch_size} [byte])" 

246 ) 

247 

248 if len(good_digests) > 0: 

249 blobs_read = storage.bulk_read_blobs(good_digests) 

250 else: 

251 blobs_read = {} 

252 

253 response = BatchReadBlobsResponse() 

254 

255 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter: 

256 for digest in good_digests: 

257 response_proto = response.responses.add() 

258 response_proto.digest.CopyFrom(digest) 

259 

260 if digest.hash in blobs_read and blobs_read[digest.hash] is not None: 

261 response_proto.data = blobs_read[digest.hash] 

262 status_code = code_pb2.OK 

263 bytes_counter.increment(digest.size_bytes) 

264 

265 with Distribution( 

266 CAS_BATCH_READ_BLOBS_SIZE_BYTES, instance_name=self._instance_name 

267 ) as metric_blob_size: 

268 metric_blob_size.count = float(digest.size_bytes) 

269 else: 

270 status_code = code_pb2.NOT_FOUND 

271 

272 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

273 

274 for digest in bad_digests: 

275 response_proto = response.responses.add() 

276 response_proto.digest.CopyFrom(digest) 

277 status_code = code_pb2.INVALID_ARGUMENT 

278 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

279 

280 return response 

281 

282 def lookup_tree_cache(self, root_digest: Digest) -> Optional[Tree]: 

283 """Find full Tree from cache""" 

284 if self._tree_cache is None: 

285 return None 

286 tree = None 

287 if response_digest := self._tree_cache.get((root_digest.hash, root_digest.size_bytes)): 

288 tree = self._storage.get_message(response_digest, Tree) 

289 if tree is None: 

290 self._tree_cache.pop((root_digest.hash, root_digest.size_bytes)) 

291 

292 metric = CAS_GET_TREE_CACHE_HIT if tree is not None else CAS_GET_TREE_CACHE_MISS 

293 with Counter(metric) as counter: 

294 counter.increment(1) 

295 return tree 

296 

297 def put_tree_cache(self, root_digest: Digest, root: Directory, children: Iterable[Directory]) -> None: 

298 """Put Tree with a full list of directories into CAS""" 

299 if self._tree_cache is None: 

300 return 

301 tree = Tree(root=root, children=children) 

302 tree_digest = self._storage.put_message(tree) 

303 self._tree_cache[(root_digest.hash, root_digest.size_bytes)] = tree_digest 

304 

305 get_tree_ignored_exceptions = (NotFoundError, RetriableError) 

306 

307 @DurationMetric(CAS_GET_TREE_TIME_METRIC_NAME, instanced=True) 

308 @generator_method_exception_counter( 

309 CAS_GET_TREE_EXCEPTION_COUNT_METRIC_NAME, 

310 ignored_exceptions=get_tree_ignored_exceptions, 

311 ) 

312 def get_tree(self, request: GetTreeRequest) -> Iterator[GetTreeResponse]: 

313 storage = self._storage 

314 

315 if not request.page_size: 

316 request.page_size = MAX_REQUEST_COUNT 

317 

318 if tree := self.lookup_tree_cache(request.root_digest): 

319 # Cache hit, yield responses based on page size 

320 directories = [tree.root] 

321 directories.extend(tree.children) 

322 yield from ( 

323 GetTreeResponse(directories=directories[start : start + request.page_size]) 

324 for start in range(0, len(directories), request.page_size) 

325 ) 

326 return 

327 

328 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter: 

329 # From the spec, a NotFound response only occurs if the root directory is missing. 

330 root_directory = storage.get_message(request.root_digest, Directory) 

331 if not root_directory: 

332 raise NotFoundError( 

333 f"Root digest not found: {request.root_digest.hash}/{request.root_digest.size_bytes}" 

334 ) 

335 

336 bytes_counter.increment(request.root_digest.size_bytes) 

337 

338 results = [root_directory] 

339 offset = 0 

340 queue = [subdir.digest for subdir in root_directory.directories] 

341 while queue: 

342 bytes_counter.increment(sum(d.size_bytes for d in queue)) 

343 blobs = storage.bulk_read_blobs(queue) 

344 

345 directories = [Directory.FromString(b) for b in blobs.values()] 

346 queue = [subdir.digest for d in directories for subdir in d.directories] 

347 

348 results.extend(directories) 

349 while len(results) - offset >= request.page_size: 

350 yield GetTreeResponse(directories=results[offset : offset + request.page_size]) 

351 offset += request.page_size 

352 

353 if len(results) - offset > 0: 

354 yield GetTreeResponse(directories=results[offset:]) 

355 if results: 

356 self.put_tree_cache(request.root_digest, results[0], results[1:]) 

357 

358 

359class ByteStreamInstance(Instance): 

360 SERVICE_NAME = bs_pb2.DESCRIPTOR.services_by_name["ByteStream"].full_name 

361 

362 BLOCK_SIZE = 1 * 1024 * 1024 # 1 MB block size 

363 

364 def __init__( 

365 self, 

366 storage: StorageABC, 

367 read_only: bool = False, 

368 disable_overwrite_early_return: bool = False, 

369 ) -> None: 

370 self._storage = storage 

371 self._query_activity_timeout = 30 

372 

373 self.__read_only = read_only 

374 

375 # If set, prevents `ByteStream.Write()` from returning without 

376 # reading all the client's `WriteRequests` for a digest that is 

377 # already in storage (i.e. not follow the REAPI-specified 

378 # behavior). 

379 self.__disable_overwrite_early_return = disable_overwrite_early_return 

380 # (Should only be used to work around issues with implementations 

381 # that treat the server half-closing its end of the gRPC stream 

382 # as a HTTP/2 stream error.) 

383 

384 # --- Public API --- 

385 

386 def start(self) -> None: 

387 self._storage.start() 

388 

389 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

390 def set_instance_name(self, instance_name: str) -> None: 

391 super().set_instance_name(instance_name) 

392 self._storage.set_instance_name(instance_name) 

393 

394 bytestream_read_ignored_exceptions = (NotFoundError, RetriableError) 

395 

396 @generator_method_duration_metric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME) 

397 @generator_method_exception_counter( 

398 CAS_BYTESTREAM_READ_EXCEPTION_COUNT_METRIC_NAME, 

399 ignored_exceptions=bytestream_read_ignored_exceptions, 

400 ) 

401 def read_cas_blob( 

402 self, digest_hash: str, digest_size: str, read_offset: int, read_limit: int 

403 ) -> Iterator[bs_pb2.ReadResponse]: 

404 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

405 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

406 

407 digest = Digest(hash=digest_hash, size_bytes=int(digest_size)) 

408 

409 # Check the given read offset and limit. 

410 if read_offset < 0 or read_offset > digest.size_bytes: 

411 raise OutOfRangeError("Read offset out of range") 

412 

413 elif read_limit == 0: 

414 bytes_remaining = digest.size_bytes - read_offset 

415 

416 elif read_limit > 0: 

417 bytes_remaining = read_limit 

418 

419 else: 

420 raise InvalidArgumentError("Negative read_limit is invalid") 

421 

422 # Read the blob from storage and send its contents to the client. 

423 result = self._storage.get_blob(digest) 

424 if result is None: 

425 raise NotFoundError("Blob not found") 

426 

427 try: 

428 if read_offset > 0: 

429 result.seek(read_offset) 

430 

431 with Distribution( 

432 metric_name=CAS_BYTESTREAM_READ_SIZE_BYTES, instance_name=self._instance_name 

433 ) as metric_blob_size: 

434 metric_blob_size.count = float(digest.size_bytes) 

435 

436 with Counter( 

437 metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name 

438 ) as bytes_counter: 

439 while bytes_remaining > 0: 

440 block_data = result.read(min(self.BLOCK_SIZE, bytes_remaining)) 

441 yield bs_pb2.ReadResponse(data=block_data) 

442 bytes_counter.increment(len(block_data)) 

443 bytes_remaining -= self.BLOCK_SIZE 

444 finally: 

445 result.close() 

446 

447 bytestream_write_ignored_exceptions = (NotFoundError, RetriableError) 

448 

449 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, instanced=True) 

450 @ExceptionCounter( 

451 CAS_BYTESTREAM_WRITE_EXCEPTION_COUNT_METRIC_NAME, 

452 ignored_exceptions=bytestream_write_ignored_exceptions, 

453 ) 

454 def write_cas_blob( 

455 self, digest_hash: str, digest_size: str, requests: Iterator[bs_pb2.WriteRequest] 

456 ) -> bs_pb2.WriteResponse: 

457 if self.__read_only: 

458 raise PermissionDeniedError(f"ByteStream instance {self._instance_name} is read-only") 

459 

460 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

461 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

462 

463 digest = Digest(hash=digest_hash, size_bytes=int(digest_size)) 

464 

465 with Distribution( 

466 metric_name=CAS_BYTESTREAM_WRITE_SIZE_BYTES, instance_name=self._instance_name 

467 ) as metric_blob_size: 

468 metric_blob_size.count = float(digest.size_bytes) 

469 

470 if self._storage.has_blob(digest): 

471 # According to the REAPI specification: 

472 # "When attempting an upload, if another client has already 

473 # completed the upload (which may occur in the middle of a single 

474 # upload if another client uploads the same blob concurrently), 

475 # the request will terminate immediately [...]". 

476 # 

477 # However, half-closing the stream can be problematic with some 

478 # intermediaries like HAProxy. 

479 # (https://github.com/haproxy/haproxy/issues/1219) 

480 # 

481 # If half-closing the stream is not allowed, we read and drop 

482 # all the client's messages before returning, still saving 

483 # the cost of a write to storage. 

484 if self.__disable_overwrite_early_return: 

485 try: 

486 for request in requests: 

487 if request.finish_write: 

488 break 

489 continue 

490 except RpcError: 

491 msg = "ByteStream client disconnected whilst streaming requests, upload cancelled." 

492 LOGGER.debug(msg) 

493 raise RetriableError(msg, retry_period=timedelta(seconds=STREAM_ERROR_RETRY_PERIOD)) 

494 

495 return bs_pb2.WriteResponse(committed_size=digest.size_bytes) 

496 

497 # Start the write session and write the first request's data. 

498 bytes_counter = Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) 

499 write_session = create_write_session(digest) 

500 with bytes_counter, write_session: 

501 computed_hash = HASH() 

502 

503 # Handle subsequent write requests. 

504 try: 

505 for request in requests: 

506 write_session.write(request.data) 

507 

508 computed_hash.update(request.data) 

509 bytes_counter.increment(len(request.data)) 

510 

511 if request.finish_write: 

512 break 

513 except RpcError: 

514 write_session.close() 

515 msg = "ByteStream client disconnected whilst streaming requests, upload cancelled." 

516 LOGGER.debug(msg) 

517 raise RetriableError(msg, retry_period=timedelta(seconds=STREAM_ERROR_RETRY_PERIOD)) 

518 

519 # Check that the data matches the provided digest. 

520 if bytes_counter.count != digest.size_bytes: 

521 raise NotImplementedError( 

522 "Cannot close stream before finishing write, " 

523 f"got {bytes_counter.count} bytes but expected {digest.size_bytes}" 

524 ) 

525 

526 if computed_hash.hexdigest() != digest.hash: 

527 raise InvalidArgumentError("Data does not match hash") 

528 

529 self._storage.commit_write(digest, write_session) 

530 return bs_pb2.WriteResponse(committed_size=int(bytes_counter.count))