Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15# pylint: disable=anomalous-backslash-in-string 

16 

17 

18from collections import namedtuple 

19from contextlib import contextmanager 

20import os 

21from typing import Dict, Set 

22import uuid 

23 

24import grpc 

25 

26from buildgrid._exceptions import NotFoundError 

27from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc 

28from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc 

29from buildgrid._protos.google.rpc import code_pb2 

30from buildgrid.client.capabilities import CapabilitiesInterface 

31from buildgrid.settings import HASH, MAX_REQUEST_SIZE, MAX_REQUEST_COUNT, BATCH_REQUEST_SIZE_THRESHOLD 

32from buildgrid.utils import create_digest, merkle_tree_maker 

33 

34 

35_FileRequest = namedtuple('FileRequest', ['digest', 'output_paths']) 

36 

37 

38class _CallCache: 

39 """Per remote grpc.StatusCode.UNIMPLEMENTED call cache.""" 

40 __calls: Dict[grpc.Channel, Set[str]] = {} 

41 

42 @classmethod 

43 def mark_unimplemented(cls, channel, name): 

44 if channel not in cls.__calls: 

45 cls.__calls[channel] = set() 

46 cls.__calls[channel].add(name) 

47 

48 @classmethod 

49 def unimplemented(cls, channel, name): 

50 if channel not in cls.__calls: 

51 return False 

52 return name in cls.__calls[channel] 

53 

54 

55class _CasBatchRequestSizesCache: 

56 """Cache that stores, for each remote, the limit of bytes that can 

57 be transferred using batches as well as a threshold that determines 

58 when a file can be fetched as part of a batch request. 

59 """ 

60 __cas_max_batch_transfer_size: Dict[grpc.Channel, Dict[str, int]] = {} 

61 __cas_batch_request_size_threshold: Dict[grpc.Channel, Dict[str, int]] = {} 

62 

63 @classmethod 

64 def max_effective_batch_size_bytes(cls, channel, instance_name): 

65 """Returns the maximum number of bytes that can be transferred 

66 using batch methods for the given remote. 

67 """ 

68 if channel not in cls.__cas_max_batch_transfer_size: 

69 cls.__cas_max_batch_transfer_size[channel] = dict() 

70 

71 if instance_name not in cls.__cas_max_batch_transfer_size[channel]: 

72 max_batch_size = cls._get_server_max_batch_total_size_bytes(channel, 

73 instance_name) 

74 

75 cls.__cas_max_batch_transfer_size[channel][instance_name] = max_batch_size 

76 

77 return cls.__cas_max_batch_transfer_size[channel][instance_name] 

78 

79 @classmethod 

80 def batch_request_size_threshold(cls, channel, instance_name): 

81 if channel not in cls.__cas_batch_request_size_threshold: 

82 cls.__cas_batch_request_size_threshold[channel] = dict() 

83 

84 if instance_name not in cls.__cas_batch_request_size_threshold[channel]: 

85 # Computing the threshold: 

86 max_batch_size = cls.max_effective_batch_size_bytes(channel, 

87 instance_name) 

88 threshold = BATCH_REQUEST_SIZE_THRESHOLD * max_batch_size 

89 

90 cls.__cas_batch_request_size_threshold[channel][instance_name] = threshold 

91 

92 return cls.__cas_batch_request_size_threshold[channel][instance_name] 

93 

94 @classmethod 

95 def _get_server_max_batch_total_size_bytes(cls, channel, instance_name): 

96 """Returns the maximum number of bytes that can be effectively 

97 transferred using batches, considering the limits imposed by 

98 the server's configuration and by gRPC. 

99 """ 

100 try: 

101 capabilities_interface = CapabilitiesInterface(channel) 

102 server_capabilities = capabilities_interface.get_capabilities(instance_name) 

103 

104 cache_capabilities = server_capabilities.cache_capabilities 

105 

106 max_batch_total_size = cache_capabilities.max_batch_total_size_bytes 

107 # The server could set this value to 0 (no limit set). 

108 if max_batch_total_size: 

109 return min(max_batch_total_size, MAX_REQUEST_SIZE) 

110 except ConnectionError: 

111 pass 

112 

113 return MAX_REQUEST_SIZE 

114 

115 

116@contextmanager 

117def download(channel, instance=None, u_uid=None): 

118 """Context manager generator for the :class:`Downloader` class.""" 

119 downloader = Downloader(channel, instance=instance) 

120 try: 

121 yield downloader 

122 finally: 

