Coverage for /builds/BuildGrid/buildgrid/buildgrid/browser/rest_api.py: 83.38%

397 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2021 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License' is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import asyncio 

16import json 

17import logging 

18import os 

19import shutil 

20import tarfile 

21import tempfile 

22from collections import namedtuple 

23from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type 

24 

25import aiofiles 

26from aiohttp import WSMsgType, web 

27from aiohttp_middlewares.annotations import UrlCollection 

28from grpc import RpcError, StatusCode 

29from grpc.aio import Call, Metadata 

30 

31from buildgrid._app.cli import Context 

32from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

33 ActionResult, 

34 Digest, 

35 Directory, 

36 GetActionResultRequest, 

37 RequestMetadata, 

38) 

39from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc_aio import ActionCacheStub 

40from buildgrid._protos.buildgrid.v2.query_build_events_pb2 import QueryEventStreamsRequest 

41from buildgrid._protos.buildgrid.v2.query_build_events_pb2_grpc_aio import QueryBuildEventsStub 

42from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest 

43from buildgrid._protos.google.bytestream.bytestream_pb2_grpc_aio import ByteStreamStub 

44from buildgrid._protos.google.longrunning import operations_pb2 

45from buildgrid._protos.google.longrunning.operations_pb2_grpc_aio import OperationsStub 

46from buildgrid.browser.utils import TARBALL_DIRECTORY_PREFIX, ResponseCache, get_cors_headers 

47from buildgrid.server.operations.filtering.interpreter import VALID_OPERATION_FILTERS, OperationFilterSpec 

48from buildgrid.server.operations.filtering.sanitizer import DatetimeValueSanitizer, SortKeyValueSanitizer 

49from buildgrid.server.request_metadata_utils import extract_request_metadata 

50from buildgrid.settings import BROWSER_MAX_CACHE_ENTRY_SIZE 

51 

52LOGGER = logging.getLogger(__name__) 

53 

54 

55def query_build_events_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

56 """Factory function which returns a handler for QueryEventStreams. 

57 

58 The returned handler uses ``context.channel`` to send a QueryEventStreams 

59 request constructed based on the provided URL query parameters. Currently 

60 only querying by build_id (equivalent to correlated invocations ID) is 

61 supported. 

62 

63 The handler returns a serialised QueryEventStreamsResponse, and raises a 

64 500 error in the case of some RPC error. 

65 

66 Args: 

67 context (Context): The context to use to send the gRPC request. 

68 

69 """ 

70 

71 async def _query_build_events_handler(request: web.Request) -> web.Response: 

72 build_id = request.rel_url.query.get("build_id", ".*") 

73 

74 stub = QueryBuildEventsStub(context.channel) # type: ignore # Requires stub regen 

75 

76 grpc_request = QueryEventStreamsRequest(build_id_pattern=build_id) 

77 

78 try: 

79 grpc_response = await stub.QueryEventStreams(grpc_request) 

80 except RpcError as e: 

81 LOGGER.warning(e.details()) 

82 raise web.HTTPInternalServerError() 

83 

84 serialized_response = grpc_response.SerializeToString() 

85 return web.Response(body=serialized_response) 

86 

87 return _query_build_events_handler 

88 

89 

90async def get_operation_filters_handler(request: web.Request) -> web.Response: 

91 """Return the available Operation filter keys.""" 

92 

93 def _generate_filter_spec(key: str, spec: OperationFilterSpec) -> Dict[str, Any]: 

94 comparators = ["<", "<=", "=", "!=", ">=", ">"] 

95 filter_type = "text" 

96 if isinstance(spec.sanitizer, SortKeyValueSanitizer): 

97 comparators = ["="] 

98 elif isinstance(spec.sanitizer, DatetimeValueSanitizer): 

99 filter_type = "datetime" 

100 

101 ret = { 

102 "comparators": comparators, 

103 "description": spec.description, 

104 "key": key, 

105 "name": spec.name, 

106 "type": filter_type, 

107 } 

