Coverage for /builds/BuildGrid/buildgrid/buildgrid/browser/rest_api.py: 82.95%

387 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-05 15:37 +0000

1# Copyright (C) 2021 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License' is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import asyncio 

16import logging 

17import os 

18import shutil 

19import tarfile 

20import tempfile 

21from base64 import b64encode 

22from collections import namedtuple 

23from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type 

24 

25import aiofiles 

26from aiohttp import WSMsgType, web 

27from aiohttp_middlewares.annotations import UrlCollection 

28from grpc import RpcError, StatusCode 

29from grpc.aio import Call, Metadata 

30 

31from buildgrid._app.cli import Context 

32from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

33 ActionResult, 

34 Digest, 

35 Directory, 

36 GetActionResultRequest, 

37 RequestMetadata, 

38) 

39from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc_aio import ActionCacheStub 

40from buildgrid._protos.buildgrid.v2.query_build_events_pb2 import QueryEventStreamsRequest 

41from buildgrid._protos.buildgrid.v2.query_build_events_pb2_grpc_aio import QueryBuildEventsStub 

42from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest 

43from buildgrid._protos.google.bytestream.bytestream_pb2_grpc_aio import ByteStreamStub 

44from buildgrid._protos.google.longrunning import operations_pb2 

45from buildgrid._protos.google.longrunning.operations_pb2_grpc_aio import OperationsStub 

46from buildgrid.browser.utils import TARBALL_DIRECTORY_PREFIX, ResponseCache, get_cors_headers 

47from buildgrid.server.request_metadata_utils import extract_request_metadata 

48from buildgrid.settings import BROWSER_MAX_CACHE_ENTRY_SIZE, BROWSER_OPERATION_CACHE_MAX_LENGTH 

49 

50LOGGER = logging.getLogger(__name__) 

51 

52 

53def query_build_events_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

54 """Factory function which returns a handler for QueryEventStreams. 

55 

56 The returned handler uses ``context.channel`` to send a QueryEventStreams 

57 request constructed based on the provided URL query parameters. Currently 

58 only querying by build_id (equivalent to correlated invocations ID) is 

59 supported. 

60 

61 The handler returns a serialised QueryEventStreamsResponse, and raises a 

62 500 error in the case of some RPC error. 

63 

64 Args: 

65 context (Context): The context to use to send the gRPC request. 

66 

67 """ 

68 

69 async def _query_build_events_handler(request: web.Request) -> web.Response: 

70 build_id = request.rel_url.query.get("build_id", ".*") 

71 

72 stub = QueryBuildEventsStub(context.channel) # type: ignore # Requires stub regen 

73 

74 grpc_request = QueryEventStreamsRequest(build_id_pattern=build_id) 

75 

76 try: 

77 grpc_response = await stub.QueryEventStreams(grpc_request) 

78 except RpcError as e: 

79 LOGGER.warning(e.details()) 

80 raise web.HTTPInternalServerError() 

81 

82 serialized_response = grpc_response.SerializeToString() 

83 return web.Response(body=serialized_response) 

84 

85 return _query_build_events_handler 

86 

87 

88def list_operations_handler( 

89 context: Context, cache: ResponseCache 

90) -> Callable[[web.Request], Awaitable[web.Response]]: 

91 """Factory function which returns a handler for ListOperations. 

92 

93 The returned handler uses ``context.channel`` and ``context.instance_name`` 

94 to send a ListOperations request constructed based on the provided URL 

95 query parameters. 

96 

97 The handler returns a serialised ListOperationsResponse, raises a 400 

98 error in the case of a bad filter or other invalid argument, or raises 

99 a 500 error in the case of some other RPC error. 

100 

101 Args: 

102 context (Context): The context to use to send the gRPC request. 

103 

104 """ 

105 

106 async def _list_operations_handler(request: web.Request) -> web.Response: 

107 filter_string = request.rel_url.query.get("q", "") 

108 page_token = request.rel_url.query.get("page_token", "") 

109 page_size_str = request.rel_url.query.get("page_size") 

110 page_size = 0 

111 if page_size_str is not None: 

112 page_size = int(page_size_str) 

113 

114 LOGGER.info( 

115 f'Received ListOperations request, filter_string="{filter_string}" ' 

116 f'page_token="{page_token}" page_size="{page_size}"' 

117 ) 

118 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

