Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/instance.py: 84.44%

270 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-22 21:04 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17Storage Instances 

18================= 

19Instances of CAS and ByteStream 

20""" 

21 

22import logging 

23from typing import List 

24 

25from buildgrid._exceptions import InvalidArgumentError, NotFoundError, OutOfRangeError, PermissionDeniedError 

26from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 as re_pb2 

27from buildgrid._protos.google.bytestream.bytestream_pb2 import ( 

28 QueryWriteStatusResponse, 

29 ReadResponse, 

30 WriteRequest, 

31 WriteResponse 

32) 

33from buildgrid._protos.google.rpc import code_pb2, status_pb2 

34from buildgrid.server.metrics_utils import ( 

35 Counter, 

36 DurationMetric, 

37 ExceptionCounter, 

38 Distribution, 

39 generator_method_duration_metric, 

40 generator_method_exception_counter 

41) 

42from buildgrid.server.metrics_names import ( 

43 CAS_EXCEPTION_COUNT_METRIC_NAME, 

44 CAS_DOWNLOADED_BYTES_METRIC_NAME, 

45 CAS_UPLOADED_BYTES_METRIC_NAME, 

46 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

47 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, 

48 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

49 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

50 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, 

51 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, 

52 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, 

53 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, 

54 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, 

55 CAS_BATCH_READ_BLOBS_SIZE_BYTES, 

56 CAS_GET_TREE_TIME_METRIC_NAME, 

57 CAS_BYTESTREAM_READ_TIME_METRIC_NAME, 

58 CAS_BYTESTREAM_READ_SIZE_BYTES, 

59 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, 

60 CAS_BYTESTREAM_WRITE_SIZE_BYTES, 

61 LOGSTREAM_WRITE_UPLOADED_BYTES_COUNT 

62) 

63from buildgrid.settings import HASH, HASH_LENGTH, MAX_REQUEST_SIZE, MAX_REQUEST_COUNT, MAX_LOGSTREAM_CHUNK_SIZE 

64from buildgrid.utils import create_digest, get_unique_objects_by_attribute 

65 

66 

67def _write_empty_digest_to_storage(storage): 

68 if not storage: 

69 return 

70 # Some clients skip uploading a blob with size 0. 

71 # We pre-insert the empty blob to make sure that it is available. 

72 empty_digest = create_digest(b'') 

73 if not storage.has_blob(empty_digest): 

74 session = storage.begin_write(empty_digest) 

75 storage.commit_write(empty_digest, session) 

76 # This check is useful to confirm the CAS is functioning correctly 

77 # but also to ensure that the access timestamp on the index is 

78 # bumped sufficiently often so that there is less of a chance of 

79 # the empty blob being evicted prematurely. 

80 if not storage.get_blob(empty_digest): 

81 raise NotFoundError("Empty blob not found after writing it to" 

82 " storage. Is your CAS configured correctly?") 

83 

84 

85class ContentAddressableStorageInstance: 

86 

87 def __init__(self, storage, read_only=False): 

88 self.__logger = logging.getLogger(__name__) 

89 

90 self._instance_name = None 

91 

92 self.__storage = storage 

93 

94 self.__read_only = read_only 

95 

96 # --- Public API --- 

97 

98 @property 

99 def instance_name(self): 

100 return self._instance_name 

101 

102 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

103 def register_instance_with_server(self, instance_name, server): 

104 """Names and registers the CAS instance with a given server.""" 

105 if self._instance_name is None: 

106 server.add_cas_instance(self, instance_name) 

107 

108 self._instance_name = instance_name 

109 if not self.__storage.instance_name: 

110 self.__storage.set_instance_name(instance_name) 

111 

112 else: 

113 raise AssertionError("Instance already registered") 

114 

115 def setup_grpc(self): 

116 self.__storage.setup_grpc() 

117 

118 # This is a check to ensure the CAS is functional, as well as make 

119 # sure that the empty digest is pre-populated (some tools don't 

120 # upload it, so we need to). It is done here since it needs to happen 

121 # after gRPC initialization in the case of a remote CAS backend. 

122 _write_empty_digest_to_storage(self.__storage) 

123 

124 def hash_type(self): 

125 return self.__storage.hash_type() 

126 

127 def max_batch_total_size_bytes(self): 

128 return self.__storage.max_batch_total_size_bytes() 

129 

130 def symlink_absolute_path_strategy(self): 

131 return self.__storage.symlink_absolute_path_strategy() 

132 

133 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, instanced=True) 

134 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

135 def find_missing_blobs(self, blob_digests): 

136 storage = self.__storage 

137 blob_digests = list(get_unique_objects_by_attribute(blob_digests, "hash")) 

138 missing_blobs = storage.missing_blobs(blob_digests) 

139 

140 num_blobs_in_request = len(blob_digests) 

141 if num_blobs_in_request > 0: 

142 num_blobs_missing = len(missing_blobs) 

143 percent_missing = float((num_blobs_missing / num_blobs_in_request) * 100) 

144 

145 with Distribution(CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

146 instance_name=self._instance_name) as metric_num_requested: 

147 metric_num_requested.count = float(num_blobs_in_request) 

148 

149 with Distribution(CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

150 instance_name=self._instance_name) as metric_num_missing: 

151 metric_num_missing.count = float(num_blobs_missing) 

152 

153 with Distribution(CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

154 instance_name=self._instance_name) as metric_percent_missing: 

155 metric_percent_missing.count = percent_missing 

156 

157 for digest in blob_digests: 

158 with Distribution(CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, 

159 instance_name=self._instance_name) as metric_requested_blob_size: 

160 metric_requested_blob_size.count = float(digest.size_bytes) 

161 

162 for digest in missing_blobs: 

163 with Distribution(CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, 

164 instance_name=self._instance_name) as metric_missing_blob_size: 

165 metric_missing_blob_size.count = float(digest.size_bytes) 

166 

167 return re_pb2.FindMissingBlobsResponse(missing_blob_digests=missing_blobs) 

168 

169 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, instanced=True) 

170 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

171 def batch_update_blobs(self, requests): 

172 if self.__read_only: 

173 raise PermissionDeniedError(f"CAS instance {self._instance_name} is read-only") 

174 

175 storage = self.__storage 

176 store = [] 

177 for request_proto in get_unique_objects_by_attribute(requests, "digest.hash"): 

178 store.append((request_proto.digest, request_proto.data)) 

179 

180 with Distribution(CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, 

181 instance_name=self._instance_name) as metric_blob_size: 

182 metric_blob_size.count = float(request_proto.digest.size_bytes) 

183 

184 response = re_pb2.BatchUpdateBlobsResponse() 

185 statuses = storage.bulk_update_blobs(store) 

186 

187 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, 

188 instance_name=self._instance_name) as bytes_counter: 

189 for (digest, _), status in zip(store, statuses): 

190 response_proto = response.responses.add() 

191 response_proto.digest.CopyFrom(digest) 

192 response_proto.status.CopyFrom(status) 

193 if response_proto.status.code == 0: 

194 bytes_counter.increment(response_proto.digest.size_bytes) 

195 

196 return response 

197 

198 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, instanced=True) 

199 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

200 def batch_read_blobs(self, digests): 

201 storage = self.__storage 

202 

203 response = re_pb2.BatchReadBlobsResponse() 

204 

205 max_batch_size = storage.max_batch_total_size_bytes() 

206 

207 # Only process unique digests 

208 good_digests = [] 

209 bad_digests = [] 

210 requested_bytes = 0 

211 for digest in get_unique_objects_by_attribute(digests, "hash"): 

212 if len(digest.hash) != HASH_LENGTH: 

213 bad_digests.append(digest) 

214 else: 

215 good_digests.append(digest) 

216 requested_bytes += digest.size_bytes 

217 

218 if requested_bytes > max_batch_size: 

219 raise InvalidArgumentError('Combined total size of blobs exceeds ' 

220 'server limit. ' 

221 f'({requested_bytes} > {max_batch_size} [byte])') 

222 

223 blobs_read = {} 

224 if len(good_digests) > 0: 

225 blobs_read = storage.bulk_read_blobs(good_digests) 

226 

227 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

228 instance_name=self._instance_name) as bytes_counter: 

229 for digest in good_digests: 

230 response_proto = response.responses.add() 

231 response_proto.digest.CopyFrom(digest) 

232 

233 if digest.hash in blobs_read and blobs_read[digest.hash] is not None: 

234 response_proto.data = blobs_read[digest.hash].read() 

235 status_code = code_pb2.OK 

236 bytes_counter.increment(digest.size_bytes) 

237 

238 with Distribution(CAS_BATCH_READ_BLOBS_SIZE_BYTES, 

239 instance_name=self._instance_name) as metric_blob_size: 

240 metric_blob_size.count = float(digest.size_bytes) 

241 else: 

242 status_code = code_pb2.NOT_FOUND 

243 

244 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

245 

246 for digest in bad_digests: 

247 response_proto = response.responses.add() 

248 response_proto.digest.CopyFrom(digest) 

249 status_code = code_pb2.INVALID_ARGUMENT 

250 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

251 

252 for blob in blobs_read.values(): 

253 blob.close() 

254 

255 return response 

256 

257 @DurationMetric(CAS_GET_TREE_TIME_METRIC_NAME, instanced=True) 

258 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

259 def get_tree(self, request): 

260 storage = self.__storage 

261 

262 response = re_pb2.GetTreeResponse() 

263 page_size = request.page_size 

264 

265 if not request.page_size: 

266 request.page_size = MAX_REQUEST_COUNT 

267 

268 root_digest = request.root_digest 

269 page_size = request.page_size 

270 

271 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

272 instance_name=self._instance_name) as bytes_counter: 

273 def __get_tree(node_digest): 

274 nonlocal response, page_size, request 

275 

276 if not page_size: 

277 page_size = request.page_size 

278 yield response 

279 response = re_pb2.GetTreeResponse() 

280 

281 if response.ByteSize() >= (MAX_REQUEST_SIZE): 

282 yield response 

283 response = re_pb2.GetTreeResponse() 

284 

285 directory_from_digest = storage.get_message( 

286 node_digest, re_pb2.Directory) 

287 

288 bytes_counter.increment(node_digest.size_bytes) 

289 

290 page_size -= 1 

291 response.directories.extend([directory_from_digest]) 

292 

293 for directory in directory_from_digest.directories: 

294 yield from __get_tree(directory.digest) 

295 

296 yield response 

297 response = re_pb2.GetTreeResponse() 

298 

299 yield from __get_tree(root_digest) 

300 

301 

302class ByteStreamInstance: 

303 

304 BLOCK_SIZE = 1 * 1024 * 1024 # 1 MB block size 

305 

306 def __init__(self, storage=None, read_only=False, stream_storage=None, 

307 disable_overwrite_early_return=False): 

308 self._logger = logging.getLogger(__name__) 

309 

310 self._instance_name = None 

311 

312 self.__storage = storage 

313 self._stream_store = stream_storage 

314 self._query_activity_timeout = 30 

315 

316 self.__read_only = read_only 

317 

318 # If set, prevents `ByteStream.Write()` from returning without 

319 # reading all the client's `WriteRequests` for a digest that is 

320 # already in storage (i.e. not follow the REAPI-specified 

321 # behavior). 

322 self.__disable_overwrite_early_return = disable_overwrite_early_return 

323 # (Should only be used to work around issues with implementations 

324 # that treat the server half-closing its end of the gRPC stream 

325 # as a HTTP/2 stream error.) 

326 

327 # --- Public API --- 

328 

329 @property 

330 def instance_name(self): 

331 return self._instance_name 

332 

333 @instance_name.setter 

334 def instance_name(self, instance_name): 

335 self._instance_name = instance_name 

336 

337 def setup_grpc(self): 

338 if self.__storage: 

339 self.__storage.setup_grpc() 

340 

341 # This is a check to ensure the CAS is functional, as well as make 

342 # sure that the empty digest is pre-populated (some tools don't 

343 # upload it, so we need to). It is done here since it needs to happen 

344 # after gRPC initialization in the case of a remote CAS backend. 

345 _write_empty_digest_to_storage(self.__storage) 

346 

347 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

348 def register_instance_with_server(self, instance_name, server): 

349 """Names and registers the byte-stream instance with a given server.""" 

350 if self._instance_name is None: 

351 server.add_bytestream_instance(self, instance_name) 

352 

353 self._instance_name = instance_name 

354 if self.__storage is not None and not self.__storage.instance_name: 

355 self.__storage.set_instance_name(instance_name) 

356 else: 

357 raise AssertionError("Instance already registered") 

358 

359 def disconnect_logstream_reader(self, read_name: str): 

360 self._logger.info(f"Disconnecting reader from [{read_name}].") 

361 try: 

362 self._stream_store.streaming_client_left(read_name) 

363 except NotFoundError as e: 

364 self._logger.debug(f"Did not disconnect reader: {str(e)}") 

365 

366 @generator_method_duration_metric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME) 

367 @generator_method_exception_counter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

368 def read_cas_blob(self, digest_hash, digest_size, read_offset, read_limit): 

369 # pylint: disable=no-else-raise 

370 if self.__storage is None: 

371 raise InvalidArgumentError( 

372 "ByteStream instance not configured for use with CAS.") 

373 

374 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

375 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

376 

377 digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size)) 

378 

379 # Check the given read offset and limit. 

380 if read_offset < 0 or read_offset > digest.size_bytes: 

381 raise OutOfRangeError("Read offset out of range") 

382 

383 elif read_limit == 0: 

384 bytes_remaining = digest.size_bytes - read_offset 

385 

386 elif read_limit > 0: 

387 bytes_remaining = read_limit 

388 

389 else: 

390 raise InvalidArgumentError("Negative read_limit is invalid") 

391 

392 # Read the blob from storage and send its contents to the client. 

393 result = self.__storage.get_blob(digest) 

394 if result is None: 

395 raise NotFoundError("Blob not found") 

396 

397 elif result.seekable(): 

398 result.seek(read_offset) 

399 

400 else: 

401 result.read(read_offset) 

402 

403 with Distribution(metric_name=CAS_BYTESTREAM_READ_SIZE_BYTES, 

404 instance_name=self._instance_name) as metric_blob_size: 

405 metric_blob_size.count = float(digest.size_bytes) 

406 

407 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

408 instance_name=self._instance_name) as bytes_counter: 

409 while bytes_remaining > 0: 

410 block_data = result.read(min(self.BLOCK_SIZE, bytes_remaining)) 

411 yield ReadResponse(data=block_data) 

412 bytes_counter.increment(len(block_data)) 

413 bytes_remaining -= self.BLOCK_SIZE 

414 

415 result.close() 

416 

417 def read_logstream(self, resource_name, context): 

418 if self._stream_store is None: 

419 raise InvalidArgumentError( 

420 "ByteStream instance not configured for use with LogStream.") 

421 

422 stream_iterator = self._stream_store.read_stream_bytes_blocking_iterator( 

423 resource_name, max_chunk_size=MAX_LOGSTREAM_CHUNK_SIZE, offset=0) 

424 

425 self._stream_store.new_client_streaming(resource_name) 

426 

427 for message in stream_iterator: 

428 if not context.is_active(): 

429 break 

430 yield ReadResponse(data=message) 

431 

432 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, instanced=True) 

433 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

434 def write_cas_blob(self, digest_hash, digest_size, requests): 

435 if self.__read_only: 

436 raise PermissionDeniedError( 

437 f"ByteStream instance {self._instance_name} is read-only") 

438 

439 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

440 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

441 

442 digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size)) 

443 

444 with Distribution(metric_name=CAS_BYTESTREAM_WRITE_SIZE_BYTES, 

445 instance_name=self._instance_name) as metric_blob_size: 

446 metric_blob_size.count = float(digest.size_bytes) 

447 

448 if self.__storage.has_blob(digest): 

449 # According to the REAPI specification: 

450 # "When attempting an upload, if another client has already 

451 # completed the upload (which may occur in the middle of a single 

452 # upload if another client uploads the same blob concurrently), 

453 # the request will terminate immediately [...]". 

454 # 

455 # However, half-closing the stream can be problematic with some 

456 # intermediaries like HAProxy. 

457 # (https://github.com/haproxy/haproxy/issues/1219) 

458 # 

459 # If half-closing the stream is not allowed, we read and drop 

460 # all the client's messages before returning, still saving 

461 # the cost of a write to storage. 

462 if self.__disable_overwrite_early_return: 

463 for request in requests: 

464 if request.finish_write: 

465 break 

466 continue 

467 

468 return WriteResponse(committed_size=digest.size_bytes) 

469 

470 write_session = self.__storage.begin_write(digest) 

471 

472 # Start the write session and write the first request's data. 

473 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, 

474 instance_name=self._instance_name) as bytes_counter: 

475 computed_hash = HASH() 

476 

477 # Handle subsequent write requests. 

478 for request in requests: 

479 write_session.write(request.data) 

480 

481 computed_hash.update(request.data) 

482 bytes_counter.increment(len(request.data)) 

483 

484 if request.finish_write: 

485 break 

486 

487 # Check that the data matches the provided digest. 

488 if bytes_counter.count != digest.size_bytes: 

489 raise NotImplementedError( 

490 "Cannot close stream before finishing write") 

491 

492 if computed_hash.hexdigest() != digest.hash: 

493 raise InvalidArgumentError("Data does not match hash") 

494 

495 self.__storage.commit_write(digest, write_session) 

496 

497 return WriteResponse(committed_size=int(bytes_counter.count)) 

498 

499 def write_logstream(self, resource_name: str, first_request: WriteRequest, 

500 requests: List[WriteRequest]) -> WriteResponse: 

501 if self._stream_store is None: 

502 raise InvalidArgumentError( 

503 "ByteStream instance not configured for use with LogStream.") 

504 

505 with Counter(metric_name=LOGSTREAM_WRITE_UPLOADED_BYTES_COUNT, 

506 instance_name=self.instance_name) as bytes_counter: 

507 self._stream_store.append_to_stream(resource_name, first_request.data) 

508 bytes_counter.increment(len(first_request.data)) 

509 

510 for request in requests: 

511 self._stream_store.append_to_stream(resource_name, request.data) 

512 bytes_counter.increment(len(request.data)) 

513 

514 self._stream_store.append_to_stream(resource_name, None, mark_finished=True) 

515 return WriteResponse(committed_size=int(bytes_counter.count)) 

516 

517 def query_logstream_status(self, resource_name: str, 

518 context) -> QueryWriteStatusResponse: 

519 if self._stream_store is None: 

520 raise InvalidArgumentError( 

521 "ByteStream instance not configured for use with LogStream.") 

522 

523 while context.is_active(): 

524 if self._stream_store.wait_for_streaming_clients(resource_name, self._query_activity_timeout): 

525 streamlength = self._stream_store.writeable_stream_length(resource_name) 

526 return QueryWriteStatusResponse(committed_size=streamlength.length, complete=streamlength.finished) 

527 

528 raise NotFoundError(f"Stream [{resource_name}] didn't become writeable.")