123 downloader.close() 

124 

125 

126class Downloader: 

127 """Remote CAS files, directories and messages download helper. 

128 

129 The :class:`Downloader` class comes with a generator factory function that 

130 can be used together with the `with` statement for context management:: 

131 

132 from buildgrid.client.cas import download 

133 

134 with download(channel, instance='build') as downloader: 

135 downloader.get_message(message_digest) 

136 """ 

137 

138 def __init__(self, channel, instance=None): 

139 """Initializes a new :class:`Downloader` instance. 

140 

141 Args: 

142 channel (grpc.Channel): A gRPC channel to the CAS endpoint. 

143 instance (str, optional): the targeted instance's name. 

144 """ 

145 self.channel = channel 

146 

147 self.instance_name = instance 

148 

149 self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel) 

150 self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) 

151 

152 self.__file_requests = {} 

153 self.__file_request_count = 0 

154 self.__file_request_size = 0 

155 self.__file_response_size = 0 

156 

157 # --- Public API --- 

158 

159 def get_blob(self, digest): 

160 """Retrieves a blob from the remote CAS server. 

161 

162 Args: 

163 digest (:obj:`Digest`): the blob's digest to fetch. 

164 

165 Returns: 

166 bytearray: the fetched blob data or None if not found. 

167 """ 

168 try: 

169 blob = self._fetch_blob(digest) 

170 except FileNotFoundError: 

171 return None 

172 except ConnectionError: 

173 raise 

174 

175 return blob 

176 

177 def get_blobs(self, digests): 

178 """Retrieves a list of blobs from the remote CAS server. 

179 

180 Args: 

181 digests (list): list of :obj:`Digest`\ s for the blobs to fetch. 

182 

183 Returns: 

184 list: the fetched blob data list. 

185 """ 

186 # _fetch_blob_batch returns (data, digest) pairs. 

187 # We only want the data. 

188 return [result[0] for result in self._fetch_blob_batch(digests)] 

189 

190 def get_message(self, digest, message): 

191 """Retrieves a :obj:`Message` from the remote CAS server. 

192 

193 Args: 

194 digest (:obj:`Digest`): the message's digest to fetch. 

195 message (:obj:`Message`): an empty message to fill. 

196 

197 Returns: 

198 :obj:`Message`: `message` filled or emptied if not found. 

199 """ 

200 try: 

201 message_blob = self._fetch_blob(digest) 

202 except NotFoundError: 

203 message_blob = None 

204 

205 if message_blob is not None: 

206 message.ParseFromString(message_blob) 

207 else: 

208 message.Clear() 

209 

210 return message 

211 

212 def get_messages(self, digests, messages): 

213 """Retrieves a list of :obj:`Message`\ s from the remote CAS server. 

214 

215 Note: 

216 The `digests` and `messages` list **must** contain the same number 

217 of elements. 

218 

219 Args: 

220 digests (list): list of :obj:`Digest`\ s for the messages to fetch. 

221 messages (list): list of empty :obj:`Message`\ s to fill. 

222 

223 Returns: 

224 list: the fetched and filled message list. 

225 """ 

226 assert len(digests) == len(messages) 

227 

228 # The individual empty messages might be of differing types, so we need 

229 # to set up a mapping 

230 digest_message_map = {digest.hash: message for (digest, message) in zip(digests, messages)} 

231 

232 batch_response = self._fetch_blob_batch(digests) 

233 

234 messages = [] 

235 for message_blob, message_digest in batch_response: 

236 message = digest_message_map[message_digest.hash] 

237 message.ParseFromString(message_blob) 

238 messages.append(message) 

239 

240 return messages 

241 

242 def download_file(self, digest, file_path, is_executable=False, queue=True): 

243 """Retrieves a file from the remote CAS server. 

244 

245 If queuing is allowed (`queue=True`), the download request **may** be 

246 defer. An explicit call to :func:`~flush` can force the request to be 

247 send immediately (along with the rest of the queued batch). 

248 

249 Args: 

250 digest (:obj:`Digest`): the file's digest to fetch. 

251 file_path (str): absolute or relative path to the local file to write. 

252 is_executable (bool): whether the file is executable or not. 

253 queue (bool, optional): whether or not the download request may be 

254 queued and submitted as part of a batch upload request. Defaults 

255 to True. 

256 

257 Raises: 

258 FileNotFoundError: if `digest` is not present in the remote CAS server. 

259 OSError: if `file_path` does not exist or is not readable. 

260 """ 