119 grpc_request = operations_pb2.ListOperationsRequest( 

120 name=context.instance_name, page_token=page_token, page_size=page_size, filter=filter_string 

121 ) 

122 

123 try: 

124 grpc_response = await stub.ListOperations(grpc_request) 

125 except RpcError as e: 

126 LOGGER.warning(e.details()) 

127 if e.code() == StatusCode.INVALID_ARGUMENT: 

128 raise web.HTTPBadRequest() 

129 raise web.HTTPInternalServerError() 

130 

131 if grpc_response.operations: 

132 for i, operation in enumerate(grpc_response.operations): 

133 if i > BROWSER_OPERATION_CACHE_MAX_LENGTH: 

134 break 

135 await cache.store_operation(operation.name, operation) 

136 

137 serialised_response = grpc_response.SerializeToString() 

138 return web.Response(body=serialised_response) 

139 

140 return _list_operations_handler 

141 

142 

143async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]: 

144 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

145 

146 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

147 grpc_request = operations_pb2.GetOperationRequest(name=operation_name) 

148 

149 try: 

150 call = stub.GetOperation(grpc_request) 

151 operation = await call 

152 except RpcError as e: 

153 LOGGER.warning(f"Error fetching operation: {e.details()}") 

154 if e.code() == StatusCode.INVALID_ARGUMENT: 

155 raise web.HTTPNotFound() 

156 raise web.HTTPInternalServerError() 

157 

158 return operation, call 

159 

160 

161def get_operation_handler(context: Context, cache: ResponseCache) -> Callable[[web.Request], Awaitable[web.Response]]: 

162 """Factory function which returns a handler for GetOperation. 

163 

164 The returned handler uses ``context.channel`` and ``context.instance_name`` 

165 to send a GetOperation request constructed based on the path component of 

166 the URL. 

167 

168 The handler returns a serialised Operation message, raises a 400 error in 

169 the case of an invalid operation name, or raises a 500 error in the case 

170 of some other RPC error. 

171 

172 Args: 

173 context (Context): The context to use to send the gRPC request. 

174 

175 """ 

176 

177 async def _get_operation_handler(request: web.Request) -> web.Response: 

178 name = request.match_info["name"] 

179 LOGGER.info(f'Received GetOperation request for "{name}"') 

180 

181 # Check the cache to see if we already fetched this operation recently 

182 operation = await cache.get_operation(name) 

183 

184 # Fall back to sending an actual gRPC request 

185 if operation is None: 

186 operation, _ = await _get_operation(context, request) 

187 await cache.store_operation(name, operation) 

188 

189 serialised_response = operation.SerializeToString() 

190 return web.Response(body=serialised_response) 

191 

192 return _get_operation_handler 

193 

194 

195def get_operation_request_metadata_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

196 """Factory function which returns a handler to get RequestMetadata. 

197 

198 The returned handler uses ``context.channel`` and ``context.instance_name`` 

199 to send a GetOperation request constructed based on the path component of 

200 the URL. 

201 

202 The handler returns a serialised RequestMetadata proto message, retrieved 

203 from the trailing metadata of the GetOperation response. In the event of 

204 an invalid operation name it raises a 404 error, and raises a 500 error in 

205 the case of some other RPC error. 

206 

207 Args: 

208 context (Context): The context to use to send the gRPC request. 

209 

210 """ 

211 

212 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response: 

213 LOGGER.info(f'Received request for RequestMetadata for "{request.match_info["name"]}') 

214 _, call = await _get_operation(context, request) 

215 metadata = await call.trailing_metadata() 

216 

217 def extract_metadata(m: Metadata) -> RequestMetadata: 

218 # `m` contains a list of tuples, but `extract_request_metadata()` 

219 # expects a `key` and `value` attributes. 

220 MetadataTuple = namedtuple("MetadataTuple", ["key", "value"]) 

221 return extract_request_metadata([MetadataTuple(entry[0], entry[1]) for entry in m]) 

222 

223 request_metadata = extract_metadata(metadata) 

224 return web.Response(body=request_metadata.SerializeToString()) 

225 

226 return _get_operation_request_metadata_handler 

227 

228 

229def cancel_operation_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

