Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/client/cas.py: 88.33%

514 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-05-21 15:45 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import os 

17import uuid 

18from collections import namedtuple 

19from contextlib import contextmanager 

20from functools import partial 

21from io import SEEK_END, BytesIO 

22from operator import attrgetter 

23from typing import IO, BinaryIO, Generator, Iterator, TypeVar, cast 

24 

25import grpc 

26from google.protobuf.message import Message 

27 

28from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

29 BatchReadBlobsRequest, 

30 BatchUpdateBlobsRequest, 

31 Digest, 

32 Directory, 

33 DirectoryNode, 

34 FileNode, 

35 GetTreeRequest, 

36 SymlinkNode, 

37 Tree, 

38) 

39from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc import ContentAddressableStorageStub 

40from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest, WriteRequest 

41from buildgrid._protos.google.bytestream.bytestream_pb2_grpc import ByteStreamStub 

42from buildgrid._protos.google.rpc import code_pb2 

43from buildgrid.server.client.capabilities import CapabilitiesInterface 

44from buildgrid.server.client.retrier import GrpcRetrier 

45from buildgrid.server.exceptions import NotFoundError 

46from buildgrid.server.metadata import metadata_list 

47from buildgrid.server.settings import BATCH_REQUEST_SIZE_THRESHOLD, HASH, MAX_REQUEST_COUNT, MAX_REQUEST_SIZE 

48from buildgrid.server.types import MessageType 

49from buildgrid.server.utils.digests import create_digest 

50 

51_FileRequest = namedtuple("_FileRequest", ["digest", "output_paths"]) 

52 

53 

54def create_digest_from_file(file_obj: BinaryIO) -> Digest: 

55 """Computed the :obj:`Digest` of a file-like object. 

56 

57 The :obj:`Digest` contains a hash of the file's contents and the size of 

58 those contents. This function only reads the content in chunks for hashing, 

59 so is safe to use on large files. 

60 

61 Args: 

62 file_obj (BinaryIO): A file-like object of some kind. 

63 

64 Returns: 

65 :obj:`Digest`: The :obj:`Digest` for the given file object. 

66 """ 

67 digest = Digest() 

68 

69 # Make sure we're hashing from the start of the file 

70 file_obj.seek(0) 

71 

72 # Generate the file hash and keep track of the file size 

73 hasher = HASH() 

74 digest.size_bytes = 0 

75 for block in iter(partial(file_obj.read, 8192), b""): 

76 hasher.update(block) 

77 digest.size_bytes += len(block) 

78 digest.hash = hasher.hexdigest() 

79 

80 # Return to the start of the file ready for future reads 

81 file_obj.seek(0) 

82 return digest 

83 

84 

85def merkle_tree_maker(directory_path: str) -> Iterator[tuple[FileNode | DirectoryNode, BinaryIO, str]]: 

86 """Walks a local folder tree, generating :obj:`FileNode` and 

87 :obj:`DirectoryNode`. 

88 

89 Args: 

90 directory_path (str): absolute or relative path to a local directory. 

91 

92 Yields: 

93 :obj:`Message`, bytes, str: a tutple of either a :obj:`FileNode` or 

94 :obj:`DirectoryNode` message, the corresponding blob and the 

95 corresponding node path. 

96 """ 

97 directory_name = os.path.basename(directory_path) 

98 

99 # Actual generator, yields recursively FileNodes and DirectoryNodes: 

100 def __merkle_tree_maker(directory_path: str, directory_name: str) -> Generator[ 

101 tuple[FileNode | DirectoryNode, BinaryIO, str], 

102 None, 

103 tuple[FileNode | DirectoryNode, BinaryIO, str], 

104 ]: 

105 if not os.path.isabs(directory_path): 

106 directory_path = os.path.abspath(directory_path) 

107 

108 directory = Directory() 

109 

110 files, directories, symlinks = [], [], [] 

111 for directory_entry in os.scandir(directory_path): 

112 node_name, node_path = directory_entry.name, directory_entry.path 

113 

114 node: FileNode | DirectoryNode 

115 node_blob: BinaryIO 

116 if directory_entry.is_file(follow_symlinks=False): 

117 with open(directory_entry.path, "rb") as node_blob: 

118 node_digest = create_digest_from_file(node_blob) 

119 

120 node = FileNode() 

121 node.name = node_name 

122 node.digest.CopyFrom(node_digest) 

123 node.is_executable = os.access(node_path, os.X_OK) 

124 

125 files.append(node) 

126 

127 yield node, node_blob, node_path 

128 

129 elif directory_entry.is_dir(follow_symlinks=False): 

130 node, node_blob, _ = yield from __merkle_tree_maker(node_path, node_name) 

131 

132 directories.append(cast(DirectoryNode, node)) 

133 

134 yield node, node_blob, node_path 

135 

136 # Create a SymlinkNode; 

137 elif os.path.islink(directory_entry.path): 

138 node_target = os.readlink(directory_entry.path) 

139 

140 symlink_node = SymlinkNode() 

141 symlink_node.name = directory_entry.name 

142 symlink_node.target = node_target 

143 

144 symlinks.append(symlink_node) 

145 

146 files.sort(key=attrgetter("name")) 

147 directories.sort(key=attrgetter("name")) 

148 symlinks.sort(key=attrgetter("name")) 

149 

150 directory.files.extend(files) 

151 directory.directories.extend(directories) 

152 directory.symlinks.extend(symlinks) 

153 

154 node_data = directory.SerializeToString() 

155 node_digest = create_digest(node_data) 

156 

157 dir_node = DirectoryNode() 

158 dir_node.name = directory_name 

159 dir_node.digest.CopyFrom(node_digest) 

160 

161 return dir_node, BytesIO(node_data), directory_path 