261 if not os.path.isabs(file_path): 

262 file_path = os.path.abspath(file_path) 

263 

264 if not queue or digest.size_bytes > self._queueable_file_size_threshold(): 

265 self._fetch_file(digest, file_path, is_executable=is_executable) 

266 else: 

267 self._queue_file(digest, file_path, is_executable=is_executable) 

268 

269 def download_directory(self, digest, directory_path): 

270 """Retrieves a :obj:`Directory` from the remote CAS server. 

271 

272 Args: 

273 digest (:obj:`Digest`): the directory's digest to fetch. 

274 

275 Raises: 

276 NotFoundError: if `digest` is not present in the remote CAS server. 

277 FileExistsError: if `directory_path` already contains parts of their 

278 fetched directory's content. 

279 """ 

280 if not os.path.isabs(directory_path): 

281 directory_path = os.path.abspath(directory_path) 

282 

283 self._fetch_directory(digest, directory_path) 

284 

285 def flush(self): 

286 """Ensures any queued request gets sent.""" 

287 if self.__file_requests: 

288 self._fetch_file_batch(self.__file_requests) 

289 

290 self.__file_requests.clear() 

291 self.__file_request_count = 0 

292 self.__file_request_size = 0 

293 self.__file_response_size = 0 

294 

295 def close(self): 

296 """Closes the underlying connection stubs. 

297 

298 Note: 

299 This will always send pending requests before closing connections, 

300 if any. 

301 """ 

302 self.flush() 

303 

304 self.__bytestream_stub = None 

305 self.__cas_stub = None 

306 

307 # --- Private API --- 

308 

309 def _fetch_blob(self, digest): 

310 """Fetches a blob using ByteStream.Read()""" 

311 read_blob = bytearray() 

312 

313 if self.instance_name: 

314 resource_name = '/'.join([self.instance_name, 'blobs', 

315 digest.hash, str(digest.size_bytes)]) 

316 else: 

317 resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) 

318 

319 read_request = bytestream_pb2.ReadRequest() 

320 read_request.resource_name = resource_name 

321 read_request.read_offset = 0 

322 

323 try: 

324 # TODO: Handle connection loss/recovery 

325 for read_response in self.__bytestream_stub.Read(read_request): 

326 read_blob += read_response.data 

327 

328 assert len(read_blob) == digest.size_bytes 

329 

330 except grpc.RpcError as e: 

331 status_code = e.code() 

332 if status_code == grpc.StatusCode.NOT_FOUND: 

333 raise FileNotFoundError("Requested data does not exist on the remote.") 

334 

335 else: 

336 raise ConnectionError(e.details()) 

337 

338 return read_blob 

339 

340 def _fetch_blob_batch(self, digests): 

341 """Fetches blobs using ContentAddressableStorage.BatchReadBlobs() 

342 Returns (data, digest) pairs""" 

343 batch_fetched = False 

344 read_blobs = [] 

345 

346 # First, try BatchReadBlobs(), if not already known not being implemented: 

347 if not _CallCache.unimplemented(self.channel, 'BatchReadBlobs'): 

348 batch_request = remote_execution_pb2.BatchReadBlobsRequest() 

349 batch_request.digests.extend(digests) 

350 if self.instance_name is not None: 

351 batch_request.instance_name = self.instance_name 

352 

353 try: 

354 batch_response = self.__cas_stub.BatchReadBlobs(batch_request) 

355 for response in batch_response.responses: 

356 assert response.digest in digests 

357 

358 read_blobs.append((response.data, response.digest)) 

359 

360 if response.status.code == code_pb2.NOT_FOUND: 

361 raise FileNotFoundError('Requested blob does not exist ' 

362 'on the remote.') 

363 if response.status.code != code_pb2.OK: 

364 raise ConnectionError('Error in CAS reply while fetching blob.') 

365 

366 batch_fetched = True 

367 

368 except grpc.RpcError as e: 

369 status_code = e.code() 

370 if status_code == grpc.StatusCode.UNIMPLEMENTED: 

371 _CallCache.mark_unimplemented(self.channel, 'BatchReadBlobs') 

372 

373 elif status_code == grpc.StatusCode.INVALID_ARGUMENT: 

374 read_blobs.clear() 

375 batch_fetched = False 

376 

377 else: 

378 raise ConnectionError(e.details()) 

379 

380 # Fallback to Read() if no BatchReadBlobs(): 

381 if not batch_fetched: 

382 for digest in digests: 