230 """Factory function which returns a handler for CancelOperation. 

231 

232 The returned handler uses ``context.channel`` and ``context.instance_name`` 

233 to send a CancelOperation request constructed based on the path component of 

234 the URL. 

235 

236 The handler raises a 404 error in the case of an invalid operation name, 

237 or a 500 error in the case of some other RPC error. 

238 

239 On success, the response is empty. 

240 

241 Args: 

242 context (Context): The context to use to send the gRPC request. 

243 

244 """ 

245 

246 async def _cancel_operation_handler(request: web.Request) -> web.Response: 

247 LOGGER.info(f'Received CancelOperation request for "{request.match_info["name"]}"') 

248 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

249 

250 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

251 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name) 

252 

253 try: 

254 await stub.CancelOperation(grpc_request) 

255 return web.Response() 

256 except RpcError as e: 

257 LOGGER.warning(e.details()) 

258 if e.code() == StatusCode.INVALID_ARGUMENT: 

259 raise web.HTTPNotFound() 

260 raise web.HTTPInternalServerError() 

261 

262 return _cancel_operation_handler 

263 

264 

265async def _fetch_action_result( 

266 context: Context, request: web.Request, cache: ResponseCache, cache_key: str 

267) -> ActionResult: 

268 stub = ActionCacheStub(context.cache_channel) # type: ignore # Requires stub regen 

269 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

270 grpc_request = GetActionResultRequest(action_digest=digest, instance_name=context.instance_name) 

271 

272 try: 

273 result = await stub.GetActionResult(grpc_request) 

274 except RpcError as e: 

275 LOGGER.warning(f"Failed to fetch ActionResult: [{e.details()}]") 

276 if e.code() == StatusCode.NOT_FOUND: 

277 raise web.HTTPNotFound() 

278 raise web.HTTPInternalServerError() 

279 

280 await cache.store_action_result(cache_key, result) 

281 return result 

282 

283 

284def get_action_result_handler( 

285 context: Context, cache: ResponseCache, cache_capacity: int = 512 

286) -> Callable[[web.Request], Awaitable[web.Response]]: 

287 """Factory function which returns a handler for GetActionResult. 

288 

289 The returned handler uses ``context.channel`` and ``context.instance_name`` 

290 to send a GetActionResult request constructed based on the path components 

291 of the URL. 

292 

293 The handler returns a serialised ActionResult message, raises a 404 error 

294 if there's no result cached, or raises a 500 error in the case of some 

295 other RPC error. 

296 

297 Args: 

298 context (Context): The context to use to send the gRPC request. 

299 cache_capacity (int): The number of ActionResults to cache in memory 

300 to avoid hitting the actual ActionCache. 

301 

302 """ 

303 

304 class FetchSpec: 

305 """Simple class used to store information about a GetActionResult request. 

306 

307 A class is used here rather than a namedtuple since we need this state 

308 to be mutable. 

309 

310 """ 

311 

312 def __init__( 

313 self, *, error: Optional[Exception], event: asyncio.Event, result: Optional[ActionResult], refcount: int 

314 ): 

315 self.error = error 

316 self.event = event 

317 self.result = result 

318 self.refcount = refcount 

319 

320 in_flight_fetches: Dict[str, FetchSpec] = {} 

321 fetch_lock = asyncio.Lock() 

322 

323 async def _get_action_result_handler(request: web.Request) -> web.Response: 

324 LOGGER.info( 

325 "Received GetActionResult request for " 

326 f'"{request.match_info["hash"]}/{request.match_info["size_bytes"]}"' 

327 ) 

328 

329 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

330 result = await cache.get_action_result(cache_key) 

331 

332 if result is None: 

333 try: 

334 duplicate_request = False 

335 spec = None 

336 async with fetch_lock: 

337 if cache_key in in_flight_fetches: 

338 LOGGER.info(f"Deduplicating GetActionResult request for [{cache_key}]") 

339 spec = in_flight_fetches[cache_key] 

340 spec.refcount += 1 

341 duplicate_request = True 

342 else: 

343 spec = FetchSpec(error=None, event=asyncio.Event(), result=None, refcount=1) 

344 in_flight_fetches[cache_key] = spec 

345 

346 if duplicate_request and spec: 

347 # If we're a duplicate of an existing request, then wait for the 

348 # existing request to finish fetching from the ActionCache. 

349 await spec.event.wait() 

350 if spec.error is not None: 

351 raise spec.error 

352 if spec.result is None: 

353 # Note: this should be impossible, but lets guard against accidentally setting 

354 # the event before the result is populated 