108 

109 try: 

110 ret["values"] = spec.sanitizer.valid_values 

111 except NotImplementedError: 

112 pass 

113 return ret 

114 

115 operation_filters = [_generate_filter_spec(key, spec) for key, spec in VALID_OPERATION_FILTERS.items()] 

116 return web.Response(text=json.dumps(operation_filters)) 

117 

118 

119def list_operations_handler( 

120 context: Context, cache: ResponseCache 

121) -> Callable[[web.Request], Awaitable[web.Response]]: 

122 """Factory function which returns a handler for ListOperations. 

123 

124 The returned handler uses ``context.channel`` and ``context.instance_name`` 

125 to send a ListOperations request constructed based on the provided URL 

126 query parameters. 

127 

128 The handler returns a serialised ListOperationsResponse, raises a 400 

129 error in the case of a bad filter or other invalid argument, or raises 

130 a 500 error in the case of some other RPC error. 

131 

132 Args: 

133 context (Context): The context to use to send the gRPC request. 

134 

135 """ 

136 

137 async def _list_operations_handler(request: web.Request) -> web.Response: 

138 filter_string = request.rel_url.query.get("q", "") 

139 page_token = request.rel_url.query.get("page_token", "") 

140 page_size_str = request.rel_url.query.get("page_size") 

141 page_size = 0 

142 if page_size_str is not None: 

143 page_size = int(page_size_str) 

144 

145 LOGGER.info( 

146 f'Received ListOperations request, filter_string="{filter_string}" ' 

147 f'page_token="{page_token}" page_size="{page_size}"' 

148 ) 

149 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

150 grpc_request = operations_pb2.ListOperationsRequest( 

151 name=context.instance_name, page_token=page_token, page_size=page_size, filter=filter_string 

152 ) 

153 

154 try: 

155 grpc_response = await stub.ListOperations(grpc_request) 

156 except RpcError as e: 

157 LOGGER.warning(e.details()) 

158 if e.code() == StatusCode.INVALID_ARGUMENT: 

159 raise web.HTTPBadRequest() 

160 raise web.HTTPInternalServerError() 

161 

162 serialised_response = grpc_response.SerializeToString() 

163 return web.Response(body=serialised_response) 

164 

165 return _list_operations_handler 

166 

167 

168async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]: 

169 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

170 

171 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

172 grpc_request = operations_pb2.GetOperationRequest(name=operation_name) 

173 

174 try: 

175 call = stub.GetOperation(grpc_request) 

176 operation = await call 

177 except RpcError as e: 

178 LOGGER.warning(f"Error fetching operation: {e.details()}") 

179 if e.code() == StatusCode.INVALID_ARGUMENT: 

180 raise web.HTTPNotFound() 

181 raise web.HTTPInternalServerError() 

182 

183 return operation, call 

184 

185 

186def get_operation_handler(context: Context, cache: ResponseCache) -> Callable[[web.Request], Awaitable[web.Response]]: 

187 """Factory function which returns a handler for GetOperation. 

188 

189 The returned handler uses ``context.channel`` and ``context.instance_name`` 

190 to send a GetOperation request constructed based on the path component of 

191 the URL. 

192 

193 The handler returns a serialised Operation message, raises a 400 error in 

194 the case of an invalid operation name, or raises a 500 error in the case 

195 of some other RPC error. 

196 

197 Args: 

198 context (Context): The context to use to send the gRPC request. 

199 

200 """ 

201 

202 async def _get_operation_handler(request: web.Request) -> web.Response: 

203 name = request.match_info["name"] 

204 LOGGER.info(f'Received GetOperation request for "{name}"') 

205 operation, _ = await _get_operation(context, request) 

206 

207 serialised_response = operation.SerializeToString() 

208 return web.Response(body=serialised_response) 

209 

210 return _get_operation_handler 

211 