383 read_blobs.append((self._fetch_blob(digest), digest)) 

384 

385 return read_blobs 

386 

387 def _fetch_file(self, digest, file_path, is_executable=False): 

388 """Fetches a file using ByteStream.Read()""" 

389 if self.instance_name: 

390 resource_name = '/'.join([self.instance_name, 'blobs', 

391 digest.hash, str(digest.size_bytes)]) 

392 else: 

393 resource_name = '/'.join(['blobs', digest.hash, str(digest.size_bytes)]) 

394 

395 read_request = bytestream_pb2.ReadRequest() 

396 read_request.resource_name = resource_name 

397 read_request.read_offset = 0 

398 

399 os.makedirs(os.path.dirname(file_path), exist_ok=True) 

400 

401 with open(file_path, 'wb') as byte_file: 

402 # TODO: Handle connection loss/recovery 

403 for read_response in self.__bytestream_stub.Read(read_request): 

404 byte_file.write(read_response.data) 

405 

406 assert byte_file.tell() == digest.size_bytes 

407 

408 if is_executable: 

409 os.chmod(file_path, 0o755) # rwxr-xr-x / 755 

410 

411 def _queue_file(self, digest, file_path, is_executable=False): 

412 """Queues a file for later batch download""" 

413 batch_size_limit = self._max_effective_batch_size_bytes() 

414 

415 if self.__file_request_size + digest.ByteSize() > batch_size_limit: 

416 self.flush() 

417 elif self.__file_response_size + digest.size_bytes > batch_size_limit: 

418 self.flush() 

419 elif self.__file_request_count >= MAX_REQUEST_COUNT: 

420 self.flush() 

421 

422 output_path = (file_path, is_executable) 

423 

424 # When queueing a file we take into account the cases where 

425 # we might want to download the same digest to different paths. 

426 if digest.hash not in self.__file_requests: 

427 request = _FileRequest(digest=digest, output_paths=[output_path]) 

428 self.__file_requests[digest.hash] = request 

429 

430 self.__file_request_count += 1 

431 self.__file_request_size += digest.ByteSize() 

432 self.__file_response_size += digest.size_bytes 

433 else: 

434 # We already have that hash queued; we'll fetch the blob 

435 # once and write copies of it: 

436 self.__file_requests[digest.hash].output_paths.append(output_path) 

437 

438 def _fetch_file_batch(self, requests): 

439 """Sends queued data using ContentAddressableStorage.BatchReadBlobs(). 

440 

441 Takes a dictionary (digest.hash, _FileRequest) as input. 

442 """ 

443 batch_digests = [request.digest for request in requests.values()] 

444 batch_responses = self._fetch_blob_batch(batch_digests) 

445 

446 for file_blob, file_digest in batch_responses: 

447 output_paths = requests[file_digest.hash].output_paths 

448 

449 for file_path, is_executable in output_paths: 

450 os.makedirs(os.path.dirname(file_path), exist_ok=True) 

451 

452 with open(file_path, 'wb') as byte_file: 

453 byte_file.write(file_blob) 

454 

455 if is_executable: 

456 os.chmod(file_path, 0o755) # rwxr-xr-x / 755 

457 

458 def _fetch_directory(self, digest, directory_path): 

459 """Fetches a file using ByteStream.GetTree()""" 

460 # Better fail early if the local root path cannot be created: 

461 os.makedirs(directory_path, exist_ok=True) 

462 

463 directories = {} 

464 directory_fetched = False 

465 # First, try GetTree() if not known to be unimplemented yet: 

466 if not _CallCache.unimplemented(self.channel, 'GetTree'): 

467 tree_request = remote_execution_pb2.GetTreeRequest() 

468 tree_request.root_digest.CopyFrom(digest) 

469 tree_request.page_size = MAX_REQUEST_COUNT 

470 if self.instance_name is not None: 

471 tree_request.instance_name = self.instance_name 

472 

473 try: 

474 for tree_response in self.__cas_stub.GetTree(tree_request): 

475 for directory in tree_response.directories: 

476 directory_blob = directory.SerializeToString() 

477 directory_hash = HASH(directory_blob).hexdigest() 

478 

479 directories[directory_hash] = directory 

480 

481 assert digest.hash in directories 

482 

483 directory = directories[digest.hash] 

484 self._write_directory(directory, directory_path, directories=directories) 

485 

486 directory_fetched = True 

487 except grpc.RpcError as e: 

488 status_code = e.code() 