162 

163 node, node_blob, node_path = yield from __merkle_tree_maker(directory_path, directory_name) 

164 

165 yield node, node_blob, node_path 

166 

167 

168class _CallCache: 

169 """Per remote grpc.StatusCode.UNIMPLEMENTED call cache.""" 

170 

171 __calls: dict[grpc.Channel, set[str]] = {} 

172 

173 @classmethod 

174 def mark_unimplemented(cls, channel: grpc.Channel, name: str) -> None: 

175 if channel not in cls.__calls: 

176 cls.__calls[channel] = set() 

177 cls.__calls[channel].add(name) 

178 

179 @classmethod 

180 def unimplemented(cls, channel: grpc.Channel, name: str) -> bool: 

181 if channel not in cls.__calls: 

182 return False 

183 return name in cls.__calls[channel] 

184 

185 

186class _CasBatchRequestSizesCache: 

187 """Cache that stores, for each remote, the limit of bytes that can 

188 be transferred using batches as well as a threshold that determines 

189 when a file can be fetched as part of a batch request. 

190 """ 

191 

192 __cas_max_batch_transfer_size: dict[grpc.Channel, dict[str, int]] = {} 

193 __cas_batch_request_size_threshold: dict[grpc.Channel, dict[str, int]] = {} 

194 

195 @classmethod 

196 def max_effective_batch_size_bytes(cls, channel: grpc.Channel, instance_name: str) -> int: 

197 """Returns the maximum number of bytes that can be transferred 

198 using batch methods for the given remote. 

199 """ 

200 if channel not in cls.__cas_max_batch_transfer_size: 

201 cls.__cas_max_batch_transfer_size[channel] = {} 

202 

203 if instance_name not in cls.__cas_max_batch_transfer_size[channel]: 

204 max_batch_size = cls._get_server_max_batch_total_size_bytes(channel, instance_name) 

205 

206 cls.__cas_max_batch_transfer_size[channel][instance_name] = max_batch_size 

207 

208 return cls.__cas_max_batch_transfer_size[channel][instance_name] 

209 

210 @classmethod 

211 def batch_request_size_threshold(cls, channel: grpc.Channel, instance_name: str) -> int: 

212 if channel not in cls.__cas_batch_request_size_threshold: 

213 cls.__cas_batch_request_size_threshold[channel] = {} 

214 

215 if instance_name not in cls.__cas_batch_request_size_threshold[channel]: 

216 # Computing the threshold: 

217 max_batch_size = cls.max_effective_batch_size_bytes(channel, instance_name) 

218 threshold = int(BATCH_REQUEST_SIZE_THRESHOLD * max_batch_size) 

219 

220 cls.__cas_batch_request_size_threshold[channel][instance_name] = threshold 

221 

222 return cls.__cas_batch_request_size_threshold[channel][instance_name] 

223 

224 @classmethod 

225 def _get_server_max_batch_total_size_bytes(cls, channel: grpc.Channel, instance_name: str) -> int: 

226 """Returns the maximum number of bytes that can be effectively 

227 transferred using batches, considering the limits imposed by 

228 the server's configuration and by gRPC. 

229 """ 

230 try: 

231 capabilities_interface = CapabilitiesInterface(channel) 

232 server_capabilities = capabilities_interface.get_capabilities(instance_name) 

233 

234 cache_capabilities = server_capabilities.cache_capabilities 

235 

236 max_batch_total_size = cache_capabilities.max_batch_total_size_bytes 

237 # The server could set this value to 0 (no limit set). 

238 if max_batch_total_size: 

239 return min(max_batch_total_size, MAX_REQUEST_SIZE) 

240 except ConnectionError: 

241 pass 

242 

243 return MAX_REQUEST_SIZE 

244 

245 

246T = TypeVar("T", bound=MessageType) 

247 

248 

249class Downloader: 

250 """Remote CAS files, directories and messages download helper. 

251 

252 The :class:`Downloader` class comes with a generator factory function that 

253 can be used together with the `with` statement for context management:: 

254 

255 from buildgrid.server.client.cas import download 

256 

257 with download(channel, instance='build') as downloader: 

258 downloader.get_message(message_digest) 

259 """ 

260 

261 def __init__( 

262 self, 

263 channel: grpc.Channel, 

264 instance: str | None = None, 

265 retries: int = 0, 

266 max_backoff: int = 64, 

267 should_backoff: bool = True, 

268 ): 

269 """Initializes a new :class:`Downloader` instance. 

270 

271 Args: 

272 channel (grpc.Channel): A gRPC channel to the CAS endpoint. 

273 instance (str, optional): the targeted instance's name. 

274 """ 

275 self.channel = channel 

276 

277 self.instance_name = instance or "" 

278 

279 self._grpc_retrier = GrpcRetrier(retries=retries, max_backoff=max_backoff, should_backoff=should_backoff) 

280 

281 self.__bytestream_stub: ByteStreamStub | None = ByteStreamStub(self.channel) 

282 self.__cas_stub: ContentAddressableStorageStub | None = ContentAddressableStorageStub(self.channel) 

283 

284 self.__file_requests: dict[str, _FileRequest] = {} 

285 self.__file_request_count = 0 

286 self.__file_request_size = 0 

287 self.__file_response_size = 0 

288 

289 # --- Public API --- 

290 

291 def get_blob(self, digest: Digest) -> bytearray | None: 

292 """Retrieves a blob from the remote CAS server. 

293 

294 Args: 

295 digest (:obj:`Digest`): the blob's digest to fetch. 

296 

297 Returns: 

298 bytearray: the fetched blob data or None if not found. 

299 """ 

300 try: 

301 blob = self._grpc_retrier.retry(self._fetch_blob, digest) 

302 except NotFoundError: 