212 

213def get_operation_request_metadata_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

214 """Factory function which returns a handler to get RequestMetadata. 

215 

216 The returned handler uses ``context.channel`` and ``context.instance_name`` 

217 to send a GetOperation request constructed based on the path component of 

218 the URL. 

219 

220 The handler returns a serialised RequestMetadata proto message, retrieved 

221 from the trailing metadata of the GetOperation response. In the event of 

222 an invalid operation name it raises a 404 error, and raises a 500 error in 

223 the case of some other RPC error. 

224 

225 Args: 

226 context (Context): The context to use to send the gRPC request. 

227 

228 """ 

229 

230 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response: 

231 LOGGER.info(f'Received request for RequestMetadata for "{request.match_info["name"]}') 

232 _, call = await _get_operation(context, request) 

233 metadata = await call.trailing_metadata() 

234 

235 def extract_metadata(m: Metadata) -> RequestMetadata: 

236 # `m` contains a list of tuples, but `extract_request_metadata()` 

237 # expects a `key` and `value` attributes. 

238 MetadataTuple = namedtuple("MetadataTuple", ["key", "value"]) 

239 return extract_request_metadata([MetadataTuple(entry[0], entry[1]) for entry in m]) 

240 

241 request_metadata = extract_metadata(metadata) 

242 return web.Response(body=request_metadata.SerializeToString()) 

243 

244 return _get_operation_request_metadata_handler 

245 

246 

247def cancel_operation_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

248 """Factory function which returns a handler for CancelOperation. 

249 

250 The returned handler uses ``context.channel`` and ``context.instance_name`` 

251 to send a CancelOperation request constructed based on the path component of 

252 the URL. 

253 

254 The handler raises a 404 error in the case of an invalid operation name, 

255 or a 500 error in the case of some other RPC error. 

256 

257 On success, the response is empty. 

258 

259 Args: 

260 context (Context): The context to use to send the gRPC request. 

261 

262 """ 

263 

264 async def _cancel_operation_handler(request: web.Request) -> web.Response: 

265 LOGGER.info(f'Received CancelOperation request for "{request.match_info["name"]}"') 

266 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

267 

268 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

269 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name) 

270 

271 try: 

272 await stub.CancelOperation(grpc_request) 

273 return web.Response() 

274 except RpcError as e: 

275 LOGGER.warning(e.details()) 

276 if e.code() == StatusCode.INVALID_ARGUMENT: 

277 raise web.HTTPNotFound() 

278 raise web.HTTPInternalServerError() 

279 

280 return _cancel_operation_handler 

281 

282 

283async def _fetch_action_result( 

284 context: Context, request: web.Request, cache: ResponseCache, cache_key: str 

285) -> ActionResult: 

286 stub = ActionCacheStub(context.cache_channel) # type: ignore # Requires stub regen 

287 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

288 grpc_request = GetActionResultRequest(action_digest=digest, instance_name=context.instance_name) 

289 

290 try: 

291 result = await stub.GetActionResult(grpc_request) 

292 except RpcError as e: 

293 LOGGER.warning(f"Failed to fetch ActionResult: [{e.details()}]") 

294 if e.code() == StatusCode.NOT_FOUND: 

295 raise web.HTTPNotFound() 

296 raise web.HTTPInternalServerError() 

297 

298 await cache.store_action_result(cache_key, result) 

299 return result 

300 

301 

302def get_action_result_handler( 

303 context: Context, cache: ResponseCache, cache_capacity: int = 512 

304) -> Callable[[web.Request], Awaitable[web.Response]]: 

305 """Factory function which returns a handler for GetActionResult. 

306 

307 The returned handler uses ``context.channel`` and ``context.instance_name`` 

308 to send a GetActionResult request constructed based on the path components 

309 of the URL. 

310 

311 The handler returns a serialised ActionResult message, raises a 404 error 

312 if there's no result cached, or raises a 500 error in the case of some 

313 other RPC error. 

314 

315 Args: 

316 context (Context): The context to use to send the gRPC request. 

317 cache_capacity (int): The number of ActionResults to cache in memory 

318 to avoid hitting the actual ActionCache. 

319 

320 """ 