489 if status_code == grpc.StatusCode.UNIMPLEMENTED: 

490 _CallCache.mark_unimplemented(self.channel, 'GetTree') 

491 

492 elif status_code == grpc.StatusCode.NOT_FOUND: 

493 raise FileNotFoundError("Requested directory does not exist on the remote.") 

494 

495 else: 

496 raise ConnectionError(e.details()) 

497 

498 # If no GetTree(), _write_directory() will use BatchReadBlobs() 

499 # if available or Read() if not. 

500 if not directory_fetched: 

501 directory = remote_execution_pb2.Directory() 

502 directory.ParseFromString(self._fetch_blob(digest)) 

503 

504 self._write_directory(directory, directory_path) 

505 

506 def _write_directory(self, root_directory, root_path, directories=None): 

507 """Generates a local directory structure""" 

508 os.makedirs(root_path, exist_ok=True) 

509 self._write_directory_recursively(root_directory, root_path, directories=None) 

510 

511 def _write_directory_recursively(self, root_directory, root_path, directories=None): 

512 """Generate local directory recursively""" 

513 # i) Files: 

514 for file_node in root_directory.files: 

515 file_path = os.path.join(root_path, file_node.name) 

516 

517 if os.path.lexists(file_path): 

518 raise FileExistsError("'{}' already exists".format(file_path)) 

519 

520 self.download_file(file_node.digest, file_path, 

521 is_executable=file_node.is_executable) 

522 self.flush() 

523 

524 # ii) Directories: 

525 pending_directory_digests = [] 

526 pending_directory_paths = {} 

527 for directory_node in root_directory.directories: 

528 directory_hash = directory_node.digest.hash 

529 

530 # FIXME: Guard against ../ 

531 directory_path = os.path.join(root_path, directory_node.name) 

532 os.mkdir(directory_path) 

533 

534 if directories and directory_node.digest.hash in directories: 

535 # We already have the directory; just write it: 

536 directory = directories[directory_hash] 

537 

538 self._write_directory_recursively(directory, directory_path, 

539 directories=directories) 

540 else: 

541 # Gather all the directories that we need to get to 

542 # try fetching them in a single batch request: 

543 pending_directory_digests.append(directory_node.digest) 

544 pending_directory_paths[directory_hash] = directory_path 

545 

546 if pending_directory_paths: 

547 fetched_blobs = self._fetch_blob_batch(pending_directory_digests) 

548 

549 for (directory_blob, directory_digest) in fetched_blobs: 

550 directory = remote_execution_pb2.Directory() 

551 directory.ParseFromString(directory_blob) 

552 

553 # Assuming that the server might not return the blobs in 

554 # the same order than they were asked for, we read 

555 # the hashes of the returned blobs: 

556 # Guarantees for the reply orderings might change in 

557 # the specification at some point. 

558 # See: github.com/bazelbuild/remote-apis/issues/52 

559 

560 directory_path = pending_directory_paths[directory_digest.hash] 

561 

562 self._write_directory(directory, directory_path, 

563 directories=directories) 

564 

565 # iii) Symlinks: 

566 for symlink_node in root_directory.symlinks: 

567 symlink_path = os.path.join(root_path, symlink_node.name) 

568 os.symlink(symlink_node.target, symlink_path) 

569 

570 def _max_effective_batch_size_bytes(self): 

571 """Returns the effective maximum number of bytes that can be 

572 transferred using batches, considering gRPC maximum message size. 

573 """ 

574 return _CasBatchRequestSizesCache.max_effective_batch_size_bytes(self.channel, 

575 self.instance_name) 

576 

577 def _queueable_file_size_threshold(self): 

578 """Returns the size limit up until which files can be queued to 

579 be requested in a batch. 

580 """ 

581 return _CasBatchRequestSizesCache.batch_request_size_threshold(self.channel, 

582 self.instance_name) 

583 

584 

585@contextmanager 

586def upload(channel, instance=None, u_uid=None): 

587 """Context manager generator for the :class:`Uploader` class.""" 

588 uploader = Uploader(channel, instance=instance, u_uid=u_uid) 

589 try: 

590 yield uploader 

591 finally: 

592 uploader.close() 

593 

594 

595class Uploader: 

596 """Remote CAS files, directories and messages upload helper. 

597 

598 The :class:`Uploader` class comes with a generator factory function that can 

599 be used together with the `with` statement for context management:: 

600 

601 from buildgrid.client.cas import upload 

602 

603 with upload(channel, instance='build') as uploader: 

604 uploader.upload_file('/path/to/local/file') 

605 """ 