303 return None 

304 

305 return blob 

306 

307 def get_blobs(self, digests: list[Digest]) -> list[bytes]: 

308 """Retrieves a list of blobs from the remote CAS server. 

309 

310 Args: 

311 digests (list): list of :obj:`Digest`s for the blobs to fetch. 

312 

313 Returns: 

314 list: the fetched blob data list. 

315 

316 Raises: 

317 NotFoundError: if a blob is not present in the remote CAS server. 

318 """ 

319 # _fetch_blob_batch returns (data, digest) pairs. 

320 # We only want the data. 

321 return [result[0] for result in self._grpc_retrier.retry(self._fetch_blob_batch, digests)] 

322 

323 def get_available_blobs(self, digests: list[Digest]) -> list[tuple[bytes, Digest]]: 

324 """Retrieves a list of blobs from the remote CAS server. 

325 

326 Skips blobs not found on the server without raising an error. 

327 

328 Args: 

329 digests (list): list of :obj:`Digest`s for the blobs to fetch. 

330 

331 Returns: 

332 list: the fetched blob data list as (data, digest) pairs 

333 """ 

334 return self._grpc_retrier.retry(self._fetch_blob_batch, digests, skip_unavailable=True) 

335 

336 def get_message(self, digest: Digest, message: T) -> T: 

337 """Retrieves a :obj:`Message` from the remote CAS server. 

338 

339 Args: 

340 digest (:obj:`Digest`): the message's digest to fetch. 

341 message (:obj:`Message`): an empty message to fill. 

342 

343 Returns: 

344 :obj:`Message`: `message` filled or emptied if not found. 

345 """ 

346 try: 

347 message_blob = self._grpc_retrier.retry(self._fetch_blob, digest) 

348 except NotFoundError: 

349 message_blob = None 

350 

351 if message_blob is not None: 

352 message.ParseFromString(bytes(message_blob)) 

353 else: 

354 message.Clear() 

355 

356 return message 

357 

358 def get_messages(self, digests: list[Digest], messages: list[Message]) -> list[Message]: 

359 """Retrieves a list of :obj:`Message`s from the remote CAS server. 

360 

361 Note: 

362 The `digests` and `messages` list **must** contain the same number 

363 of elements. 

364 

365 Args: 

366 digests (list): list of :obj:`Digest`s for the messages to fetch. 

367 messages (list): list of empty :obj:`Message`s to fill. 

368 

369 Returns: 

370 list: the fetched and filled message list. 

371 """ 

372 assert len(digests) == len(messages) 

373 

374 # The individual empty messages might be of differing types, so we need 

375 # to set up a mapping 

376 digest_message_map = {digest.hash: message for (digest, message) in zip(digests, messages)} 

377 

378 batch_response = self._grpc_retrier.retry(self._fetch_blob_batch, digests) 

379 

380 messages = [] 

381 for message_blob, message_digest in batch_response: 

382 message = digest_message_map[message_digest.hash] 

383 message.ParseFromString(message_blob) 

384 messages.append(message) 

385 

386 return messages 

387 

388 def download_file(self, digest: Digest, file_path: str, is_executable: bool = False, queue: bool = True) -> None: 

389 """Retrieves a file from the remote CAS server. 

390 

391 If queuing is allowed (`queue=True`), the download request **may** be 

392 defer. An explicit call to :func:`~flush` can force the request to be 

393 send immediately (along with the rest of the queued batch). 

394 

395 Args: 

396 digest (:obj:`Digest`): the file's digest to fetch. 

397 file_path (str): absolute or relative path to the local file to write. 

398 is_executable (bool): whether the file is executable or not. 

399 queue (bool, optional): whether or not the download request may be 

400 queued and submitted as part of a batch upload request. Defaults 

401 to True. 

402 

403 Raises: 

404 NotFoundError: if `digest` is not present in the remote CAS server. 

405 OSError: if `file_path` does not exist or is not readable. 

406 """ 

407 if not os.path.isabs(file_path): 

408 file_path = os.path.abspath(file_path) 

409 

410 if not queue or digest.size_bytes > self._queueable_file_size_threshold(): 

411 self._grpc_retrier.retry(self._fetch_file, digest, file_path, is_executable) 

412 else: 

413 self._queue_file(digest, file_path, is_executable=is_executable) 

414 

415 def download_directory(self, digest: Digest, directory_path: str) -> None: 

416 """Retrieves a :obj:`Directory` from the remote CAS server. 

417 

418 Args: 

419 digest (:obj:`Digest`): the directory's digest to fetch. 

420 directory_path (str): the path to download to 

421 

422 Raises: 

423 NotFoundError: if `digest` is not present in the remote CAS server. 

424 FileExistsError: if `directory_path` already contains parts of their 

425 fetched directory's content. 

426 """ 

427 if not os.path.isabs(directory_path): 

428 directory_path = os.path.abspath(directory_path) 

429 

430 self._grpc_retrier.retry(self._fetch_directory, digest, directory_path) 

431 

432 def flush(self) -> None: 

433 """Ensures any queued request gets sent.""" 

434 if self.__file_requests: 

435 self._grpc_retrier.retry(self._fetch_file_batch, self.__file_requests) 

436 

437 self.__file_requests.clear() 

438 self.__file_request_count = 0 

439 self.__file_request_size = 0 

440 self.__file_response_size = 0 

441 

442 def close(self) -> None: 

443 """Closes the underlying connection stubs. 

444 

445 Note: 

446 This will always send pending requests before closing connections, 

447 if any. 

448 """ 

449 self.flush() 

450 

451 self.__bytestream_stub = None 

452 self.__cas_stub = None 

453 

454 # --- Private API --- 

455 

456 def _fetch_blob(self, digest: Digest) -> bytearray: 