321 

322 class FetchSpec: 

323 """Simple class used to store information about a GetActionResult request. 

324 

325 A class is used here rather than a namedtuple since we need this state 

326 to be mutable. 

327 

328 """ 

329 

330 def __init__( 

331 self, *, error: Optional[Exception], event: asyncio.Event, result: Optional[ActionResult], refcount: int 

332 ): 

333 self.error = error 

334 self.event = event 

335 self.result = result 

336 self.refcount = refcount 

337 

338 in_flight_fetches: Dict[str, FetchSpec] = {} 

339 fetch_lock = asyncio.Lock() 

340 

341 async def _get_action_result_handler(request: web.Request) -> web.Response: 

342 LOGGER.info( 

343 "Received GetActionResult request for " 

344 f'"{request.match_info["hash"]}/{request.match_info["size_bytes"]}"' 

345 ) 

346 

347 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

348 result = await cache.get_action_result(cache_key) 

349 

350 if result is None: 

351 try: 

352 duplicate_request = False 

353 spec = None 

354 async with fetch_lock: 

355 if cache_key in in_flight_fetches: 

356 LOGGER.info(f"Deduplicating GetActionResult request for [{cache_key}]") 

357 spec = in_flight_fetches[cache_key] 

358 spec.refcount += 1 

359 duplicate_request = True 

360 else: 

361 spec = FetchSpec(error=None, event=asyncio.Event(), result=None, refcount=1) 

362 in_flight_fetches[cache_key] = spec 

363 

364 if duplicate_request and spec: 

365 # If we're a duplicate of an existing request, then wait for the 

366 # existing request to finish fetching from the ActionCache. 

367 await spec.event.wait() 

368 if spec.error is not None: 

369 raise spec.error 

370 if spec.result is None: 

371 # Note: this should be impossible, but lets guard against accidentally setting 

372 # the event before the result is populated 

373 LOGGER.info(f"Result not set in deduplicated request for [{cache_key}]") 

374 raise web.HTTPInternalServerError() 

375 result = spec.result 

376 else: 

377 try: 

378 result = await _fetch_action_result(context, request, cache, cache_key) 

379 except Exception as e: 

380 async with fetch_lock: 

381 if spec is not None: 

382 spec.error = e 

383 raise e 

384 

385 async with fetch_lock: 

386 if spec is not None: 

387 spec.result = result 

388 spec.event.set() 

389 

390 finally: 

391 async with fetch_lock: 

392 # Decrement refcount now we're done with the result. If we're the 

393 # last request interested in the result then remove it from the 

394 # `in_flight_fetches` dictionary. 

395 spec = in_flight_fetches.get(cache_key) 

396 if spec is not None: 

397 spec.refcount -= 1 

398 if spec.refcount <= 0: 

399 in_flight_fetches.pop(cache_key) 

400 

401 return web.Response(body=result.SerializeToString()) 

402 

403 return _get_action_result_handler 

404 

405 

406def get_blob_handler( 

407 context: Context, cache: ResponseCache, allow_all: bool = False, allowed_origins: UrlCollection = () 

408) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: 

409 async def _get_blob_handler(request: web.Request) -> web.StreamResponse: 

410 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

411 try: 

412 offset = int(request.rel_url.query.get("offset", "0")) 

413 limit = int(request.rel_url.query.get("limit", "0")) 

414 except ValueError: 

415 raise web.HTTPBadRequest() 

416 

417 response = web.StreamResponse() 

418 

419 # We need to explicitly set CORS headers here, because the middleware that 

420 # normally handles this only adds the header when the response is returned 

421 # from this function. However, when using a StreamResponse with chunked 