606 

607 def __init__(self, channel, instance=None, u_uid=None): 

608 """Initializes a new :class:`Uploader` instance. 

609 

610 Args: 

611 channel (grpc.Channel): A gRPC channel to the CAS endpoint. 

612 instance (str, optional): the targeted instance's name. 

613 u_uid (str, optional): a UUID for CAS transactions. 

614 """ 

615 self.channel = channel 

616 

617 self.instance_name = instance 

618 if u_uid is not None: 

619 self.u_uid = u_uid 

620 else: 

621 self.u_uid = str(uuid.uuid4()) 

622 

623 self.__bytestream_stub = bytestream_pb2_grpc.ByteStreamStub(self.channel) 

624 self.__cas_stub = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) 

625 

626 self.__requests = {} 

627 self.__request_count = 0 

628 self.__request_size = 0 

629 

630 # --- Public API --- 

631 

632 def put_blob(self, blob, digest=None, queue=False): 

633 """Stores a blob into the remote CAS server. 

634 

635 If queuing is allowed (`queue=True`), the upload request **may** be 

636 defer. An explicit call to :func:`~flush` can force the request to be 

637 send immediately (along with the rest of the queued batch). 

638 

639 Args: 

640 blob (bytes): the blob's data. 

641 digest (:obj:`Digest`, optional): the blob's digest. 

642 queue (bool, optional): whether or not the upload request may be 

643 queued and submitted as part of a batch upload request. Defaults 

644 to False. 

645 

646 Returns: 

647 :obj:`Digest`: the sent blob's digest. 

648 """ 

649 

650 if not queue or len(blob) > self._queueable_file_size_threshold(): 

651 blob_digest = self._send_blob(blob, digest=digest) 

652 else: 

653 blob_digest = self._queue_blob(blob, digest=digest) 

654 

655 return blob_digest 

656 

657 def put_message(self, message, digest=None, queue=False): 

658 """Stores a message into the remote CAS server. 

659 

660 If queuing is allowed (`queue=True`), the upload request **may** be 

661 defer. An explicit call to :func:`~flush` can force the request to be 

662 send immediately (along with the rest of the queued batch). 

663 

664 Args: 

665 message (:obj:`Message`): the message object. 

666 digest (:obj:`Digest`, optional): the message's digest. 

667 queue (bool, optional): whether or not the upload request may be 

668 queued and submitted as part of a batch upload request. Defaults 

669 to False. 

670 

671 Returns: 

672 :obj:`Digest`: the sent message's digest. 

673 """ 

674 message_blob = message.SerializeToString() 

675 

676 if not queue or len(message_blob) > self._queueable_file_size_threshold(): 

677 message_digest = self._send_blob(message_blob, digest=digest) 

678 else: 

679 message_digest = self._queue_blob(message_blob, digest=digest) 

680 

681 return message_digest 

682 

683 def upload_file(self, file_path, queue=True): 

684 """Stores a local file into the remote CAS storage. 

685 

686 If queuing is allowed (`queue=True`), the upload request **may** be 

687 defer. An explicit call to :func:`~flush` can force the request to be 

688 send immediately (allong with the rest of the queued batch). 

689 

690 Args: 

691 file_path (str): absolute or relative path to a local file. 

692 queue (bool, optional): whether or not the upload request may be 

693 queued and submitted as part of a batch upload request. Defaults 

694 to True. 

695 

696 Returns: 

697 :obj:`Digest`: The digest of the file's content. 

698 

699 Raises: 

700 FileNotFoundError: If `file_path` does not exist. 

701 PermissionError: If `file_path` is not readable. 

702 """ 

703 if not os.path.isabs(file_path): 

704 file_path = os.path.abspath(file_path) 

705 

706 with open(file_path, 'rb') as bytes_steam: 

707 file_bytes = bytes_steam.read() 

708 

709 if not queue or len(file_bytes) > self._queueable_file_size_threshold(): 

710 file_digest = self._send_blob(file_bytes) 

711 else: 

712 file_digest = self._queue_blob(file_bytes) 

713 

714 return file_digest 

715 

716 def upload_directory(self, directory_path, queue=True): 