457 """Fetches a blob using ByteStream.Read()""" 

458 

459 assert self.__bytestream_stub, "Downloader used after closing" 

460 

461 if self.instance_name: 

462 resource_name = "/".join([self.instance_name, "blobs", digest.hash, str(digest.size_bytes)]) 

463 else: 

464 resource_name = "/".join(["blobs", digest.hash, str(digest.size_bytes)]) 

465 

466 read_blob = bytearray() 

467 read_request = ReadRequest() 

468 read_request.resource_name = resource_name 

469 read_request.read_offset = 0 

470 

471 for read_response in self.__bytestream_stub.Read(read_request, metadata=metadata_list()): 

472 read_blob += read_response.data 

473 

474 assert len(read_blob) == digest.size_bytes 

475 return read_blob 

476 

477 def _fetch_blob_batch( 

478 self, digests: list[Digest], *, skip_unavailable: bool = False 

479 ) -> list[tuple[bytes, Digest]]: 

480 """Fetches blobs using ContentAddressableStorage.BatchReadBlobs() 

481 Returns (data, digest) pairs""" 

482 

483 assert self.__cas_stub, "Downloader used after closing" 

484 

485 batch_fetched = False 

486 read_blobs = [] 

487 

488 # First, try BatchReadBlobs(), if not already known not being implemented: 

489 if not _CallCache.unimplemented(self.channel, "BatchReadBlobs"): 

490 batch_request = BatchReadBlobsRequest() 

491 batch_request.digests.extend(digests) 

492 if self.instance_name is not None: 

493 batch_request.instance_name = self.instance_name 

494 

495 try: 

496 batch_response = self.__cas_stub.BatchReadBlobs(batch_request, metadata=metadata_list()) 

497 

498 for response in batch_response.responses: 

499 assert response.digest in digests 

500 

501 if response.status.code == code_pb2.OK: 

502 read_blobs.append((response.data, response.digest)) 

503 elif response.status.code == code_pb2.NOT_FOUND: 

504 if not skip_unavailable: 

505 raise NotFoundError("Requested blob does not exist on the remote.") 

506 else: 

507 raise ConnectionError("Error in CAS reply while fetching blob.") 

508 

509 batch_fetched = True 

510 

511 except grpc.RpcError as e: 

512 status_code = e.code() 

513 if status_code == grpc.StatusCode.UNIMPLEMENTED: 

514 _CallCache.mark_unimplemented(self.channel, "BatchReadBlobs") 

515 elif status_code == grpc.StatusCode.INVALID_ARGUMENT: 

516 read_blobs.clear() 

517 else: 

518 raise 

519 

520 # Fallback to Read() if no BatchReadBlobs(): 

521 if not batch_fetched: 

522 for digest in digests: 

523 blob = self._grpc_retrier.retry(self._fetch_blob, digest) 

524 read_blobs.append((blob, digest)) 

525 

526 return read_blobs 

527 

528 def _fetch_file(self, digest: Digest, file_path: str, is_executable: bool = False) -> None: 

529 """Fetches a file using ByteStream.Read()""" 

530 

531 assert self.__bytestream_stub, "Downloader used after closing" 

532 

533 if self.instance_name: 

534 resource_name = "/".join([self.instance_name, "blobs", digest.hash, str(digest.size_bytes)]) 

535 else: 

536 resource_name = "/".join(["blobs", digest.hash, str(digest.size_bytes)]) 

537 

538 os.makedirs(os.path.dirname(file_path), exist_ok=True) 

539 

540 read_request = ReadRequest() 

541 read_request.resource_name = resource_name 

542 read_request.read_offset = 0 

543 

544 with open(file_path, "wb") as byte_file: 

545 for read_response in self.__bytestream_stub.Read(read_request, metadata=metadata_list()): 

546 byte_file.write(read_response.data) 

547 

548 assert byte_file.tell() == digest.size_bytes 

549 

550 if is_executable: 

551 os.chmod(file_path, 0o755) # rwxr-xr-x / 755 

552 

553 def _queue_file(self, digest: Digest, file_path: str, is_executable: bool = False) -> None: 

554 """Queues a file for later batch download""" 

555 batch_size_limit = self._max_effective_batch_size_bytes() 

556 

557 if self.__file_request_size + digest.ByteSize() > batch_size_limit: 

558 self.flush() 

559 elif self.__file_response_size + digest.size_bytes > batch_size_limit: 

560 self.flush() 

561 elif self.__file_request_count >= MAX_REQUEST_COUNT: 

562 self.flush() 

563 

564 output_path = (file_path, is_executable) 

565 

566 # When queueing a file we take into account the cases where 

567 # we might want to download the same digest to different paths. 

568 if digest.hash not in self.__file_requests: 

569 request = _FileRequest(digest=digest, output_paths=[output_path]) 

570 self.__file_requests[digest.hash] = request 

571 

572 self.__file_request_count += 1 

573 self.__file_request_size += digest.ByteSize() 

574 self.__file_response_size += digest.size_bytes 

575 else: 

576 # We already have that hash queued; we'll fetch the blob 

577 # once and write copies of it: 

578 self.__file_requests[digest.hash].output_paths.append(output_path) 

579 

580 def _fetch_file_batch(self, requests: dict[str, _FileRequest]) -> None: 

581 """Sends queued data using ContentAddressableStorage.BatchReadBlobs(). 

582 

583 Takes a dictionary (digest.hash, _FileRequest) as input. 

584 """ 

585 batch_digests = [request.digest for request in requests.values()] 

586 batch_responses = self._fetch_blob_batch(batch_digests) 

587 

588 for file_blob, file_digest in batch_responses: 

589 output_paths = requests[file_digest.hash].output_paths 

590 

591 for file_path, is_executable in output_paths: 