422 # encoding enabled, the client begins to receive the response when we call 

423 # `response.write()`. This leads to the request being disallowed due to the 

424 # missing reponse header for clients executing in browsers. 

425 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all) 

426 response.headers.update(cors_headers) 

427 

428 if request.rel_url.query.get("raw", "") == "1": 

429 response.headers["Content-type"] = "text/plain; charset=utf-8" 

430 else: 

431 # Setting the Content-Disposition header so that when 

432 # downloading a blob with a browser the default name uniquely 

433 # identifies its contents: 

434 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}" 

435 

436 # For partial reads also include the indices to prevent file mix-ups: 

437 if offset != 0 or limit != 0: 

438 filename += f"_chunk_{offset}-" 

439 if limit != 0: 

440 filename += f"{offset + limit}" 

441 

442 response.headers["Content-Disposition"] = f"Attachment;filename={filename}" 

443 

444 prepared = False 

445 

446 async def _callback(data: bytes) -> None: 

447 nonlocal prepared 

448 if not prepared: 

449 # Prepare for chunked encoding when the callback is first called, 

450 # so that we're sure we actually have some data before doing so. 

451 response.enable_chunked_encoding() 

452 await response.prepare(request) 

453 prepared = True 

454 await response.write(data) 

455 

456 await _fetch_blob(context, cache, digest, callback=_callback, offset=offset, limit=limit) 

457 

458 return response 

459 

460 return _get_blob_handler 

461 

462 

463def _create_tarball(directory: str, name: str) -> bool: 

464 """Makes a tarball from a given directory. 

465 

466 Returns True if the tarball was successfully created, and False if not. 

467 

468 Args: 

469 directory (str): The directory to tar up. 

470 name (str): The name of the tarball to be produced. 

471 

472 """ 

473 try: 

474 with tarfile.open(name, "w:gz") as tarball: 

475 tarball.add(directory, arcname="") 

476 except Exception: 

477 return False 

478 return True 

479 

480 

481async def _fetch_blob( 

482 context: Context, 

483 cache: ResponseCache, 

484 digest: Digest, 

485 message_class: Optional[Type[Any]] = None, 

486 callback: Optional[Callable[[bytes], Awaitable[Any]]] = None, 

487 offset: int = 0, 

488 limit: int = 0, 

489) -> Any: 

490 """Fetch a blob from CAS. 

491 

492 This function sends a ByteStream Read request for the given digest. If ``callback`` 

493 is set then the callback is called with the data in each ReadResponse message and 

494 this function returns an empty bytes object once the response is finished. If 

495 ``message_class`` is set with no ``callback`` set then this function calls 

496 ``message_class.FromString`` on the fetched blob and returns the result. 

497 

498 If neither ``callback`` or ``message_class`` are set then this function just returns 

499 the raw blob that was fetched from CAS. 

500 

501 Args: 

502 context (Context): The context to use to send the gRPC request. 

503 cache (ResponseCache): The response cache to check/update with the fetched blob. 

504 digest (Digest): The Digest of the blob to fetch from CAS. 

505 message_class (type): A class which implements a ``FromString`` class method. 

506 The method should take a bytes object expected to contain the blob fetched 

507 from CAS. 

508 callback (callable): A function or other callable to act on a subset of the 

509 blob contents. 

510 offset (int): Read offset to start reading the blob at. Defaults to 0, the start 

511 of the blob. 

512 limit (int): Maximum number of bytes to read from the blob. Defaults to 0, no 

513 limit. 

514 

515 """ 

516 cacheable = digest.size_bytes <= BROWSER_MAX_CACHE_ENTRY_SIZE 

517 resource_name = f"{context.instance_name}/blobs/{digest.hash}/{digest.size_bytes}" 

518 blob = None 

519 if cacheable: 

520 blob = await cache.get_blob(resource_name) 

521 if blob is not None and callback is not None: 

522 if limit > 0: 