717 """Stores a local folder into the remote CAS storage. 

718 

719 If queuing is allowed (`queue=True`), the upload request **may** be 

720 defer. An explicit call to :func:`~flush` can force the request to be 

721 send immediately (allong with the rest of the queued batch). 

722 

723 Args: 

724 directory_path (str): absolute or relative path to a local folder. 

725 queue (bool, optional): wheter or not the upload requests may be 

726 queued and submitted as part of a batch upload request. Defaults 

727 to True. 

728 

729 Returns: 

730 :obj:`Digest`: The digest of the top :obj:`Directory`. 

731 

732 Raises: 

733 FileNotFoundError: If `directory_path` does not exist. 

734 PermissionError: If `directory_path` is not readable. 

735 """ 

736 if not os.path.isabs(directory_path): 

737 directory_path = os.path.abspath(directory_path) 

738 

739 last_directory_node = None 

740 

741 if not queue: 

742 for node, blob, _ in merkle_tree_maker(directory_path): 

743 if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR: 

744 last_directory_node = node 

745 

746 self._send_blob(blob, digest=node.digest) 

747 

748 else: 

749 for node, blob, _ in merkle_tree_maker(directory_path): 

750 if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR: 

751 last_directory_node = node 

752 

753 self._queue_blob(blob, digest=node.digest) 

754 

755 return last_directory_node.digest 

756 

757 def upload_tree(self, directory_path, queue=True): 

758 """Stores a local folder into the remote CAS storage as a :obj:`Tree`. 

759 

760 If queuing is allowed (`queue=True`), the upload request **may** be 

761 defer. An explicit call to :func:`~flush` can force the request to be 

762 send immediately (allong with the rest of the queued batch). 

763 

764 Args: 

765 directory_path (str): absolute or relative path to a local folder. 

766 queue (bool, optional): wheter or not the upload requests may be 

767 queued and submitted as part of a batch upload request. Defaults 

768 to True. 

769 

770 Returns: 

771 :obj:`Digest`: The digest of the :obj:`Tree`. 

772 

773 Raises: 

774 FileNotFoundError: If `directory_path` does not exist. 

775 PermissionError: If `directory_path` is not readable. 

776 """ 

777 if not os.path.isabs(directory_path): 

778 directory_path = os.path.abspath(directory_path) 

779 

780 directories = [] 

781 

782 if not queue: 

783 for node, blob, _ in merkle_tree_maker(directory_path): 

784 if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR: 

785 # TODO: Get the Directory object from merkle_tree_maker(): 

786 directory = remote_execution_pb2.Directory() 

787 directory.ParseFromString(blob) 

788 directories.append(directory) 

789 

790 self._send_blob(blob, digest=node.digest) 

791 

792 else: 

793 for node, blob, _ in merkle_tree_maker(directory_path): 

794 if node.DESCRIPTOR is remote_execution_pb2.DirectoryNode.DESCRIPTOR: 

795 # TODO: Get the Directory object from merkle_tree_maker(): 

796 directory = remote_execution_pb2.Directory() 

797 directory.ParseFromString(blob) 

798 directories.append(directory) 

799 

800 self._queue_blob(blob, digest=node.digest) 

801 

802 tree = remote_execution_pb2.Tree() 

803 tree.root.CopyFrom(directories[-1]) 

804 tree.children.extend(directories[:-1]) 

805 

806 return self.put_message(tree, queue=queue) 

807 

808 def flush(self): 

809 """Ensures any queued request gets sent.""" 

810 if self.__requests: 

811 self._send_blob_batch(self.__requests) 

812 

813 self.__requests.clear() 

814 self.__request_count = 0 

815 self.__request_size = 0 

816 

817 def close(self): 

818 """Closes the underlying connection stubs. 

819 

820 Note: 

821 This will always send pending requests before closing connections, 

822 if any. 

823 """ 

824 self.flush() 

825 

826 self.__bytestream_stub = None 

827 self.__cas_stub = None 

828 

829 # --- Private API --- 

830 

831 def _send_blob(self, blob, digest=None): 

832 """Sends a memory block using ByteStream.Write()""" 

833 blob_digest = remote_execution_pb2.Digest() 

834 if digest is not None: 

835 blob_digest.CopyFrom(digest) 

836 else: 

837 blob_digest.hash = HASH(blob).hexdigest() 

838 blob_digest.size_bytes = len(blob) 

839 if self.instance_name: 

840 resource_name = '/'.join([self.instance_name, 'uploads', self.u_uid, 'blobs', 

841 blob_digest.hash, str(blob_digest.size_bytes)]) 

842 else: 

843 resource_name = '/'.join(['uploads', self.u_uid, 'blobs', 

844 blob_digest.hash, str(blob_digest.size_bytes)]) 