592 os.makedirs(os.path.dirname(file_path), exist_ok=True) 

593 

594 with open(file_path, "wb") as byte_file: 

595 byte_file.write(file_blob) 

596 

597 if is_executable: 

598 os.chmod(file_path, 0o755) # rwxr-xr-x / 755 

599 

600 def _fetch_directory(self, digest: Digest, directory_path: str) -> None: 

601 """Fetches a file using ByteStream.GetTree()""" 

602 # Better fail early if the local root path cannot be created: 

603 

604 assert self.__cas_stub, "Downloader used after closing" 

605 

606 os.makedirs(directory_path, exist_ok=True) 

607 

608 directories = {} 

609 directory_fetched = False 

610 # First, try GetTree() if not known to be unimplemented yet: 

611 if not _CallCache.unimplemented(self.channel, "GetTree"): 

612 tree_request = GetTreeRequest() 

613 tree_request.root_digest.CopyFrom(digest) 

614 tree_request.page_size = MAX_REQUEST_COUNT 

615 if self.instance_name is not None: 

616 tree_request.instance_name = self.instance_name 

617 

618 try: 

619 for tree_response in self.__cas_stub.GetTree(tree_request): 

620 for directory in tree_response.directories: 

621 directory_blob = directory.SerializeToString() 

622 directory_hash = HASH(directory_blob).hexdigest() 

623 

624 directories[directory_hash] = directory 

625 

626 assert digest.hash in directories 

627 

628 directory = directories[digest.hash] 

629 self._write_directory(directory, directory_path, directories=directories) 

630 

631 directory_fetched = True 

632 except grpc.RpcError as e: 

633 status_code = e.code() 

634 if status_code == grpc.StatusCode.UNIMPLEMENTED: 

635 _CallCache.mark_unimplemented(self.channel, "GetTree") 

636 

637 else: 

638 raise 

639 

640 # If no GetTree(), _write_directory() will use BatchReadBlobs() 

641 # if available or Read() if not. 

642 if not directory_fetched: 

643 directory = Directory() 

644 directory.ParseFromString(self._grpc_retrier.retry(self._fetch_blob, digest)) 

645 

646 self._write_directory(directory, directory_path) 

647 

648 def _write_directory( 

649 self, root_directory: Directory, root_path: str, directories: dict[str, Directory] | None = None 

650 ) -> None: 

651 """Generates a local directory structure""" 

652 os.makedirs(root_path, exist_ok=True) 

653 self._write_directory_recursively(root_directory, root_path, directories=None) 

654 

655 def _write_directory_recursively( 

656 self, root_directory: Directory, root_path: str, directories: dict[str, Directory] | None = None 

657 ) -> None: 

658 """Generate local directory recursively""" 

659 # i) Files: 

660 for file_node in root_directory.files: 

661 file_path = os.path.join(root_path, file_node.name) 

662 

663 if os.path.lexists(file_path): 

664 raise FileExistsError(f"'{file_path}' already exists") 

665 

666 self.download_file(file_node.digest, file_path, is_executable=file_node.is_executable) 

667 self.flush() 

668 

669 # ii) Directories: 

670 pending_directory_digests = [] 

671 pending_directory_paths = {} 

672 for directory_node in root_directory.directories: 

673 directory_hash = directory_node.digest.hash 

674 

675 # FIXME: Guard against ../ 

676 directory_path = os.path.join(root_path, directory_node.name) 

677 os.mkdir(directory_path) 

678 

679 if directories and directory_node.digest.hash in directories: 

680 # We already have the directory; just write it: 

681 directory = directories[directory_hash] 

682 

683 self._write_directory_recursively(directory, directory_path, directories=directories) 

684 else: 

685 # Gather all the directories that we need to get to 

686 # try fetching them in a single batch request: 

687 if directory_hash not in pending_directory_paths: 

688 pending_directory_digests.append(directory_node.digest) 

689 pending_directory_paths[directory_hash] = [directory_path] 

690 else: 

691 pending_directory_paths[directory_hash].append(directory_path) 

692 

693 if pending_directory_paths: 

694 fetched_blobs = self._grpc_retrier.retry(self._fetch_blob_batch, pending_directory_digests) 

695 

696 for directory_blob, directory_digest in fetched_blobs: 

697 directory = Directory() 

698 directory.ParseFromString(directory_blob) 

699 

700 # Assuming that the server might not return the blobs in 

701 # the same order than they were asked for, we read 

702 # the hashes of the returned blobs: 

703 # Guarantees for the reply orderings might change in 

704 # the specification at some point. 

705 # See: github.com/bazelbuild/remote-apis/issues/52 

706 

707 for directory_path in pending_directory_paths[directory_digest.hash]: 

708 self._write_directory(directory, directory_path, directories=directories) 

709 

710 # iii) Symlinks: 

711 for symlink_node in root_directory.symlinks: 

712 symlink_path = os.path.join(root_path, symlink_node.name) 

713 os.symlink(symlink_node.target, symlink_path) 

714 

715 def _max_effective_batch_size_bytes(self) -> int: 

716 """Returns the effective maximum number of bytes that can be 

717 transferred using batches, considering gRPC maximum message size. 

718 """ 

719 return _CasBatchRequestSizesCache.max_effective_batch_size_bytes(self.channel, self.instance_name) 

720 

721 def _queueable_file_size_threshold(self) -> int: 

722 """Returns the size limit up until which files can be queued to 

723 be requested in a batch. 

724 """ 

725 return _CasBatchRequestSizesCache.batch_request_size_threshold(self.channel, self.instance_name) 

726 

727 

728@contextmanager 