523 try: 

524 blob = blob[offset : offset + limit] 

525 except IndexError: 

526 raise web.HTTPBadRequest() 

527 await callback(blob) 

528 

529 if blob is None: 

530 stub = ByteStreamStub(context.cas_channel) # type: ignore # Requires stub regen 

531 grpc_request = ReadRequest(resource_name=resource_name, read_offset=offset, read_limit=limit) 

532 

533 blob = b"" 

534 try: 

535 async for grpc_response in stub.Read(grpc_request): 

536 if grpc_response.data: 

537 if callback is not None: 

538 await callback(grpc_response.data) 

539 

540 if callback is None or cacheable: 

541 blob += grpc_response.data 

542 except RpcError as e: 

543 LOGGER.warning(e.details()) 

544 if e.code() == StatusCode.NOT_FOUND: 

545 raise web.HTTPNotFound() 

546 raise web.HTTPInternalServerError() 

547 

548 if cacheable: 

549 await cache.store_blob(resource_name, blob) 

550 

551 if message_class is not None and callback is None: 

552 return message_class.FromString(blob) 

553 return blob 

554 

555 

556async def _download_directory( 

557 context: Context, cache: ResponseCache, base: str, path: str, directory: Directory 

558) -> None: 

559 """Download the contents of a directory from CAS. 

560 

561 This function takes a Directory message and downloads the directory 

562 contents defined in the message from CAS. Raises a 400 error if the 

563 directory contains a symlink which points to a location outside the 

564 initial directory. 

565 

566 The directory is downloaded recursively depth-first. 

567 

568 Args: 

569 context (Context): The context to use for making gRPC requests. 

570 cache (ResponseCache): The response cache to use when downloading the 

571 directory contents. 

572 base (str): The initial directory path, used to check symlinks don't 

573 escape into the wider filesystem. 

574 path (str): The path to download the directory into. 

575 directory (Directory): The Directory message to fetch the contents of. 

576 

577 """ 

578 for directory_node in directory.directories: 

579 dir_path = os.path.join(path, directory_node.name) 

580 os.mkdir(dir_path) 

581 child = await _fetch_blob(context, cache, directory_node.digest, message_class=Directory) 

582 await _download_directory(context, cache, base, dir_path, child) 

583 

584 for file_node in directory.files: 

585 file_path = os.path.join(path, file_node.name) 

586 async with aiofiles.open(file_path, "wb") as f: 

587 await _fetch_blob(context, cache, file_node.digest, callback=f.write) 

588 if file_node.is_executable: 

589 os.chmod(file_path, 0o755) 

590 

591 for link_node in directory.symlinks: 

592 link_path = os.path.join(path, link_node.name) 

593 target_path = os.path.realpath(link_node.target) 

594 target_relpath = os.path.relpath(base, target_path) 

595 if target_relpath.startswith(os.pardir): 

596 raise web.HTTPBadRequest( 

597 reason="Requested directory contains a symlink targeting a location outside the tarball" 

598 ) 

599 

600 os.symlink(link_node.target, link_path) 

601 

602 

603async def _tarball_from_directory(context: Context, cache: ResponseCache, directory: Directory, tmp_dir: str) -> str: 

604 """Construct a tarball of a directory stored in CAS. 

605 

606 This function fetches the contents of the given directory message into a 

607 temporary directory, and then constructs a tarball of the directory. The 

608 path to this tarball is returned when construction is complete. 

609 

610 Args: 

611 context (Context): The context to use to send the gRPC requests. 

612 cache (ResponseCache): The response cache to use when fetching the 

613 tarball contents. 

614 directory (Directory): The Directory message for the directory we're 

615 making a tarball of. 

616 tmp_dir (str): Path to a temporary directory to use for storing the 

617 directory contents and its tarball. 

618 

619 """ 

620 tarball_dir = tempfile.mkdtemp(dir=tmp_dir) 