355 LOGGER.info(f"Result not set in deduplicated request for [{cache_key}]") 

356 raise web.HTTPInternalServerError() 

357 result = spec.result 

358 else: 

359 try: 

360 result = await _fetch_action_result(context, request, cache, cache_key) 

361 except Exception as e: 

362 async with fetch_lock: 

363 if spec is not None: 

364 spec.error = e 

365 raise e 

366 

367 async with fetch_lock: 

368 if spec is not None: 

369 spec.result = result 

370 spec.event.set() 

371 

372 finally: 

373 async with fetch_lock: 

374 # Decrement refcount now we're done with the result. If we're the 

375 # last request interested in the result then remove it from the 

376 # `in_flight_fetches` dictionary. 

377 spec = in_flight_fetches.get(cache_key) 

378 if spec is not None: 

379 spec.refcount -= 1 

380 if spec.refcount <= 0: 

381 in_flight_fetches.pop(cache_key) 

382 

383 return web.Response(body=result.SerializeToString()) 

384 

385 return _get_action_result_handler 

386 

387 

388def get_blob_handler( 

389 context: Context, cache: ResponseCache, allow_all: bool = False, allowed_origins: UrlCollection = () 

390) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: 

391 async def _get_blob_handler(request: web.Request) -> web.StreamResponse: 

392 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

393 try: 

394 offset = int(request.rel_url.query.get("offset", "0")) 

395 limit = int(request.rel_url.query.get("limit", "0")) 

396 except ValueError: 

397 raise web.HTTPBadRequest() 

398 

399 response = web.StreamResponse() 

400 

401 # We need to explicitly set CORS headers here, because the middleware that 

402 # normally handles this only adds the header when the response is returned 

403 # from this function. However, when using a StreamResponse with chunked 

404 # encoding enabled, the client begins to receive the response when we call 

405 # `response.write()`. This leads to the request being disallowed due to the 

406 # missing reponse header for clients executing in browsers. 

407 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all) 

408 response.headers.update(cors_headers) 

409 

410 if request.rel_url.query.get("raw", "") == "1": 

411 response.headers["Content-type"] = "text/plain; charset=utf-8" 

412 else: 

413 # Setting the Content-Disposition header so that when 

414 # downloading a blob with a browser the default name uniquely 

415 # identifies its contents: 

416 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}" 

417 

418 # For partial reads also include the indices to prevent file mix-ups: 

419 if offset != 0 or limit != 0: 

420 filename += f"_chunk_{offset}-" 

421 if limit != 0: 

422 filename += f"{offset + limit}" 

423 

424 response.headers["Content-Disposition"] = f"Attachment;filename={filename}" 

425 

426 prepared = False 

427 

428 async def _callback(data: bytes) -> None: 

429 nonlocal prepared 

430 if not prepared: 

431 # Prepare for chunked encoding when the callback is first called, 

432 # so that we're sure we actually have some data before doing so. 

433 response.enable_chunked_encoding() 

434 await response.prepare(request) 

435 prepared = True 

436 await response.write(data) 

437 

438 await _fetch_blob(context, cache, digest, callback=_callback, offset=offset, limit=limit) 

439 

440 return response 

441 

442 return _get_blob_handler 

443 

444 

445def _create_tarball(directory: str, name: str) -> bool: 

446 """Makes a tarball from a given directory. 

447 

448 Returns True if the tarball was successfully created, and False if not. 

449 

450 Args: 

451 directory (str): The directory to tar up. 

452 name (str): The name of the tarball to be produced. 

453 

454 """ 

455 try: 

456 with tarfile.open(name, "w:gz") as tarball: 

457 tarball.add(directory, arcname="") 

458 except Exception: 

459 return False 

460 return True 

461 

462 

463async def _fetch_blob( 

464 context: Context, 

465 cache: ResponseCache, 

466 digest: Digest, 

467 message_class: Optional[Type[Any]] = None, 

468 callback: Optional[Callable[[bytes], Awaitable[Any]]] = None, 

469 offset: int = 0, 

470 limit: int = 0, 

471) -> Any: 