729def download( 

730 channel: grpc.Channel, 

731 instance: str | None = None, 

732 u_uid: str | None = None, 

733 retries: int = 0, 

734 max_backoff: int = 64, 

735 should_backoff: bool = True, 

736) -> Iterator[Downloader]: 

737 """Context manager generator for the :class:`Downloader` class.""" 

738 downloader = Downloader( 

739 channel, instance=instance, retries=retries, max_backoff=max_backoff, should_backoff=should_backoff 

740 ) 

741 try: 

742 yield downloader 

743 finally: 

744 downloader.close() 

745 

746 

747class Uploader: 

748 """Remote CAS files, directories and messages upload helper. 

749 

750 The :class:`Uploader` class comes with a generator factory function that can 

751 be used together with the `with` statement for context management:: 

752 

753 from buildgrid.server.client.cas import upload 

754 

755 with upload(channel, instance='build') as uploader: 

756 uploader.upload_file('/path/to/local/file') 

757 """ 

758 

759 def __init__( 

760 self, 

761 channel: grpc.Channel, 

762 instance: str | None = None, 

763 u_uid: str | None = None, 

764 retries: int = 0, 

765 max_backoff: int = 64, 

766 should_backoff: bool = True, 

767 ): 

768 """Initializes a new :class:`Uploader` instance. 

769 

770 Args: 

771 channel (grpc.Channel): A gRPC channel to the CAS endpoint. 

772 instance (str, optional): the targeted instance's name. 

773 u_uid (str, optional): a UUID for CAS transactions. 

774 """ 

775 self.channel = channel 

776 

777 self.instance_name = instance or "" 

778 if u_uid is not None: 

779 self.u_uid = u_uid 

780 else: 

781 self.u_uid = str(uuid.uuid4()) 

782 

783 self._grpc_retrier = GrpcRetrier(retries=retries, max_backoff=max_backoff, should_backoff=should_backoff) 

784 

785 self.__bytestream_stub: ByteStreamStub | None = ByteStreamStub(self.channel) 

786 self.__cas_stub: ContentAddressableStorageStub | None = ContentAddressableStorageStub(self.channel) 

787 

788 self.__requests: dict[str, tuple[bytes, Digest]] = {} 

789 self.__request_count = 0 

790 self.__request_size = 0 

791 

792 # --- Public API --- 

793 

794 def put_blob( 

795 self, blob: IO[bytes], digest: Digest | None = None, queue: bool = False, length: int | None = None 

796 ) -> Digest: 

797 """Stores a blob into the remote CAS server. 

798 

799 If queuing is allowed (`queue=True`), the upload request **may** be 

800 defer. An explicit call to :func:`~flush` can force the request to be 

801 send immediately (along with the rest of the queued batch). 

802 

803 The caller should set at least one of ``digest`` or ``length`` to 

804 allow the uploader to skip determining the size of the blob using 

805 multiple seeks. 

806 

807 Args: 

808 blob (IO[bytes]): a file-like object containing the blob. 

809 digest (:obj:`Digest`, optional): the blob's digest. 

810 queue (bool, optional): whether or not the upload request may be 

811 queued and submitted as part of a batch upload request. Defaults 

812 to False. 

813 length (int, optional): The size of the blob, in bytes. If ``digest`` 

814 is also set, this is ignored in favour of ``digest.size_bytes``. 

815 

816 Returns: 

817 :obj:`Digest`: the sent blob's digest. 

818 """ 

819 if digest is not None: 

820 length = digest.size_bytes 

821 elif length is None: 

822 # If neither the digest or the length were set, fall back to 

823 # seeking to the end of the object to find the length 

824 blob.seek(0, SEEK_END) 

825 length = blob.tell() 

826 blob.seek(0) 

827 

828 if not queue or length > self._queueable_file_size_threshold(): 

829 blob_digest = self._grpc_retrier.retry(self._send_blob, blob, digest) 

830 else: 

831 blob_digest = self._queue_blob(blob.read(), digest=digest) 

832 

833 return blob_digest 

834 

835 def put_message(self, message: Message, digest: Digest | None = None, queue: bool = False) -> Digest: 

836 """Stores a message into the remote CAS server. 

837 

838 If queuing is allowed (`queue=True`), the upload request **may** be 

839 defer. An explicit call to :func:`~flush` can force the request to be 

840 send immediately (along with the rest of the queued batch). 

841 

842 Args: 

843 message (:obj:`Message`): the message object. 

844 digest (:obj:`Digest`, optional): the message's digest. 

845 queue (bool, optional): whether or not the upload request may be 

846 queued and submitted as part of a batch upload request. Defaults 

847 to False. 

848 

849 Returns: 

850 :obj:`Digest`: the sent message's digest. 

851 """ 

852 message_blob = message.SerializeToString() 

853 

854 if not queue or len(message_blob) > self._queueable_file_size_threshold(): 

855 message_digest = self._grpc_retrier.retry(self._send_blob, BytesIO(message_blob), digest) 

856 else: 

857 message_digest = self._queue_blob(message_blob, digest=digest) 

858 

859 return message_digest 

860 

861 def upload_file(self, file_path: str, queue: bool = True) -> Digest: 

862 """Stores a local file into the remote CAS storage. 

863 

864 If queuing is allowed (`queue=True`), the upload request **may** be 

865 defer. An explicit call to :func:`~flush` can force the request to be 

866 send immediately (allong with the rest of the queued batch). 

867 

868 Args: 

869 file_path (str): absolute or relative path to a local file. 

870 queue (bool, optional): whether or not the upload request may be 

871 queued and submitted as part of a batch upload request. Defaults 

872 to True. 

873 

874 Returns: 

875 :obj:`Digest`: The digest of the file's content. 

876 

877 Raises: 

878 FileNotFoundError: If `file_path` does not exist. 

879 PermissionError: If `file_path` is not readable. 

880 """ 

