Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/instance.py: 87.64%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

267 statements  

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17Storage Instances 

18================= 

19Instances of CAS and ByteStream 

20""" 

21 

22import logging 

23from typing import List 

24 

25from buildgrid._exceptions import InvalidArgumentError, NotFoundError, OutOfRangeError, PermissionDeniedError 

26from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 as re_pb2 

27from buildgrid._protos.google.bytestream.bytestream_pb2 import ( 

28 QueryWriteStatusResponse, 

29 ReadResponse, 

30 WriteRequest, 

31 WriteResponse 

32) 

33from buildgrid._protos.google.rpc import code_pb2, status_pb2 

34from buildgrid.server.metrics_utils import ( 

35 Counter, 

36 DurationMetric, 

37 ExceptionCounter, 

38 Distribution, 

39 generator_method_duration_metric, 

40 generator_method_exception_counter 

41) 

42from buildgrid.server.metrics_names import ( 

43 CAS_EXCEPTION_COUNT_METRIC_NAME, 

44 CAS_DOWNLOADED_BYTES_METRIC_NAME, 

45 CAS_UPLOADED_BYTES_METRIC_NAME, 

46 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

47 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, 

48 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

49 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

50 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, 

51 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, 

52 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, 

53 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, 

54 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, 

55 CAS_BATCH_READ_BLOBS_SIZE_BYTES, 

56 CAS_GET_TREE_TIME_METRIC_NAME, 

57 CAS_BYTESTREAM_READ_TIME_METRIC_NAME, 

58 CAS_BYTESTREAM_READ_SIZE_BYTES, 

59 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, 

60 CAS_BYTESTREAM_WRITE_SIZE_BYTES, 

61 LOGSTREAM_WRITE_UPLOADED_BYTES_COUNT 

62) 

63from buildgrid.settings import HASH, HASH_LENGTH, MAX_REQUEST_SIZE, MAX_REQUEST_COUNT, MAX_LOGSTREAM_CHUNK_SIZE 

64from buildgrid.utils import create_digest, get_unique_objects_by_attribute 

65 

66 

67def _write_empty_digest_to_storage(storage): 

68 if not storage: 

69 return 

70 # Some clients skip uploading a blob with size 0. 

71 # We pre-insert the empty blob to make sure that it is available. 

72 empty_digest = create_digest(b'') 

73 if not storage.has_blob(empty_digest): 

74 session = storage.begin_write(empty_digest) 

75 storage.commit_write(empty_digest, session) 

76 # This check is useful to confirm the CAS is functioning correctly 

77 # but also to ensure that the access timestamp on the index is 

78 # bumped sufficiently often so that there is less of a chance of 

79 # the empty blob being evicted prematurely. 

80 if not storage.get_blob(empty_digest): 

81 raise NotFoundError("Empty blob not found after writing it to" 

82 " storage. Is your CAS configured correctly?") 

83 

84 

85class ContentAddressableStorageInstance: 

86 

87 def __init__(self, storage, read_only=False): 

88 self.__logger = logging.getLogger(__name__) 

89 

90 self._instance_name = None 

91 

92 self.__storage = storage 

93 

94 self.__read_only = read_only 

95 

96 # --- Public API --- 

97 

98 @property 

99 def instance_name(self): 

100 return self._instance_name 

101 

102 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

103 def register_instance_with_server(self, instance_name, server): 

104 """Names and registers the CAS instance with a given server.""" 

105 if self._instance_name is None: 

106 server.add_cas_instance(self, instance_name) 

107 

108 self._instance_name = instance_name 

109 if not self.__storage.instance_name: 

110 self.__storage.set_instance_name(instance_name) 

111 

112 else: 

113 raise AssertionError("Instance already registered") 

114 

115 def setup_grpc(self): 

116 self.__storage.setup_grpc() 

117 

118 # This is a check to ensure the CAS is functional, as well as make 

119 # sure that the empty digest is pre-populated (some tools don't 

120 # upload it, so we need to). It is done here since it needs to happen 

121 # after gRPC initialization in the case of a remote CAS backend. 

122 _write_empty_digest_to_storage(self.__storage) 

123 

124 def hash_type(self): 

125 return self.__storage.hash_type() 

126 

127 def max_batch_total_size_bytes(self): 

128 return self.__storage.max_batch_total_size_bytes() 

129 

130 def symlink_absolute_path_strategy(self): 

131 return self.__storage.symlink_absolute_path_strategy() 

132 

133 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, instanced=True) 

134 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

135 def find_missing_blobs(self, blob_digests): 

136 storage = self.__storage 

137 blob_digests = list(get_unique_objects_by_attribute(blob_digests, "hash")) 

138 missing_blobs = storage.missing_blobs(blob_digests) 

139 

140 num_blobs_in_request = len(blob_digests) 

141 if num_blobs_in_request > 0: 

142 num_blobs_missing = len(missing_blobs) 

143 percent_missing = float((num_blobs_missing / num_blobs_in_request) * 100) 

144 

145 with Distribution(CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

146 instance_name=self._instance_name) as metric_num_requested: 

147 metric_num_requested.count = float(num_blobs_in_request) 

148 

149 with Distribution(CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

150 instance_name=self._instance_name) as metric_num_missing: 

151 metric_num_missing.count = float(num_blobs_missing) 

152 

153 with Distribution(CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

154 instance_name=self._instance_name) as metric_percent_missing: 

155 metric_percent_missing.count = percent_missing 

156 

157 for digest in blob_digests: 

158 with Distribution(CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, 

159 instance_name=self._instance_name) as metric_requested_blob_size: 

160 metric_requested_blob_size.count = float(digest.size_bytes) 

161 

162 for digest in missing_blobs: 

163 with Distribution(CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, 

164 instance_name=self._instance_name) as metric_missing_blob_size: 

165 metric_missing_blob_size.count = float(digest.size_bytes) 

166 

167 return re_pb2.FindMissingBlobsResponse(missing_blob_digests=missing_blobs) 

168 

169 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, instanced=True) 

170 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

171 def batch_update_blobs(self, requests): 

172 if self.__read_only: 

173 raise PermissionDeniedError(f"CAS instance {self._instance_name} is read-only") 

174 

175 storage = self.__storage 

176 store = [] 

177 for request_proto in get_unique_objects_by_attribute(requests, "digest.hash"): 

178 store.append((request_proto.digest, request_proto.data)) 

179 

180 with Distribution(CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, 

181 instance_name=self._instance_name) as metric_blob_size: 

182 metric_blob_size.count = float(request_proto.digest.size_bytes) 

183 

184 response = re_pb2.BatchUpdateBlobsResponse() 

185 statuses = storage.bulk_update_blobs(store) 

186 

187 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, 

188 instance_name=self._instance_name) as bytes_counter: 

189 for (digest, _), status in zip(store, statuses): 

190 response_proto = response.responses.add() 

191 response_proto.digest.CopyFrom(digest) 

192 response_proto.status.CopyFrom(status) 

193 if response_proto.status.code == 0: 

194 bytes_counter.increment(response_proto.digest.size_bytes) 

195 

196 return response 

197 

198 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, instanced=True) 

199 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

200 def batch_read_blobs(self, digests): 

201 storage = self.__storage 

202 

203 response = re_pb2.BatchReadBlobsResponse() 

204 

205 max_batch_size = storage.max_batch_total_size_bytes() 

206 

207 # Only process unique digests 

208 good_digests = [] 

209 bad_digests = [] 

210 requested_bytes = 0 

211 for digest in get_unique_objects_by_attribute(digests, "hash"): 

212 if len(digest.hash) != HASH_LENGTH: 

213 bad_digests.append(digest) 

214 else: 

215 good_digests.append(digest) 

216 requested_bytes += digest.size_bytes 

217 

218 if requested_bytes > max_batch_size: 

219 raise InvalidArgumentError('Combined total size of blobs exceeds ' 

220 'server limit. ' 

221 f'({requested_bytes} > {max_batch_size} [byte])') 

222 

223 blobs_read = {} 

224 if len(good_digests) > 0: 

225 blobs_read = storage.bulk_read_blobs(good_digests) 

226 

227 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

228 instance_name=self._instance_name) as bytes_counter: 

229 for digest in good_digests: 

230 response_proto = response.responses.add() 

231 response_proto.digest.CopyFrom(digest) 

232 

233 if digest.hash in blobs_read and blobs_read[digest.hash] is not None: 

234 response_proto.data = blobs_read[digest.hash].read() 

235 status_code = code_pb2.OK 

236 bytes_counter.increment(digest.size_bytes) 

237 

238 with Distribution(CAS_BATCH_READ_BLOBS_SIZE_BYTES, 

239 instance_name=self._instance_name) as metric_blob_size: 

240 metric_blob_size.count = float(digest.size_bytes) 

241 else: 

242 status_code = code_pb2.NOT_FOUND 

243 

244 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

245 

246 for digest in bad_digests: 

247 response_proto = response.responses.add() 

248 response_proto.digest.CopyFrom(digest) 

249 status_code = code_pb2.INVALID_ARGUMENT 

250 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

251 

252 return response 

253 

254 @DurationMetric(CAS_GET_TREE_TIME_METRIC_NAME, instanced=True) 

255 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

256 def get_tree(self, request): 

257 storage = self.__storage 

258 

259 response = re_pb2.GetTreeResponse() 

260 page_size = request.page_size 

261 

262 if not request.page_size: 

263 request.page_size = MAX_REQUEST_COUNT 

264 

265 root_digest = request.root_digest 

266 page_size = request.page_size 

267 

268 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

269 instance_name=self._instance_name) as bytes_counter: 

270 def __get_tree(node_digest): 

271 nonlocal response, page_size, request 

272 

273 if not page_size: 

274 page_size = request.page_size 

275 yield response 

276 response = re_pb2.GetTreeResponse() 

277 

278 if response.ByteSize() >= (MAX_REQUEST_SIZE): 

279 yield response 

280 response = re_pb2.GetTreeResponse() 

281 

282 directory_from_digest = storage.get_message( 

283 node_digest, re_pb2.Directory) 

284 

285 bytes_counter.increment(node_digest.size_bytes) 

286 

287 page_size -= 1 

288 response.directories.extend([directory_from_digest]) 

289 

290 for directory in directory_from_digest.directories: 

291 yield from __get_tree(directory.digest) 

292 

293 yield response 

294 response = re_pb2.GetTreeResponse() 

295 

296 yield from __get_tree(root_digest) 

297 

298 

299class ByteStreamInstance: 

300 

301 BLOCK_SIZE = 1 * 1024 * 1024 # 1 MB block size 

302 

303 def __init__(self, storage=None, read_only=False, stream_storage=None, 

304 disable_overwrite_early_return=False): 

305 self._logger = logging.getLogger(__name__) 

306 

307 self._instance_name = None 

308 

309 self.__storage = storage 

310 self._stream_store = stream_storage 

311 self._query_activity_timeout = 30 

312 

313 self.__read_only = read_only 

314 

315 # If set, prevents `ByteStream.Write()` from returning without 

316 # reading all the client's `WriteRequests` for a digest that is 

317 # already in storage (i.e. not follow the REAPI-specified 

318 # behavior). 

319 self.__disable_overwrite_early_return = disable_overwrite_early_return 

320 # (Should only be used to work around issues with implementations 

321 # that treat the server half-closing its end of the gRPC stream 

322 # as a HTTP/2 stream error.) 

323 

324 # --- Public API --- 

325 

326 @property 

327 def instance_name(self): 

328 return self._instance_name 

329 

330 @instance_name.setter 

331 def instance_name(self, instance_name): 

332 self._instance_name = instance_name 

333 

334 def setup_grpc(self): 

335 if self.__storage: 

336 self.__storage.setup_grpc() 

337 

338 # This is a check to ensure the CAS is functional, as well as make 

339 # sure that the empty digest is pre-populated (some tools don't 

340 # upload it, so we need to). It is done here since it needs to happen 

341 # after gRPC initialization in the case of a remote CAS backend. 

342 _write_empty_digest_to_storage(self.__storage) 

343 

344 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

345 def register_instance_with_server(self, instance_name, server): 

346 """Names and registers the byte-stream instance with a given server.""" 

347 if self._instance_name is None: 

348 server.add_bytestream_instance(self, instance_name) 

349 

350 self._instance_name = instance_name 

351 if self.__storage is not None and not self.__storage.instance_name: 

352 self.__storage.set_instance_name(instance_name) 

353 else: 

354 raise AssertionError("Instance already registered") 

355 

356 def disconnect_logstream_reader(self, read_name: str): 

357 self._logger.info(f"Disconnecting reader from [{read_name}].") 

358 try: 

359 self._stream_store.streaming_client_left(read_name) 

360 except NotFoundError as e: 

361 self._logger.debug(f"Did not disconnect reader: {str(e)}") 

362 

363 @generator_method_duration_metric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME) 

364 @generator_method_exception_counter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

365 def read_cas_blob(self, digest_hash, digest_size, read_offset, read_limit): 

366 # pylint: disable=no-else-raise 

367 if self.__storage is None: 

368 raise InvalidArgumentError( 

369 "ByteStream instance not configured for use with CAS.") 

370 

371 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

372 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

373 

374 digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size)) 

375 

376 # Check the given read offset and limit. 

377 if read_offset < 0 or read_offset > digest.size_bytes: 

378 raise OutOfRangeError("Read offset out of range") 

379 

380 elif read_limit == 0: 

381 bytes_remaining = digest.size_bytes - read_offset 

382 

383 elif read_limit > 0: 

384 bytes_remaining = read_limit 

385 

386 else: 

387 raise InvalidArgumentError("Negative read_limit is invalid") 

388 

389 # Read the blob from storage and send its contents to the client. 

390 result = self.__storage.get_blob(digest) 

391 if result is None: 

392 raise NotFoundError("Blob not found") 

393 

394 elif result.seekable(): 

395 result.seek(read_offset) 

396 

397 else: 

398 result.read(read_offset) 

399 

400 with Distribution(metric_name=CAS_BYTESTREAM_READ_SIZE_BYTES, 

401 instance_name=self._instance_name) as metric_blob_size: 

402 metric_blob_size.count = float(digest.size_bytes) 

403 

404 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

405 instance_name=self._instance_name) as bytes_counter: 

406 while bytes_remaining > 0: 

407 block_data = result.read(min(self.BLOCK_SIZE, bytes_remaining)) 

408 yield ReadResponse(data=block_data) 

409 bytes_counter.increment(len(block_data)) 

410 bytes_remaining -= self.BLOCK_SIZE 

411 

412 def read_logstream(self, resource_name, context): 

413 if self._stream_store is None: 

414 raise InvalidArgumentError( 

415 "ByteStream instance not configured for use with LogStream.") 

416 

417 stream_iterator = self._stream_store.read_stream_bytes_blocking_iterator( 

418 resource_name, max_chunk_size=MAX_LOGSTREAM_CHUNK_SIZE, offset=0) 

419 

420 self._stream_store.new_client_streaming(resource_name) 

421 

422 for message in stream_iterator: 

423 if not context.is_active(): 

424 break 

425 yield ReadResponse(data=message) 

426 

427 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, instanced=True) 

428 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME) 

429 def write_cas_blob(self, digest_hash, digest_size, requests): 

430 if self.__read_only: 

431 raise PermissionDeniedError( 

432 f"ByteStream instance {self._instance_name} is read-only") 

433 

434 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

435 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

436 

437 digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size)) 

438 

439 with Distribution(metric_name=CAS_BYTESTREAM_WRITE_SIZE_BYTES, 

440 instance_name=self._instance_name) as metric_blob_size: 

441 metric_blob_size.count = float(digest.size_bytes) 

442 

443 if self.__storage.has_blob(digest): 

444 # According to the REAPI specification: 

445 # "When attempting an upload, if another client has already 

446 # completed the upload (which may occur in the middle of a single 

447 # upload if another client uploads the same blob concurrently), 

448 # the request will terminate immediately [...]". 

449 # 

450 # However, half-closing the stream can be problematic with some 

451 # intermediaries like HAProxy. 

452 # (https://github.com/haproxy/haproxy/issues/1219) 

453 # 

454 # If half-closing the stream is not allowed, we read and drop 

455 # all the client's messages before returning, still saving 

456 # the cost of a write to storage. 

457 if self.__disable_overwrite_early_return: 

458 for request in requests: 

459 if request.finish_write: 

460 break 

461 continue 

462 

463 return WriteResponse(committed_size=digest.size_bytes) 

464 

465 write_session = self.__storage.begin_write(digest) 

466 

467 # Start the write session and write the first request's data. 

468 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, 

469 instance_name=self._instance_name) as bytes_counter: 

470 computed_hash = HASH() 

471 

472 # Handle subsequent write requests. 

473 for request in requests: 

474 write_session.write(request.data) 

475 

476 computed_hash.update(request.data) 

477 bytes_counter.increment(len(request.data)) 

478 

479 if request.finish_write: 

480 break 

481 

482 # Check that the data matches the provided digest. 

483 if bytes_counter.count != digest.size_bytes: 

484 raise NotImplementedError( 

485 "Cannot close stream before finishing write") 

486 

487 if computed_hash.hexdigest() != digest.hash: 

488 raise InvalidArgumentError("Data does not match hash") 

489 

490 self.__storage.commit_write(digest, write_session) 

491 

492 return WriteResponse(committed_size=int(bytes_counter.count)) 

493 

494 def write_logstream(self, resource_name: str, first_request: WriteRequest, 

495 requests: List[WriteRequest]) -> WriteResponse: 

496 if self._stream_store is None: 

497 raise InvalidArgumentError( 

498 "ByteStream instance not configured for use with LogStream.") 

499 

500 with Counter(metric_name=LOGSTREAM_WRITE_UPLOADED_BYTES_COUNT, 

501 instance_name=self.instance_name) as bytes_counter: 

502 self._stream_store.append_to_stream(resource_name, first_request.data) 

503 bytes_counter.increment(len(first_request.data)) 

504 

505 for request in requests: 

506 self._stream_store.append_to_stream(resource_name, request.data) 

507 bytes_counter.increment(len(request.data)) 

508 

509 self._stream_store.append_to_stream(resource_name, None, mark_finished=True) 

510 return WriteResponse(committed_size=int(bytes_counter.count)) 

511 

512 def query_logstream_status(self, resource_name: str, 

513 context) -> QueryWriteStatusResponse: 

514 if self._stream_store is None: 

515 raise InvalidArgumentError( 

516 "ByteStream instance not configured for use with LogStream.") 

517 

518 while context.is_active(): 

519 if self._stream_store.wait_for_streaming_clients(resource_name, self._query_activity_timeout): 

520 streamlength = self._stream_store.writeable_stream_length(resource_name) 

521 return QueryWriteStatusResponse(committed_size=streamlength.length, complete=streamlength.finished) 

522 

523 raise NotFoundError(f"Stream [{resource_name}] didn't become writeable.")