472 """Fetch a blob from CAS. 

473 

474 This function sends a ByteStream Read request for the given digest. If ``callback`` 

475 is set then the callback is called with the data in each ReadResponse message and 

476 this function returns an empty bytes object once the response is finished. If 

477 ``message_class`` is set with no ``callback`` set then this function calls 

478 ``message_class.FromString`` on the fetched blob and returns the result. 

479 

480 If neither ``callback`` or ``message_class`` are set then this function just returns 

481 the raw blob that was fetched from CAS. 

482 

483 Args: 

484 context (Context): The context to use to send the gRPC request. 

485 cache (ResponseCache): The response cache to check/update with the fetched blob. 

486 digest (Digest): The Digest of the blob to fetch from CAS. 

487 message_class (type): A class which implements a ``FromString`` class method. 

488 The method should take a bytes object expected to contain the blob fetched 

489 from CAS. 

490 callback (callable): A function or other callable to act on a subset of the 

491 blob contents. 

492 offset (int): Read offset to start reading the blob at. Defaults to 0, the start 

493 of the blob. 

494 limit (int): Maximum number of bytes to read from the blob. Defaults to 0, no 

495 limit. 

496 

497 """ 

498 cacheable = digest.size_bytes <= BROWSER_MAX_CACHE_ENTRY_SIZE 

499 resource_name = f"{context.instance_name}/blobs/{digest.hash}/{digest.size_bytes}" 

500 blob = None 

501 if cacheable: 

502 blob = await cache.get_blob(resource_name) 

503 if blob is not None and callback is not None: 

504 if limit > 0: 

505 try: 

506 blob = blob[offset : offset + limit] 

507 except IndexError: 

508 raise web.HTTPBadRequest() 

509 await callback(blob) 

510 

511 if blob is None: 

512 stub = ByteStreamStub(context.cas_channel) # type: ignore # Requires stub regen 

513 grpc_request = ReadRequest(resource_name=resource_name, read_offset=offset, read_limit=limit) 

514 

515 blob = b"" 

516 try: 

517 async for grpc_response in stub.Read(grpc_request): 

518 if grpc_response.data: 

519 if callback is not None: 

520 await callback(grpc_response.data) 

521 

522 if callback is None or cacheable: 

523 blob += grpc_response.data 

524 except RpcError as e: 

525 LOGGER.warning(e.details()) 

526 if e.code() == StatusCode.NOT_FOUND: 

527 raise web.HTTPNotFound() 

528 raise web.HTTPInternalServerError() 

529 

530 if cacheable: 

531 await cache.store_blob(resource_name, blob) 

532 

533 if message_class is not None and callback is None: 

534 return message_class.FromString(blob) 

535 return blob 

536 

537 

538async def _download_directory( 

539 context: Context, cache: ResponseCache, base: str, path: str, directory: Directory 

540) -> None: 

541 """Download the contents of a directory from CAS. 

542 

543 This function takes a Directory message and downloads the directory 

544 contents defined in the message from CAS. Raises a 400 error if the 

545 directory contains a symlink which points to a location outside the 

546 initial directory. 

547 

548 The directory is downloaded recursively depth-first. 

549 

550 Args: 

551 context (Context): The context to use for making gRPC requests. 

552 cache (ResponseCache): The response cache to use when downloading the 

553 directory contents. 

554 base (str): The initial directory path, used to check symlinks don't 

555 escape into the wider filesystem. 

556 path (str): The path to download the directory into. 

557 directory (Directory): The Directory message to fetch the contents of. 

558 

559 """ 

560 for directory_node in directory.directories: 

561 dir_path = os.path.join(path, directory_node.name) 

562 os.mkdir(dir_path) 

563 child = await _fetch_blob(context, cache, directory_node.digest, message_class=Directory) 

564 await _download_directory(context, cache, base, dir_path, child) 

565 

566 for file_node in directory.files: 

567 file_path = os.path.join(path, file_node.name) 

568 async with aiofiles.open(file_path, "wb") as f: 

569 await _fetch_blob(context, cache, file_node.digest, callback=f.write) 

570 if file_node.is_executable: 

571 os.chmod(file_path, 0o755) 

572 

573 for link_node in directory.symlinks: 

574 link_path = os.path.join(path, link_node.name) 

575 target_path = os.path.realpath(link_node.target) 

576 target_relpath = os.path.relpath(base, target_path) 

577 if target_relpath.startswith(os.pardir): 

578 raise web.HTTPBadRequest( 

579 reason="Requested directory contains a symlink targeting a location outside the tarball" 

580 ) 

581 

582 os.symlink(link_node.target, link_path) 

