Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/instance.py: 93.02%

258 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17Storage Instances 

18================= 

19Instances of CAS and ByteStream 

20""" 

21 

22import logging 

23from datetime import timedelta 

24from typing import Iterable, Iterator, Optional, Sequence, Tuple 

25 

26from cachetools import TTLCache 

27from grpc import RpcError 

28 

29from buildgrid._exceptions import ( 

30 InvalidArgumentError, 

31 NotFoundError, 

32 OutOfRangeError, 

33 PermissionDeniedError, 

34 RetriableError, 

35) 

36from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import DESCRIPTOR as RE_DESCRIPTOR 

37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

38 BatchReadBlobsResponse, 

39 BatchUpdateBlobsRequest, 

40 BatchUpdateBlobsResponse, 

41 Digest, 

42 DigestFunction, 

43 Directory, 

44 FindMissingBlobsResponse, 

45 GetTreeRequest, 

46 GetTreeResponse, 

47 SymlinkAbsolutePathStrategy, 

48 Tree, 

49) 

50from buildgrid._protos.google.bytestream import bytestream_pb2 as bs_pb2 

51from buildgrid._protos.google.rpc import code_pb2, status_pb2 

52from buildgrid.server.cas.storage.storage_abc import StorageABC, create_write_session 

53from buildgrid.server.metrics_names import ( 

54 CAS_BATCH_READ_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

55 CAS_BATCH_READ_BLOBS_SIZE_BYTES, 

56 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, 

57 CAS_BATCH_UPDATE_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

58 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, 

59 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, 

60 CAS_BYTESTREAM_READ_EXCEPTION_COUNT_METRIC_NAME, 

61 CAS_BYTESTREAM_READ_SIZE_BYTES, 

62 CAS_BYTESTREAM_READ_TIME_METRIC_NAME, 

63 CAS_BYTESTREAM_WRITE_EXCEPTION_COUNT_METRIC_NAME, 

64 CAS_BYTESTREAM_WRITE_SIZE_BYTES, 

65 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, 

66 CAS_DOWNLOADED_BYTES_METRIC_NAME, 

67 CAS_EXCEPTION_COUNT_METRIC_NAME, 

68 CAS_FIND_MISSING_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

69 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

70 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

71 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

72 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, 

73 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, 

74 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, 

75 CAS_GET_TREE_CACHE_HIT, 

76 CAS_GET_TREE_CACHE_MISS, 

77 CAS_GET_TREE_EXCEPTION_COUNT_METRIC_NAME, 

78 CAS_GET_TREE_TIME_METRIC_NAME, 

79 CAS_UPLOADED_BYTES_METRIC_NAME, 

80) 

81from buildgrid.server.metrics_utils import ( 

82 Counter, 

83 Distribution, 

84 DurationMetric, 

85 ExceptionCounter, 

86 generator_method_duration_metric, 

87 generator_method_exception_counter, 

88) 

89from buildgrid.server.servicer import Instance 

90from buildgrid.settings import HASH, HASH_LENGTH, MAX_REQUEST_COUNT, STREAM_ERROR_RETRY_PERIOD 

91from buildgrid.utils import create_digest, get_unique_objects_by_attribute 

92 

93LOGGER = logging.getLogger(__name__) 

94 

95EMPTY_BLOB = b"" 

96EMPTY_BLOB_DIGEST: Digest = create_digest(EMPTY_BLOB) 

97 

98 

99class ContentAddressableStorageInstance(Instance): 

100 SERVICE_NAME = RE_DESCRIPTOR.services_by_name["ContentAddressableStorage"].full_name 

101 

102 def __init__( 

103 self, 

104 storage: StorageABC, 

105 read_only: bool = False, 

106 tree_cache_size: Optional[int] = None, 

107 tree_cache_ttl_minutes: float = 60, 

108 ) -> None: 

109 self._storage = storage 

110 self.__read_only = read_only 

111 

112 self._tree_cache: Optional[TTLCache[Tuple[str, int], Digest]] = None 

113 if tree_cache_size: 

114 self._tree_cache = TTLCache(tree_cache_size, tree_cache_ttl_minutes * 60) 

115 

116 # --- Public API --- 

117 

118 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

119 def set_instance_name(self, instance_name: str) -> None: 

120 super().set_instance_name(instance_name) 

121 self._storage.set_instance_name(instance_name) 

122 

123 def start(self) -> None: 

124 self._storage.start() 

125 

126 def stop(self) -> None: 

127 self._storage.stop() 

128 LOGGER.info(f"Stopped CAS instance for '{self._instance_name}'") 

129 

130 def hash_type(self) -> "DigestFunction.Value.ValueType": 

131 return self._storage.hash_type() 

132 

133 def max_batch_total_size_bytes(self) -> int: 

134 return self._storage.max_batch_total_size_bytes() 

135 

136 def symlink_absolute_path_strategy(self) -> "SymlinkAbsolutePathStrategy.Value.ValueType": 

137 return self._storage.symlink_absolute_path_strategy() 

138 

139 find_missing_blobs_ignored_exceptions = (RetriableError,) 

140 

141 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, instanced=True) 

142 @ExceptionCounter( 

143 CAS_FIND_MISSING_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

144 ignored_exceptions=find_missing_blobs_ignored_exceptions, 

145 ) 

146 def find_missing_blobs(self, blob_digests: Sequence[Digest]) -> FindMissingBlobsResponse: 

147 storage = self._storage 

148 blob_digests = list(get_unique_objects_by_attribute(blob_digests, "hash")) 

149 missing_blobs = storage.missing_blobs(blob_digests) 

150 

151 num_blobs_in_request = len(blob_digests) 

152 if num_blobs_in_request > 0: 

153 num_blobs_missing = len(missing_blobs) 

154 percent_missing = float((num_blobs_missing / num_blobs_in_request) * 100) 

155 

156 with Distribution( 

157 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, instance_name=self._instance_name 

158 ) as metric_num_requested: 

159 metric_num_requested.count = float(num_blobs_in_request) 

160 

161 with Distribution( 

162 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, instance_name=self._instance_name 

163 ) as metric_num_missing: 

164 metric_num_missing.count = float(num_blobs_missing) 

165 

166 with Distribution( 

167 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, instance_name=self._instance_name 

168 ) as metric_percent_missing: 

169 metric_percent_missing.count = percent_missing 

170 

171 for digest in blob_digests: 

172 with Distribution( 

173 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, instance_name=self._instance_name 

174 ) as metric_requested_blob_size: 

175 metric_requested_blob_size.count = float(digest.size_bytes) 

176 

177 for digest in missing_blobs: 

178 with Distribution( 

179 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, instance_name=self._instance_name 

180 ) as metric_missing_blob_size: 

181 metric_missing_blob_size.count = float(digest.size_bytes) 

182 

183 return FindMissingBlobsResponse(missing_blob_digests=missing_blobs) 

184 

185 batch_update_blobs_ignored_exceptions = (RetriableError,) 

186 

187 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, instanced=True) 

188 @ExceptionCounter( 

189 CAS_BATCH_UPDATE_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

190 ignored_exceptions=batch_update_blobs_ignored_exceptions, 

191 ) 

192 def batch_update_blobs(self, requests: Sequence[BatchUpdateBlobsRequest.Request]) -> BatchUpdateBlobsResponse: 

193 if self.__read_only: 

194 raise PermissionDeniedError(f"CAS instance {self._instance_name} is read-only") 

195 

196 storage = self._storage 

197 store = [] 

198 for request_proto in get_unique_objects_by_attribute(requests, "digest.hash"): 

199 store.append((request_proto.digest, request_proto.data)) 

200 

201 with Distribution( 

202 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, instance_name=self._instance_name 

203 ) as metric_blob_size: 

204 metric_blob_size.count = float(request_proto.digest.size_bytes) 

205 

206 response = BatchUpdateBlobsResponse() 

207 statuses = storage.bulk_update_blobs(store) 

208 

209 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter: 

210 for (digest, _), status in zip(store, statuses): 

211 response_proto = response.responses.add() 

212 response_proto.digest.CopyFrom(digest) 

213 response_proto.status.CopyFrom(status) 

214 if response_proto.status.code == 0: 

215 bytes_counter.increment(response_proto.digest.size_bytes) 

216 

217 return response 

218 

219 batch_read_blobs_ignored_exceptions = (RetriableError,) 

220 

221 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, instanced=True) 

222 @ExceptionCounter( 

223 CAS_BATCH_READ_BLOBS_EXCEPTION_COUNT_METRIC_NAME, 

224 ignored_exceptions=batch_read_blobs_ignored_exceptions, 

225 ) 

226 def batch_read_blobs(self, digests: Sequence[Digest]) -> BatchReadBlobsResponse: 

227 storage = self._storage 

228 

229 max_batch_size = storage.max_batch_total_size_bytes() 

230 

231 # Only process unique digests 

232 good_digests = [] 

233 bad_digests = [] 

234 requested_bytes = 0 

235 for digest in get_unique_objects_by_attribute(digests, "hash"): 

236 if len(digest.hash) != HASH_LENGTH: 

237 bad_digests.append(digest) 

238 else: 

239 good_digests.append(digest) 

240 requested_bytes += digest.size_bytes 

241 

242 if requested_bytes > max_batch_size: 

243 raise InvalidArgumentError( 

244 "Combined total size of blobs exceeds " 

245 "server limit. " 

246 f"({requested_bytes} > {max_batch_size} [byte])" 

247 ) 

248 

249 if len(good_digests) > 0: 

250 blobs_read = storage.bulk_read_blobs(good_digests) 

251 else: 

252 blobs_read = {} 

253 

254 response = BatchReadBlobsResponse() 

255 

256 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter: 

257 for digest in good_digests: 

258 response_proto = response.responses.add() 

259 response_proto.digest.CopyFrom(digest) 

260 

261 if digest.hash in blobs_read and blobs_read[digest.hash] is not None: 

262 response_proto.data = blobs_read[digest.hash] 

263 status_code = code_pb2.OK 

264 bytes_counter.increment(digest.size_bytes) 

265 

266 with Distribution( 

267 CAS_BATCH_READ_BLOBS_SIZE_BYTES, instance_name=self._instance_name 

268 ) as metric_blob_size: 

269 metric_blob_size.count = float(digest.size_bytes) 

270 else: 

271 status_code = code_pb2.NOT_FOUND 

272 LOGGER.info(f"Blob not found: {digest.hash}/{digest.size_bytes}, from BatchReadBlobs") 

273 

274 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

275 

276 for digest in bad_digests: 

277 response_proto = response.responses.add() 

278 response_proto.digest.CopyFrom(digest) 

279 status_code = code_pb2.INVALID_ARGUMENT 

280 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

281 

282 return response 

283 

284 def lookup_tree_cache(self, root_digest: Digest) -> Optional[Tree]: 

285 """Find full Tree from cache""" 

286 if self._tree_cache is None: 

287 return None 

288 tree = None 

289 if response_digest := self._tree_cache.get((root_digest.hash, root_digest.size_bytes)): 

290 tree = self._storage.get_message(response_digest, Tree) 

291 if tree is None: 

292 self._tree_cache.pop((root_digest.hash, root_digest.size_bytes)) 

293 

294 metric = CAS_GET_TREE_CACHE_HIT if tree is not None else CAS_GET_TREE_CACHE_MISS 

295 with Counter(metric) as counter: 

296 counter.increment(1) 

297 return tree 

298 

299 def put_tree_cache(self, root_digest: Digest, root: Directory, children: Iterable[Directory]) -> None: 

300 """Put Tree with a full list of directories into CAS""" 

301 if self._tree_cache is None: 

302 return 

303 tree = Tree(root=root, children=children) 

304 tree_digest = self._storage.put_message(tree) 

305 self._tree_cache[(root_digest.hash, root_digest.size_bytes)] = tree_digest 

306 

307 get_tree_ignored_exceptions = (NotFoundError, RetriableError) 

308 

309 @generator_method_duration_metric(CAS_GET_TREE_TIME_METRIC_NAME) 

310 @generator_method_exception_counter( 

311 CAS_GET_TREE_EXCEPTION_COUNT_METRIC_NAME, 

312 ignored_exceptions=get_tree_ignored_exceptions, 

313 ) 

314 def get_tree(self, request: GetTreeRequest) -> Iterator[GetTreeResponse]: 

315 storage = self._storage 

316 

317 if not request.page_size: 

318 request.page_size = MAX_REQUEST_COUNT 

319 

320 if tree := self.lookup_tree_cache(request.root_digest): 

321 # Cache hit, yield responses based on page size 

322 directories = [tree.root] 

323 directories.extend(tree.children) 

324 yield from ( 

325 GetTreeResponse(directories=directories[start : start + request.page_size]) 

326 for start in range(0, len(directories), request.page_size) 

327 ) 

328 return 

329 

330 results = [] 

331 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter: 

332 response = GetTreeResponse() 

333 for dir in storage.get_tree(request.root_digest): 

334 bytes_counter.increment(sum(directory.digest.size_bytes for directory in dir.directories)) 

335 response.directories.append(dir) 

336 results.append(dir) 

337 if len(response.directories) >= request.page_size: 

338 yield response 

339 response.Clear() 

340 

341 bytes_counter.increment(request.root_digest.size_bytes) 

342 if response.directories: 

343 yield response 

344 if results: 

345 self.put_tree_cache(request.root_digest, results[0], results[1:]) 

346 

347 

348class ByteStreamInstance(Instance): 

349 SERVICE_NAME = bs_pb2.DESCRIPTOR.services_by_name["ByteStream"].full_name 

350 

351 BLOCK_SIZE = 1 * 1024 * 1024 # 1 MB block size 

352 

353 def __init__( 

354 self, 

355 storage: StorageABC, 

356 read_only: bool = False, 

357 disable_overwrite_early_return: bool = False, 

358 ) -> None: 

359 self._storage = storage 

360 self._query_activity_timeout = 30 

361 

362 self.__read_only = read_only 

363 

364 # If set, prevents `ByteStream.Write()` from returning without 

365 # reading all the client's `WriteRequests` for a digest that is 

366 # already in storage (i.e. not follow the REAPI-specified 

367 # behavior). 

368 self.__disable_overwrite_early_return = disable_overwrite_early_return 

369 # (Should only be used to work around issues with implementations 

370 # that treat the server half-closing its end of the gRPC stream 

371 # as a HTTP/2 stream error.) 

372 

373 # --- Public API --- 

374 

375 def start(self) -> None: 

376 self._storage.start() 

377 

378 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

379 def set_instance_name(self, instance_name: str) -> None: 

380 super().set_instance_name(instance_name) 

381 self._storage.set_instance_name(instance_name) 

382 

383 bytestream_read_ignored_exceptions = (NotFoundError, RetriableError) 

384 

385 @generator_method_duration_metric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME) 

386 @generator_method_exception_counter( 

387 CAS_BYTESTREAM_READ_EXCEPTION_COUNT_METRIC_NAME, 

388 ignored_exceptions=bytestream_read_ignored_exceptions, 

389 ) 

390 def read_cas_blob(self, digest: Digest, read_offset: int, read_limit: int) -> Iterator[bs_pb2.ReadResponse]: 

391 # Check the given read offset and limit. 

392 if read_offset < 0 or read_offset > digest.size_bytes: 

393 raise OutOfRangeError("Read offset out of range") 

394 

395 elif read_limit == 0: 

396 bytes_remaining = digest.size_bytes - read_offset 

397 

398 elif read_limit > 0: 

399 bytes_remaining = read_limit 

400 

401 else: 

402 raise InvalidArgumentError("Negative read_limit is invalid") 

403 

404 # Read the blob from storage and send its contents to the client. 

405 result = self._storage.get_blob(digest) 

406 if result is None: 

407 raise NotFoundError(f"Blob not found: {digest.hash}/{digest.size_bytes}, from Bytestream.Read") 

408 

409 try: 

410 if read_offset > 0: 

411 result.seek(read_offset) 

412 

413 with Distribution( 

414 metric_name=CAS_BYTESTREAM_READ_SIZE_BYTES, instance_name=self._instance_name 

415 ) as metric_blob_size: 

416 metric_blob_size.count = float(digest.size_bytes) 

417 

418 with Counter( 

419 metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name 

420 ) as bytes_counter: 

421 while bytes_remaining > 0: 

422 block_data = result.read(min(self.BLOCK_SIZE, bytes_remaining)) 

423 yield bs_pb2.ReadResponse(data=block_data) 

424 bytes_counter.increment(len(block_data)) 

425 bytes_remaining -= self.BLOCK_SIZE 

426 finally: 

427 result.close() 

428 

429 bytestream_write_ignored_exceptions = (NotFoundError, RetriableError) 

430 

431 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, instanced=True) 

432 @ExceptionCounter( 

433 CAS_BYTESTREAM_WRITE_EXCEPTION_COUNT_METRIC_NAME, 

434 ignored_exceptions=bytestream_write_ignored_exceptions, 

435 ) 

436 def write_cas_blob( 

437 self, digest_hash: str, digest_size: str, requests: Iterator[bs_pb2.WriteRequest] 

438 ) -> bs_pb2.WriteResponse: 

439 if self.__read_only: 

440 raise PermissionDeniedError(f"ByteStream instance {self._instance_name} is read-only") 

441 

442 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

443 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

444 

445 digest = Digest(hash=digest_hash, size_bytes=int(digest_size)) 

446 

447 with Distribution( 

448 metric_name=CAS_BYTESTREAM_WRITE_SIZE_BYTES, instance_name=self._instance_name 

449 ) as metric_blob_size: 

450 metric_blob_size.count = float(digest.size_bytes) 

451 

452 if self._storage.has_blob(digest): 

453 # According to the REAPI specification: 

454 # "When attempting an upload, if another client has already 

455 # completed the upload (which may occur in the middle of a single 

456 # upload if another client uploads the same blob concurrently), 

457 # the request will terminate immediately [...]". 

458 # 

459 # However, half-closing the stream can be problematic with some 

460 # intermediaries like HAProxy. 

461 # (https://github.com/haproxy/haproxy/issues/1219) 

462 # 

463 # If half-closing the stream is not allowed, we read and drop 

464 # all the client's messages before returning, still saving 

465 # the cost of a write to storage. 

466 if self.__disable_overwrite_early_return: 

467 try: 

468 for request in requests: 

469 if request.finish_write: 

470 break 

471 continue 

472 except RpcError: 

473 msg = "ByteStream client disconnected whilst streaming requests, upload cancelled." 

474 LOGGER.debug(msg) 

475 raise RetriableError(msg, retry_period=timedelta(seconds=STREAM_ERROR_RETRY_PERIOD)) 

476 

477 return bs_pb2.WriteResponse(committed_size=digest.size_bytes) 

478 

479 # Start the write session and write the first request's data. 

480 bytes_counter = Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) 

481 write_session = create_write_session(digest) 

482 with bytes_counter, write_session: 

483 computed_hash = HASH() 

484 

485 # Handle subsequent write requests. 

486 try: 

487 for request in requests: 

488 write_session.write(request.data) 

489 

490 computed_hash.update(request.data) 

491 bytes_counter.increment(len(request.data)) 

492 

493 if request.finish_write: 

494 break 

495 except RpcError: 

496 write_session.close() 

497 msg = "ByteStream client disconnected whilst streaming requests, upload cancelled." 

498 LOGGER.debug(msg) 

499 raise RetriableError(msg, retry_period=timedelta(seconds=STREAM_ERROR_RETRY_PERIOD)) 

500 

501 # Check that the data matches the provided digest. 

502 if bytes_counter.count != digest.size_bytes: 

503 raise NotImplementedError( 

504 "Cannot close stream before finishing write, " 

505 f"got {bytes_counter.count} bytes but expected {digest.size_bytes}" 

506 ) 

507 

508 if computed_hash.hexdigest() != digest.hash: 

509 raise InvalidArgumentError("Data does not match hash") 

510 

511 self._storage.commit_write(digest, write_session) 

512 return bs_pb2.WriteResponse(committed_size=int(bytes_counter.count))