621 tarball_path = os.path.join(tmp_dir, "directory.tar.gz") 

622 loop = asyncio.get_event_loop() 

623 

624 # Fetch the contents of the directory into a temporary directory 

625 await _download_directory(context, cache, tarball_dir, tarball_dir, directory) 

626 

627 # Make a tarball from that temporary directory 

628 # NOTE: We do this using loop.run_in_executor to avoid the 

629 # synchronous and blocking tarball construction 

630 tarball_result = await loop.run_in_executor(None, _create_tarball, tarball_dir, tarball_path) 

631 if not tarball_result: 

632 raise web.HTTPInternalServerError() 

633 return tarball_path 

634 

635 

636def get_tarball_handler( 

637 context: Context, 

638 cache: ResponseCache, 

639 allow_all: bool = False, 

640 allowed_origins: UrlCollection = (), 

641 tarball_dir: Optional[str] = None, 

642) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: 

643 """Factory function which returns a handler for tarball requests. 

644 

645 This function also takes care of cleaning up old incomplete tarball constructions 

646 when given a named directory to do the construction in. 

647 

648 The returned handler takes the hash and size_bytes of a Digest of a Directory 

649 message and constructs a tarball of the directory defined by the message. 

650 

651 Args: 

652 context (Context): The context to use to send the gRPC requests. 

653 allow_all (bool): Whether or not to allow all CORS origins. 

654 allowed_origins (list): List of valid CORS origins. 

655 tarball_dir (str): Base directory to use for tarball construction. 

656 

657 """ 

658 

659 class FetchSpec: 

660 """Simple class used to store information about a tarball request. 

661 

662 A class is used here rather than a namedtuple since we need this state 

663 to be mutable. 

664 

665 """ 

666 

667 def __init__(self, *, error: Optional[Exception], event: Optional[asyncio.Event], path: str, refcount: int): 

668 self.error = error 

669 self.event = event 

670 self.path = path 

671 self.refcount = refcount 

672 

673 in_flight_fetches: Dict[str, FetchSpec] = {} 

674 fetch_lock = asyncio.Lock() 

675 

676 # If we have a tarball directory to use, empty all existing tarball constructions from it 

677 # to provide some form of cleanup after a crash. 

678 if tarball_dir is not None: 

679 for path in os.listdir(tarball_dir): 

680 if path.startswith(TARBALL_DIRECTORY_PREFIX): 

681 shutil.rmtree(os.path.join(tarball_dir, path)) 

682 

683 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse: 

684 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

685 LOGGER.info(f"Received request for a tarball from CAS for blob with digest [{digest_str}]") 

686 

687 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

688 

689 tmp_dir = tempfile.mkdtemp(prefix=TARBALL_DIRECTORY_PREFIX, dir=tarball_dir) 

690 

691 try: 

692 duplicate_request = False 

693 event = None 

694 async with fetch_lock: 

695 if digest_str in in_flight_fetches: 

696 LOGGER.info(f"Deduplicating request for tarball of [{digest_str}]") 

697 spec = in_flight_fetches[digest_str] 

698 spec.refcount += 1 

699 event = spec.event 

700 duplicate_request = True 

701 else: 

702 event = asyncio.Event() 

703 in_flight_fetches[digest_str] = FetchSpec(error=None, event=event, path="", refcount=1) 

704 

705 if duplicate_request and event: 

706 # If we're a duplicate of an existing request, then wait for the 

707 # existing request to finish tarball creation before reading the 

708 # path from the cache. 

709 await event.wait() 

710 spec = in_flight_fetches[digest_str] 

711 if spec.error is not None: 

712 raise spec.error 

713 tarball_path = in_flight_fetches[digest_str].path 

714 else: 

715 try: 

716 directory = await _fetch_blob(context, cache, digest, message_class=Directory) 

717 tarball_path = await _tarball_from_directory(context, cache, directory, tmp_dir) 

718 except web.HTTPError as e: 