583 

584 

585async def _tarball_from_directory(context: Context, cache: ResponseCache, directory: Directory, tmp_dir: str) -> str: 

586 """Construct a tarball of a directory stored in CAS. 

587 

588 This function fetches the contents of the given directory message into a 

589 temporary directory, and then constructs a tarball of the directory. The 

590 path to this tarball is returned when construction is complete. 

591 

592 Args: 

593 context (Context): The context to use to send the gRPC requests. 

594 cache (ResponseCache): The response cache to use when fetching the 

595 tarball contents. 

596 directory (Directory): The Directory message for the directory we're 

597 making a tarball of. 

598 tmp_dir (str): Path to a temporary directory to use for storing the 

599 directory contents and its tarball. 

600 

601 """ 

602 tarball_dir = tempfile.mkdtemp(dir=tmp_dir) 

603 tarball_path = os.path.join(tmp_dir, "directory.tar.gz") 

604 loop = asyncio.get_event_loop() 

605 

606 # Fetch the contents of the directory into a temporary directory 

607 await _download_directory(context, cache, tarball_dir, tarball_dir, directory) 

608 

609 # Make a tarball from that temporary directory 

610 # NOTE: We do this using loop.run_in_executor to avoid the 

611 # synchronous and blocking tarball construction 

612 tarball_result = await loop.run_in_executor(None, _create_tarball, tarball_dir, tarball_path) 

613 if not tarball_result: 

614 raise web.HTTPInternalServerError() 

615 return tarball_path 

616 

617 

618def get_tarball_handler( 

619 context: Context, 

620 cache: ResponseCache, 

621 allow_all: bool = False, 

622 allowed_origins: UrlCollection = (), 

623 tarball_dir: Optional[str] = None, 

624) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: 

625 """Factory function which returns a handler for tarball requests. 

626 

627 This function also takes care of cleaning up old incomplete tarball constructions 

628 when given a named directory to do the construction in. 

629 

630 The returned handler takes the hash and size_bytes of a Digest of a Directory 

631 message and constructs a tarball of the directory defined by the message. 

632 

633 Args: 

634 context (Context): The context to use to send the gRPC requests. 

635 allow_all (bool): Whether or not to allow all CORS origins. 

636 allowed_origins (list): List of valid CORS origins. 

637 tarball_dir (str): Base directory to use for tarball construction. 

638 

639 """ 

640 

641 class FetchSpec: 

642 """Simple class used to store information about a tarball request. 

643 

644 A class is used here rather than a namedtuple since we need this state 

645 to be mutable. 

646 

647 """ 

648 

649 def __init__(self, *, error: Optional[Exception], event: Optional[asyncio.Event], path: str, refcount: int): 

650 self.error = error 

651 self.event = event 

652 self.path = path 

653 self.refcount = refcount 

654 

655 in_flight_fetches: Dict[str, FetchSpec] = {} 

656 fetch_lock = asyncio.Lock() 

657 

658 # If we have a tarball directory to use, empty all existing tarball constructions from it 

659 # to provide some form of cleanup after a crash. 

660 if tarball_dir is not None: 

661 for path in os.listdir(tarball_dir): 

662 if path.startswith(TARBALL_DIRECTORY_PREFIX): 

663 shutil.rmtree(os.path.join(tarball_dir, path)) 

664 

665 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse: 

666 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

667 LOGGER.info(f"Received request for a tarball from CAS for blob with digest [{digest_str}]") 

668 

669 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

670 

671 tmp_dir = tempfile.mkdtemp(prefix=TARBALL_DIRECTORY_PREFIX, dir=tarball_dir) 

672 

673 try: 

674 duplicate_request = False 

675 event = None 

676 async with fetch_lock: 

677 if digest_str in in_flight_fetches: 

678 LOGGER.info(f"Deduplicating request for tarball of [{digest_str}]") 

679 spec = in_flight_fetches[digest_str] 

680 spec.refcount += 1 

681 event = spec.event 

682 duplicate_request = True 

683 else: 

684 event = asyncio.Event() 

685 in_flight_fetches[digest_str] = FetchSpec(error=None, event=event, path="", refcount=1) 

686 

687 if duplicate_request and event: 

688 # If we're a duplicate of an existing request, then wait for the 

689 # existing request to finish tarball creation before reading the 

690 # path from the cache. 

691 await event.wait() 