881 if not os.path.isabs(file_path): 

882 file_path = os.path.abspath(file_path) 

883 

884 with open(file_path, "rb") as file_object: 

885 if not queue or os.path.getsize(file_path) > self._queueable_file_size_threshold(): 

886 file_digest = self._grpc_retrier.retry(self._send_blob, file_object) 

887 else: 

888 file_digest = self._queue_blob(file_object.read()) 

889 

890 return file_digest 

891 

892 def upload_directory(self, directory_path: str, queue: bool = True) -> Digest: 

893 """Stores a local folder into the remote CAS storage. 

894 

895 If queuing is allowed (`queue=True`), the upload request **may** be 

896 defer. An explicit call to :func:`~flush` can force the request to be 

897 send immediately (allong with the rest of the queued batch). 

898 

899 Args: 

900 directory_path (str): absolute or relative path to a local folder. 

901 queue (bool, optional): wheter or not the upload requests may be 

902 queued and submitted as part of a batch upload request. Defaults 

903 to True. 

904 

905 Returns: 

906 :obj:`Digest`: The digest of the top :obj:`Directory`. 

907 

908 Raises: 

909 FileNotFoundError: If `directory_path` does not exist. 

910 PermissionError: If `directory_path` is not readable. 

911 """ 

912 if not os.path.isabs(directory_path): 

913 directory_path = os.path.abspath(directory_path) 

914 

915 if not queue: 

916 for node, blob, _ in merkle_tree_maker(directory_path): 

917 if node.DESCRIPTOR is DirectoryNode.DESCRIPTOR: 

918 last_directory_node = node 

919 

920 self._grpc_retrier.retry(self._send_blob, blob, node.digest) 

921 

922 else: 

923 for node, blob, _ in merkle_tree_maker(directory_path): 

924 if node.DESCRIPTOR is DirectoryNode.DESCRIPTOR: 

925 last_directory_node = node 

926 

927 if node.digest.size_bytes > self._queueable_file_size_threshold(): 

928 self._grpc_retrier.retry(self._send_blob, blob, node.digest) 

929 else: 

930 self._queue_blob(blob.read(), digest=node.digest) 

931 

932 return last_directory_node.digest 

933 

934 def upload_tree(self, directory_path: str, queue: bool = True) -> Digest: 

935 """Stores a local folder into the remote CAS storage as a :obj:`Tree`. 

936 

937 If queuing is allowed (`queue=True`), the upload request **may** be 

938 defer. An explicit call to :func:`~flush` can force the request to be 

939 send immediately (allong with the rest of the queued batch). 

940 

941 Args: 

942 directory_path (str): absolute or relative path to a local folder. 

943 queue (bool, optional): wheter or not the upload requests may be 

944 queued and submitted as part of a batch upload request. Defaults 

945 to True. 

946 

947 Returns: 

948 :obj:`Digest`: The digest of the :obj:`Tree`. 

949 

950 Raises: 

951 FileNotFoundError: If `directory_path` does not exist. 

952 PermissionError: If `directory_path` is not readable. 

953 """ 

954 if not os.path.isabs(directory_path): 

955 directory_path = os.path.abspath(directory_path) 

956 

957 directories = [] 

958 

959 if not queue: 

960 for node, blob, _ in merkle_tree_maker(directory_path): 

961 if node.DESCRIPTOR is DirectoryNode.DESCRIPTOR: 

962 # TODO: Get the Directory object from merkle_tree_maker(): 

963 directory = Directory() 

964 directory.ParseFromString(blob.read()) 

965 directories.append(directory) 

966 

967 self._grpc_retrier.retry(self._send_blob, blob, node.digest) 

968 

969 else: 

970 for node, blob, _ in merkle_tree_maker(directory_path): 

971 if node.DESCRIPTOR is DirectoryNode.DESCRIPTOR: 

972 # TODO: Get the Directory object from merkle_tree_maker(): 

973 directory = Directory() 

974 directory.ParseFromString(blob.read()) 

975 directories.append(directory) 

976 

977 if node.digest.size_bytes > self._queueable_file_size_threshold(): 

978 self._grpc_retrier.retry(self._send_blob, blob, node.digest) 

979 else: 

980 self._queue_blob(blob.read(), digest=node.digest) 

981 

982 tree = Tree() 

983 tree.root.CopyFrom(directories[-1]) 

984 tree.children.extend(directories[:-1]) 

985 

986 return self.put_message(tree, queue=queue) 

987 

988 def flush(self) -> None: 

989 """Ensures any queued request gets sent.""" 

990 if self.__requests: 

991 self._grpc_retrier.retry(self._send_blob_batch, self.__requests) 

992 

993 self.__requests.clear() 

994 self.__request_count = 0 

995 self.__request_size = 0 

996 

997 def close(self) -> None: 

998 """Closes the underlying connection stubs. 

999 

1000 Note: 

1001 This will always send pending requests before closing connections, 

1002 if any. 

1003 """ 

1004 self.flush() 

1005 

1006 self.__bytestream_stub = None 

1007 self.__cas_stub = None 

1008 

1009 # --- Private API --- 

1010 

1011 def _send_blob(self, blob: BinaryIO, digest: Digest | None = None) -> Digest: 

1012 """Sends a memory block using ByteStream.Write()""" 

1013 

1014 assert self.__bytestream_stub, "Uploader used after closing" 

1015 

1016 blob.seek(0) 

1017 blob_digest = Digest() 

1018 if digest is not None: 

1019 blob_digest.CopyFrom(digest) 

1020 else: 

1021 blob_digest = create_digest_from_file(blob) 

1022 

1023 if self.instance_name: 

1024 resource_name = "/".join( 

1025 [self.instance_name, "uploads", self.u_uid, "blobs", blob_digest.hash, str(blob_digest.size_bytes)] 

1026 ) 