719 in_flight_fetches[digest_str].error = e 

720 if event: 

721 event.set() 

722 raise e 

723 except Exception as e: 

724 LOGGER.debug("Unexpected error constructing tarball", exc_info=True) 

725 in_flight_fetches[digest_str].error = e 

726 if event: 

727 event.set() 

728 raise web.HTTPInternalServerError() 

729 

730 # Update path in deduplication cache, and set event to notify 

731 # duplicate requests that the tarball is ready 

732 async with fetch_lock: 

733 if event: 

734 in_flight_fetches[digest_str].path = tarball_path 

735 event.set() 

736 

737 response = web.StreamResponse() 

738 

739 # We need to explicitly set CORS headers here, because the middleware that 

740 # normally handles this only adds the header when the response is returned 

741 # from this function. However, when using a StreamResponse with chunked 

742 # encoding enabled, the client begins to receive the response when we call 

743 # `response.write()`. This leads to the request being disallowed due to the 

744 # missing reponse header for clients executing in browsers. 

745 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all) 

746 response.headers.update(cors_headers) 

747 

748 response.enable_chunked_encoding() 

749 await response.prepare(request) 

750 

751 async with aiofiles.open(tarball_path, "rb") as tarball: 

752 await tarball.seek(0) 

753 chunk = await tarball.read(1024) 

754 while chunk: 

755 await response.write(chunk) 

756 chunk = await tarball.read(1024) 

757 return response 

758 

759 except RpcError as e: 

760 LOGGER.warning(e.details()) 

761 if e.code() == StatusCode.NOT_FOUND: 

762 raise web.HTTPNotFound() 

763 raise web.HTTPInternalServerError() 

764 

765 finally: 

766 cleanup = False 

767 async with fetch_lock: 

768 # Decrement refcount now we're done with the tarball. If we're the 

769 # last request interested in the tarball then remove it along with 

770 # its construction directory. 

771 spec = in_flight_fetches[digest_str] 

772 spec.refcount -= 1 

773 if spec.refcount <= 0: 

774 cleanup = True 

775 in_flight_fetches.pop(digest_str) 

776 if cleanup: 

777 shutil.rmtree(tmp_dir) 

778 

779 return _get_tarball_handler 

780 

781 

782def logstream_handler(context: Context) -> Callable[[web.Request], Awaitable[Any]]: 

783 async def _logstream_handler(request: web.Request) -> Any: 

784 LOGGER.info("Receieved request for a LogStream websocket") 

785 stub = ByteStreamStub(context.logstream_channel) # type: ignore # Requires stub regen 

786 ws = web.WebSocketResponse() 

787 await ws.prepare(request) 

788 

789 async for msg in ws: 

790 if msg.type == WSMsgType.BINARY: 

791 read_request = ReadRequest() 

792 read_request.ParseFromString(msg.data) 

793 

794 read_request.resource_name = f"{read_request.resource_name}" 

795 try: 

796 async for response in stub.Read(read_request): 

797 serialized_response = response.SerializeToString() 

798 if serialized_response: 

799 ws_response = { 

800 "resource_name": read_request.resource_name, 

801 "data": response.data.decode("utf-8"), 

802 "complete": False, 

803 } 

804 await ws.send_json(ws_response) 

805 ws_response = {"resource_name": read_request.resource_name, "data": "", "complete": True} 

806 await ws.send_json(ws_response) 

807 except RpcError as e: 

808 LOGGER.warning(e.details()) 

809 if e.code() == StatusCode.NOT_FOUND: 

810 ws_response = { 

811 "resource_name": read_request.resource_name, 

812 "data": "NOT_FOUND", 

813 "complete": True, 

814 } 

815 await ws.send_json(ws_response) 

816 ws_response = {"resource_name": read_request.resource_name, "data": "INTERNAL", "complete": True} 

817 await ws.send_json(ws_response) 

818 

819 return _logstream_handler