692 spec = in_flight_fetches[digest_str] 

693 if spec.error is not None: 

694 raise spec.error 

695 tarball_path = in_flight_fetches[digest_str].path 

696 else: 

697 try: 

698 directory = await _fetch_blob(context, cache, digest, message_class=Directory) 

699 tarball_path = await _tarball_from_directory(context, cache, directory, tmp_dir) 

700 except web.HTTPError as e: 

701 in_flight_fetches[digest_str].error = e 

702 if event: 

703 event.set() 

704 raise e 

705 except Exception as e: 

706 LOGGER.debug("Unexpected error constructing tarball", exc_info=True) 

707 in_flight_fetches[digest_str].error = e 

708 if event: 

709 event.set() 

710 raise web.HTTPInternalServerError() 

711 

712 # Update path in deduplication cache, and set event to notify 

713 # duplicate requests that the tarball is ready 

714 async with fetch_lock: 

715 if event: 

716 in_flight_fetches[digest_str].path = tarball_path 

717 event.set() 

718 

719 response = web.StreamResponse() 

720 

721 # We need to explicitly set CORS headers here, because the middleware that 

722 # normally handles this only adds the header when the response is returned 

723 # from this function. However, when using a StreamResponse with chunked 

724 # encoding enabled, the client begins to receive the response when we call 

725 # `response.write()`. This leads to the request being disallowed due to the 

726 # missing reponse header for clients executing in browsers. 

727 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all) 

728 response.headers.update(cors_headers) 

729 

730 response.enable_chunked_encoding() 

731 await response.prepare(request) 

732 

733 async with aiofiles.open(tarball_path, "rb") as tarball: 

734 await tarball.seek(0) 

735 chunk = await tarball.read(1024) 

736 while chunk: 

737 await response.write(chunk) 

738 chunk = await tarball.read(1024) 

739 return response 

740 

741 except RpcError as e: 

742 LOGGER.warning(e.details()) 

743 if e.code() == StatusCode.NOT_FOUND: 

744 raise web.HTTPNotFound() 

745 raise web.HTTPInternalServerError() 

746 

747 finally: 

748 cleanup = False 

749 async with fetch_lock: 

750 # Decrement refcount now we're done with the tarball. If we're the 

751 # last request interested in the tarball then remove it along with 

752 # its construction directory. 

753 spec = in_flight_fetches[digest_str] 

754 spec.refcount -= 1 

755 if spec.refcount <= 0: 

756 cleanup = True 

757 in_flight_fetches.pop(digest_str) 

758 if cleanup: 

759 shutil.rmtree(tmp_dir) 

760 

761 return _get_tarball_handler 

762 

763 

764def logstream_handler(context: Context) -> Callable[[web.Request], Awaitable[Any]]: 

765 async def _logstream_handler(request: web.Request) -> Any: 

766 LOGGER.info("Receieved request for a LogStream websocket") 

767 stub = ByteStreamStub(context.logstream_channel) # type: ignore # Requires stub regen 

768 ws = web.WebSocketResponse() 

769 await ws.prepare(request) 

770 

771 async for msg in ws: 

772 if msg.type == WSMsgType.BINARY: 

773 read_request = ReadRequest() 

774 read_request.ParseFromString(msg.data) 

775 

776 read_request.resource_name = f"{context.instance_name}/{read_request.resource_name}" 

777 try: 

778 async for response in stub.Read(read_request): 

779 serialized_response = response.SerializeToString() 

780 if serialized_response: 

781 ws_response = { 

782 "resource_name": read_request.resource_name, 

783 "data": b64encode(serialized_response).decode("utf-8"), 

784 "complete": False, 

785 } 

786 await ws.send_json(ws_response) 

787 ws_response = {"resource_name": read_request.resource_name, "data": "", "complete": True} 

788 await ws.send_json(ws_response) 

789 except RpcError as e: 

790 LOGGER.warning(e.details()) 

791 if e.code() == StatusCode.NOT_FOUND: 

792 ws_response = { 

793 "resource_name": read_request.resource_name, 

794 "data": "NOT_FOUND", 

795 "complete": True, 

796 } 

797 await ws.send_json(ws_response) 

798 ws_response = {"resource_name": read_request.resource_name, "data": "INTERNAL", "complete": True} 

799 await ws.send_json(ws_response) 

800 

801 return _logstream_handler