845 

846 def __write_request_stream(resource, content): 

847 offset = 0 

848 finished = False 

849 remaining = len(content) 

850 while not finished: 

851 chunk_size = min(remaining, MAX_REQUEST_SIZE) 

852 remaining -= chunk_size 

853 

854 request = bytestream_pb2.WriteRequest() 

855 request.resource_name = resource 

856 request.data = content[offset:offset + chunk_size] 

857 request.write_offset = offset 

858 request.finish_write = remaining <= 0 

859 

860 yield request 

861 

862 offset += chunk_size 

863 finished = request.finish_write 

864 

865 write_resquests = __write_request_stream(resource_name, blob) 

866 # TODO: Handle connection loss/recovery using QueryWriteStatus() 

867 try: 

868 write_response = self.__bytestream_stub.Write(write_resquests) 

869 except grpc.RpcError as e: 

870 raise ConnectionError(e.details()) 

871 

872 assert write_response.committed_size == blob_digest.size_bytes 

873 

874 return blob_digest 

875 

876 def _queue_blob(self, blob, digest=None): 

877 """Queues a memory block for later batch upload""" 

878 blob_digest = remote_execution_pb2.Digest() 

879 if digest is not None: 

880 blob_digest.CopyFrom(digest) 

881 else: 

882 blob_digest.hash = HASH(blob).hexdigest() 

883 blob_digest.size_bytes = len(blob) 

884 

885 # If we are here queueing a file we know that its size is 

886 # smaller than gRPC's message size limit. 

887 # We'll make a single batch request as big as the server allows. 

888 batch_size_limit = self._max_effective_batch_size_bytes() 

889 

890 if self.__request_size + blob_digest.size_bytes > batch_size_limit: 

891 self.flush() 

892 elif self.__request_count >= MAX_REQUEST_COUNT: 

893 self.flush() 

894 

895 self.__requests[blob_digest.hash] = (blob, blob_digest) 

896 self.__request_count += 1 

897 self.__request_size += blob_digest.size_bytes 

898 

899 return blob_digest 

900 

901 def _send_blob_batch(self, batch): 

902 """Sends queued data using ContentAddressableStorage.BatchUpdateBlobs()""" 

903 batch_fetched = False 

904 written_digests = [] 

905 

906 # First, try BatchUpdateBlobs(), if not already known not being implemented: 

907 if not _CallCache.unimplemented(self.channel, 'BatchUpdateBlobs'): 

908 batch_request = remote_execution_pb2.BatchUpdateBlobsRequest() 

909 if self.instance_name is not None: 

910 batch_request.instance_name = self.instance_name 

911 

912 for blob, digest in batch.values(): 

913 request = batch_request.requests.add() 

914 request.digest.CopyFrom(digest) 

915 request.data = blob 

916 

917 try: 

918 batch_response = self.__cas_stub.BatchUpdateBlobs(batch_request) 

919 for response in batch_response.responses: 

920 assert response.digest.hash in batch 

921 

922 written_digests.append(response.digest) 

923 if response.status.code != code_pb2.OK: 

924 response.digest.Clear() 

925 

926 batch_fetched = True 

927 

928 except grpc.RpcError as e: 

929 status_code = e.code() 

930 if status_code == grpc.StatusCode.UNIMPLEMENTED: 

931 _CallCache.mark_unimplemented(self.channel, 'BatchUpdateBlobs') 

932 

933 elif status_code == grpc.StatusCode.INVALID_ARGUMENT: 

934 written_digests.clear() 

935 batch_fetched = False 

936 

937 else: 

938 raise ConnectionError(e.details()) 

939 

940 # Fallback to Write() if no BatchUpdateBlobs(): 

941 if not batch_fetched: 

942 for blob, digest in batch.values(): 

943 written_digests.append(self._send_blob(blob, digest=digest)) 

944 

945 return written_digests 

946 

947 def _max_effective_batch_size_bytes(self): 

948 """Returns the effective maximum number of bytes that can be 

949 transferred using batches, considering gRPC maximum message size. 

950 """ 

951 return _CasBatchRequestSizesCache.max_effective_batch_size_bytes(self.channel, 

952 self.instance_name) 

953 

954 def _queueable_file_size_threshold(self): 

955 """Returns the size limit up until which files can be queued to 

956 be requested in a batch. 

957 """ 

958 return _CasBatchRequestSizesCache.batch_request_size_threshold(self.channel, 

959 self.instance_name)