1027 else: 

1028 resource_name = "/".join(["uploads", self.u_uid, "blobs", blob_digest.hash, str(blob_digest.size_bytes)]) 

1029 

1030 def __write_request_stream(resource: str, content: BinaryIO) -> Iterator[WriteRequest]: 

1031 offset = 0 

1032 finished = False 

1033 remaining = blob_digest.size_bytes - offset 

1034 while not finished: 

1035 chunk_size = min(remaining, MAX_REQUEST_SIZE) 

1036 remaining -= chunk_size 

1037 

1038 request = WriteRequest() 

1039 request.resource_name = resource 

1040 request.data = content.read(chunk_size) 

1041 request.write_offset = offset 

1042 request.finish_write = remaining <= 0 

1043 

1044 yield request 

1045 

1046 offset += chunk_size 

1047 finished = request.finish_write 

1048 

1049 write_requests = __write_request_stream(resource_name, blob) 

1050 

1051 write_response = self.__bytestream_stub.Write(write_requests, metadata=metadata_list()) 

1052 

1053 assert write_response.committed_size == blob_digest.size_bytes 

1054 

1055 return blob_digest 

1056 

1057 def _queue_blob(self, blob: bytes, digest: Digest | None = None) -> Digest: 

1058 """Queues a memory block for later batch upload""" 

1059 blob_digest = Digest() 

1060 if digest is not None: 

1061 blob_digest.CopyFrom(digest) 

1062 else: 

1063 blob_digest.hash = HASH(blob).hexdigest() 

1064 blob_digest.size_bytes = len(blob) 

1065 

1066 # If we are here queueing a file we know that its size is 

1067 # smaller than gRPC's message size limit. 

1068 # We'll make a single batch request as big as the server allows. 

1069 batch_size_limit = self._max_effective_batch_size_bytes() 

1070 

1071 if self.__request_size + blob_digest.size_bytes > batch_size_limit: 

1072 self.flush() 

1073 elif self.__request_count >= MAX_REQUEST_COUNT: 

1074 self.flush() 

1075 

1076 self.__requests[blob_digest.hash] = (blob, blob_digest) 

1077 self.__request_count += 1 

1078 self.__request_size += blob_digest.size_bytes 

1079 

1080 return blob_digest 

1081 

1082 def _send_blob_batch(self, batch: dict[str, tuple[bytes, Digest]]) -> list[Digest]: 

1083 """Sends queued data using ContentAddressableStorage.BatchUpdateBlobs()""" 

1084 

1085 assert self.__cas_stub, "Uploader used after closing" 

1086 

1087 batch_fetched = False 

1088 written_digests = [] 

1089 

1090 # First, try BatchUpdateBlobs(), if not already known not being implemented: 

1091 if not _CallCache.unimplemented(self.channel, "BatchUpdateBlobs"): 

1092 batch_request = BatchUpdateBlobsRequest() 

1093 if self.instance_name is not None: 

1094 batch_request.instance_name = self.instance_name 

1095 

1096 for blob, digest in batch.values(): 

1097 request = batch_request.requests.add() 

1098 request.digest.CopyFrom(digest) 

1099 request.data = blob 

1100 

1101 try: 

1102 batch_response = self.__cas_stub.BatchUpdateBlobs(batch_request, metadata=metadata_list()) 

1103 

1104 for response in batch_response.responses: 

1105 assert response.digest.hash in batch 

1106 

1107 written_digests.append(response.digest) 

1108 if response.status.code != code_pb2.OK: 

1109 response.digest.Clear() 

1110 

1111 batch_fetched = True 

1112 

1113 except grpc.RpcError as e: 

1114 status_code = e.code() 

1115 if status_code == grpc.StatusCode.UNIMPLEMENTED: 

1116 _CallCache.mark_unimplemented(self.channel, "BatchUpdateBlobs") 

1117 

1118 elif status_code == grpc.StatusCode.INVALID_ARGUMENT: 

1119 written_digests.clear() 

1120 batch_fetched = False 

1121 

1122 else: 

1123 raise 

1124 

1125 # Fallback to Write() if no BatchUpdateBlobs(): 

1126 if not batch_fetched: 

1127 for blob, digest in batch.values(): 

1128 written_digests.append(self._send_blob(BytesIO(blob))) 

1129 

1130 return written_digests 

1131 

1132 def _max_effective_batch_size_bytes(self) -> int: 

1133 """Returns the effective maximum number of bytes that can be 

1134 transferred using batches, considering gRPC maximum message size. 

1135 """ 

1136 return _CasBatchRequestSizesCache.max_effective_batch_size_bytes(self.channel, self.instance_name) 

1137 

1138 def _queueable_file_size_threshold(self) -> int: 

1139 """Returns the size limit up until which files can be queued to 

1140 be requested in a batch. 

1141 """ 

1142 return _CasBatchRequestSizesCache.batch_request_size_threshold(self.channel, self.instance_name) 

1143 

1144 

1145@contextmanager 

1146def upload( 

1147 channel: grpc.Channel, 

1148 instance: str | None = None, 

1149 u_uid: str | None = None, 

1150 retries: int = 0, 

1151 max_backoff: int = 64, 

1152 should_backoff: bool = True, 

1153) -> Iterator[Uploader]: 

1154 """Context manager generator for the :class:`Uploader` class.""" 

1155 uploader = Uploader( 

1156 channel, 

1157 instance=instance, 

1158 u_uid=u_uid, 

1159 retries=retries, 

1160 max_backoff=max_backoff, 

1161 should_backoff=should_backoff, 

1162 ) 

1163 try: 

1164 yield uploader 

1165 finally: 

1166 uploader.close()