Coverage for /builds/BuildGrid/buildgrid/buildgrid/browser/rest_api.py: 18.35%

387 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-22 21:04 +0000

1# Copyright (C) 2021 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License' is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import asyncio 

16from base64 import b64encode 

17from collections import namedtuple 

18import logging 

19import os 

20import shutil 

21import tarfile 

22import tempfile 

23from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type 

24 

25import aiofiles 

26from aiohttp import web, WSMsgType 

27from grpc import RpcError, StatusCode 

28from grpc.aio import Call # type: ignore 

29 

30from buildgrid._app.cli import Context 

31from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

32 ActionResult, 

33 Digest, 

34 Directory, 

35 GetActionResultRequest 

36) 

37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc import ActionCacheStub 

38from buildgrid._protos.buildgrid.v2.query_build_events_pb2 import QueryEventStreamsRequest 

39from buildgrid._protos.buildgrid.v2.query_build_events_pb2_grpc import QueryBuildEventsStub 

40from buildgrid._protos.google.bytestream.bytestream_pb2_grpc import ByteStreamStub 

41from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest 

42from buildgrid._protos.google.longrunning.operations_pb2_grpc import OperationsStub 

43from buildgrid._protos.google.longrunning import operations_pb2 

44from buildgrid.browser.utils import TARBALL_DIRECTORY_PREFIX, ResponseCache, get_cors_headers 

45from buildgrid.server.request_metadata_utils import extract_request_metadata 

46from buildgrid.settings import BROWSER_MAX_CACHE_ENTRY_SIZE, BROWSER_OPERATION_CACHE_MAX_LENGTH 

47 

48 

49LOGGER = logging.getLogger(__name__) 

50 

51 

52def query_build_events_handler(context: Context) -> Callable: 

53 """Factory function which returns a handler for QueryEventStreams. 

54 

55 The returned handler uses ``context.channel`` to send a QueryEventStreams 

56 request constructed based on the provided URL query parameters. Currently 

57 only querying by build_id (equivalent to correlated invocations ID) is 

58 supported. 

59 

60 The handler returns a serialised QueryEventStreamsResponse, and raises a 

61 500 error in the case of some RPC error. 

62 

63 Args: 

64 context (Context): The context to use to send the gRPC request. 

65 

66 """ 

67 async def _query_build_events_handler(request: web.Request) -> web.Response: 

68 build_id = request.rel_url.query.get('build_id', '.*') 

69 

70 stub = QueryBuildEventsStub(context.channel) 

71 

72 grpc_request = QueryEventStreamsRequest( 

73 build_id_pattern=build_id 

74 ) 

75 

76 try: 

77 grpc_response = await stub.QueryEventStreams(grpc_request) 

78 except RpcError as e: 

79 LOGGER.warning(e.details()) 

80 raise web.HTTPInternalServerError() 

81 

82 serialized_response = grpc_response.SerializeToString() 

83 return web.Response(body=serialized_response) 

84 return _query_build_events_handler 

85 

86 

87def list_operations_handler(context: Context, cache: ResponseCache) -> Callable: 

88 """Factory function which returns a handler for ListOperations. 

89 

90 The returned handler uses ``context.channel`` and ``context.instance_name`` 

91 to send a ListOperations request constructed based on the provided URL 

92 query parameters. 

93 

94 The handler returns a serialised ListOperationsResponse, raises a 400 

95 error in the case of a bad filter or other invalid argument, or raises 

96 a 500 error in the case of some other RPC error. 

97 

98 Args: 

99 context (Context): The context to use to send the gRPC request. 

100 

101 """ 

102 async def _list_operations_handler(request: web.Request) -> web.Response: 

103 filter_string = request.rel_url.query.get('q', '') 

104 page_token = request.rel_url.query.get('page_token') 

105 page_size_str = request.rel_url.query.get('page_size') 

106 page_size = None 

107 if page_size_str is not None: 

108 page_size = int(page_size_str) 

109 

110 LOGGER.info(f'Received ListOperations request, filter_string="{filter_string}" ' 

111 f'page_token="{page_token}" page_size="{page_size}"') 

112 stub = OperationsStub(context.operations_channel) 

113 grpc_request = operations_pb2.ListOperationsRequest( 

114 name=context.instance_name, 

115 page_token=page_token, 

116 page_size=page_size, 

117 filter=filter_string) 

118 

119 try: 

120 grpc_response = await stub.ListOperations(grpc_request) 

121 except RpcError as e: 

122 LOGGER.warning(e.details()) 

123 if e.code() == StatusCode.INVALID_ARGUMENT: 

124 raise web.HTTPBadRequest() 

125 raise web.HTTPInternalServerError() 

126 

127 if grpc_response.operations: 

128 for i, operation in enumerate(grpc_response.operations): 

129 if i > BROWSER_OPERATION_CACHE_MAX_LENGTH: 

130 break 

131 await cache.store_operation(operation.name, operation) 

132 

133 serialised_response = grpc_response.SerializeToString() 

134 return web.Response(body=serialised_response) 

135 return _list_operations_handler 

136 

137 

138async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]: 

139 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

140 

141 stub = OperationsStub(context.operations_channel) 

142 grpc_request = operations_pb2.GetOperationRequest(name=operation_name) 

143 

144 try: 

145 call = stub.GetOperation(grpc_request) 

146 operation = await call 

147 except RpcError as e: 

148 LOGGER.warning(f"Error fetching operation: {e.details()}") 

149 if e.code() == StatusCode.INVALID_ARGUMENT: 

150 raise web.HTTPNotFound() 

151 raise web.HTTPInternalServerError() 

152 

153 return operation, call 

154 

155 

156def get_operation_handler(context: Context, cache: ResponseCache) -> Callable: 

157 """Factory function which returns a handler for GetOperation. 

158 

159 The returned handler uses ``context.channel`` and ``context.instance_name`` 

160 to send a GetOperation request constructed based on the path component of 

161 the URL. 

162 

163 The handler returns a serialised Operation message, raises a 400 error in 

164 the case of an invalid operation name, or raises a 500 error in the case 

165 of some other RPC error. 

166 

167 Args: 

168 context (Context): The context to use to send the gRPC request. 

169 

170 """ 

171 async def _get_operation_handler(request: web.Request) -> web.Response: 

172 name = request.match_info["name"] 

173 LOGGER.info(f'Received GetOperation request for "{name}"') 

174 

175 # Check the cache to see if we already fetched this operation recently 

176 operation = await cache.get_operation(name) 

177 

178 # Fall back to sending an actual gRPC request 

179 if operation is None: 

180 operation, _ = await _get_operation(context, request) 

181 await cache.store_operation(name, operation) 

182 

183 serialised_response = operation.SerializeToString() 

184 return web.Response(body=serialised_response) 

185 return _get_operation_handler 

186 

187 

188def get_operation_request_metadata_handler(context: Context) -> Callable: 

189 """Factory function which returns a handler to get RequestMetadata. 

190 

191 The returned handler uses ``context.channel`` and ``context.instance_name`` 

192 to send a GetOperation request constructed based on the path component of 

193 the URL. 

194 

195 The handler returns a serialised RequestMetadata proto message, retrieved 

196 from the trailing metadata of the GetOperation response. In the event of 

197 an invalid operation name it raises a 404 error, and raises a 500 error in 

198 the case of some other RPC error. 

199 

200 Args: 

201 context (Context): The context to use to send the gRPC request. 

202 

203 """ 

204 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response: 

205 LOGGER.info(f'Received request for RequestMetadata for "{request.match_info["name"]}') 

206 _, call = await _get_operation(context, request) 

207 metadata = await call.trailing_metadata() 

208 

209 def extract_metadata(m): 

210 # `m` contains a list of tuples, but `extract_request_metadata()` 

211 # expects a `key` and `value` attributes. 

212 MetadataTuple = namedtuple('MetadataTuple', ['key', 'value']) 

213 m = [MetadataTuple(entry[0], entry[1]) for entry in m] 

214 return extract_request_metadata(m) 

215 

216 request_metadata = extract_metadata(metadata) 

217 return web.Response(body=request_metadata.SerializeToString()) 

218 return _get_operation_request_metadata_handler 

219 

220 

221def cancel_operation_handler(context: Context) -> Callable: 

222 """Factory function which returns a handler for CancelOperation. 

223 

224 The returned handler uses ``context.channel`` and ``context.instance_name`` 

225 to send a CancelOperation request constructed based on the path component of 

226 the URL. 

227 

228 The handler raises a 404 error in the case of an invalid operation name, 

229 or a 500 error in the case of some other RPC error. 

230 

231 On success, the response is empty. 

232 

233 Args: 

234 context (Context): The context to use to send the gRPC request. 

235 

236 """ 

237 async def _cancel_operation_handler(request: web.Request) -> web.Response: 

238 LOGGER.info(f'Received CancelOperation request for "{request.match_info["name"]}"') 

239 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

240 

241 stub = OperationsStub(context.operations_channel) 

242 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name) 

243 

244 try: 

245 await stub.CancelOperation(grpc_request) 

246 return web.Response() 

247 except RpcError as e: 

248 LOGGER.warning(e.details()) 

249 if e.code() == StatusCode.INVALID_ARGUMENT: 

250 raise web.HTTPNotFound() 

251 raise web.HTTPInternalServerError() 

252 

253 return _cancel_operation_handler 

254 

255 

256async def _fetch_action_result( 

257 context: Context, 

258 request: web.Request, 

259 cache: ResponseCache, 

260 cache_key: str 

261) -> ActionResult: 

262 stub = ActionCacheStub(context.cache_channel) 

263 digest = Digest( 

264 hash=request.match_info["hash"], 

265 size_bytes=int(request.match_info["size_bytes"]) 

266 ) 

267 grpc_request = GetActionResultRequest( 

268 action_digest=digest, instance_name=context.instance_name) 

269 

270 try: 

271 result = await stub.GetActionResult(grpc_request) 

272 except RpcError as e: 

273 LOGGER.warning(f"Failed to fetch ActionResult: [{e.details()}]") 

274 if e.code() == StatusCode.NOT_FOUND: 

275 raise web.HTTPNotFound() 

276 raise web.HTTPInternalServerError() 

277 

278 await cache.store_action_result(cache_key, result) 

279 

280 return result 

281 

282 

283def get_action_result_handler(context: Context, cache: ResponseCache, cache_capacity: int=512) -> Callable: 

284 """Factory function which returns a handler for GetActionResult. 

285 

286 The returned handler uses ``context.channel`` and ``context.instance_name`` 

287 to send a GetActionResult request constructed based on the path components 

288 of the URL. 

289 

290 The handler returns a serialised ActionResult message, raises a 404 error 

291 if there's no result cached, or raises a 500 error in the case of some 

292 other RPC error. 

293 

294 Args: 

295 context (Context): The context to use to send the gRPC request. 

296 cache_capacity (int): The number of ActionResults to cache in memory 

297 to avoid hitting the actual ActionCache. 

298 

299 """ 

300 class FetchSpec: 

301 """Simple class used to store information about a GetActionResult request. 

302 

303 A class is used here rather than a namedtuple since we need this state 

304 to be mutable. 

305 

306 """ 

307 def __init__(self, *, 

308 error: Optional[Exception], 

309 event: asyncio.Event, 

310 result: Optional[ActionResult], 

311 refcount: int): 

312 self.error = error 

313 self.event = event 

314 self.result = result 

315 self.refcount = refcount 

316 

317 in_flight_fetches: Dict[str, FetchSpec] = {} 

318 fetch_lock = asyncio.Lock() 

319 

320 async def _get_action_result_handler(request: web.Request) -> web.Response: 

321 LOGGER.info( 

322 'Received GetActionResult request for ' 

323 f'"{request.match_info["hash"]}/{request.match_info["size_bytes"]}"' 

324 ) 

325 

326 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

327 result = await cache.get_action_result(cache_key) 

328 

329 if result is None: 

330 try: 

331 duplicate_request = False 

332 spec = None 

333 async with fetch_lock: 

334 if cache_key in in_flight_fetches: 

335 LOGGER.info(f'Deduplicating GetActionResult request for [{cache_key}]') 

336 spec = in_flight_fetches[cache_key] 

337 spec.refcount += 1 

338 duplicate_request = True 

339 else: 

340 spec = FetchSpec( 

341 error=None, 

342 event=asyncio.Event(), 

343 result=None, 

344 refcount=1 

345 ) 

346 in_flight_fetches[cache_key] = spec 

347 

348 if duplicate_request and spec: 

349 # If we're a duplicate of an existing request, then wait for the 

350 # existing request to finish fetching from the ActionCache. 

351 await spec.event.wait() 

352 if spec.error is not None: 

353 raise spec.error 

354 if spec.result is None: 

355 # Note: this should be impossible, but lets guard against accidentally setting 

356 # the event before the result is populated 

357 LOGGER.info(f'Result not set in deduplicated request for [{cache_key}]') 

358 raise web.HTTPInternalServerError() 

359 result = spec.result 

360 else: 

361 try: 

362 result = await _fetch_action_result(context, request, cache, cache_key) 

363 except Exception as e: 

364 async with fetch_lock: 

365 if spec is not None: 

366 spec.error = e 

367 raise e 

368 

369 async with fetch_lock: 

370 if spec is not None: 

371 spec.result = result 

372 spec.event.set() 

373 

374 finally: 

375 async with fetch_lock: 

376 # Decrement refcount now we're done with the result. If we're the 

377 # last request interested in the result then remove it from the 

378 # `in_flight_fetches` dictionary. 

379 spec = in_flight_fetches.get(cache_key) 

380 if spec is not None: 

381 spec.refcount -= 1 

382 if spec.refcount <= 0: 

383 in_flight_fetches.pop(cache_key) 

384 

385 return web.Response(body=result.SerializeToString()) 

386 return _get_action_result_handler 

387 

388 

389def get_blob_handler( 

390 context: Context, 

391 cache: ResponseCache, 

392 allow_all: bool=False, 

393 allowed_origins: Sequence=() 

394) -> Callable: 

395 async def _get_blob_handler(request: web.Request) -> web.StreamResponse: 

396 digest = Digest( 

397 hash=request.match_info['hash'], 

398 size_bytes=int(request.match_info['size_bytes']) 

399 ) 

400 try: 

401 offset = int(request.rel_url.query.get('offset', '0')) 

402 limit = int(request.rel_url.query.get('limit', '0')) 

403 except ValueError: 

404 raise web.HTTPBadRequest() 

405 

406 response = web.StreamResponse() 

407 

408 # We need to explicitly set CORS headers here, because the middleware that 

409 # normally handles this only adds the header when the response is returned 

410 # from this function. However, when using a StreamResponse with chunked 

411 # encoding enabled, the client begins to receive the response when we call 

412 # `response.write()`. This leads to the request being disallowed due to the 

413 # missing reponse header for clients executing in browsers. 

414 cors_headers = get_cors_headers( 

415 request.headers.get('Origin'), allowed_origins, allow_all) 

416 response.headers.update(cors_headers) 

417 

418 if request.rel_url.query.get('raw', '') == '1': 

419 response.headers['Content-type'] = 'text/plain; charset=utf-8' 

420 else: 

421 # Setting the Content-Disposition header so that when 

422 # downloading a blob with a browser the default name uniquely 

423 # identifies its contents: 

424 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}" 

425 

426 # For partial reads also include the indices to prevent file mix-ups: 

427 if offset != 0 or limit != 0: 

428 filename += f"_chunk_{offset}-" 

429 if limit != 0: 

430 filename += f"{offset + limit}" 

431 

432 response.headers['Content-Disposition'] = f'Attachment;filename={filename}' 

433 

434 prepared = False 

435 

436 async def _callback(data): 

437 nonlocal prepared 

438 if not prepared: 

439 # Prepare for chunked encoding when the callback is first called, 

440 # so that we're sure we actually have some data before doing so. 

441 response.enable_chunked_encoding() 

442 await response.prepare(request) 

443 prepared = True 

444 await response.write(data) 

445 await _fetch_blob(context, cache, digest, callback=_callback, offset=offset, limit=limit) 

446 

447 return response 

448 return _get_blob_handler 

449 

450 

451def _create_tarball(directory: str, name: str) -> bool: 

452 """Makes a tarball from a given directory. 

453 

454 Returns True if the tarball was successfully created, and False if not. 

455 

456 Args: 

457 directory (str): The directory to tar up. 

458 name (str): The name of the tarball to be produced. 

459 

460 """ 

461 try: 

462 with tarfile.open(name, 'w:gz') as tarball: 

463 tarball.add(directory, arcname="") 

464 except Exception: 

465 return False 

466 return True 

467 

468 

469async def _fetch_blob( 

470 context: Context, 

471 cache: ResponseCache, 

472 digest: Digest, 

473 message_class: Optional[Type]=None, 

474 callback: Optional[Callable]=None, 

475 offset: int=0, 

476 limit: int=0 

477) -> Any: 

478 """Fetch a blob from CAS. 

479 

480 This function sends a ByteStream Read request for the given digest. If ``callback`` 

481 is set then the callback is called with the data in each ReadResponse message and 

482 this function returns an empty bytes object once the response is finished. If 

483 ``message_class`` is set with no ``callback`` set then this function calls 

484 ``message_class.FromString`` on the fetched blob and returns the result. 

485 

486 If neither ``callback`` or ``message_class`` are set then this function just returns 

487 the raw blob that was fetched from CAS. 

488 

489 Args: 

490 context (Context): The context to use to send the gRPC request. 

491 cache (ResponseCache): The response cache to check/update with the fetched blob. 

492 digest (Digest): The Digest of the blob to fetch from CAS. 

493 message_class (type): A class which implements a ``FromString`` class method. 

494 The method should take a bytes object expected to contain the blob fetched 

495 from CAS. 

496 callback (callable): A function or other callable to act on a subset of the 

497 blob contents. 

498 offset (int): Read offset to start reading the blob at. Defaults to 0, the start 

499 of the blob. 

500 limit (int): Maximum number of bytes to read from the blob. Defaults to 0, no 

501 limit. 

502 

503 """ 

504 cacheable = digest.size_bytes <= BROWSER_MAX_CACHE_ENTRY_SIZE 

505 resource_name = f"{context.instance_name}/blobs/{digest.hash}/{digest.size_bytes}" 

506 blob = None 

507 if cacheable: 

508 blob = await cache.get_blob(resource_name) 

509 if blob is not None and callback is not None: 

510 if limit > 0: 

511 try: 

512 blob = blob[offset:offset + limit] 

513 except IndexError: 

514 raise web.HTTPBadRequest() 

515 await callback(blob) 

516 

517 if blob is None: 

518 stub = ByteStreamStub(context.cas_channel) 

519 grpc_request = ReadRequest( 

520 resource_name=resource_name, 

521 read_offset=offset, 

522 read_limit=limit 

523 ) 

524 

525 blob = b'' 

526 try: 

527 async for grpc_response in stub.Read(grpc_request): 

528 if grpc_response.data: 

529 if callback is not None: 

530 await callback(grpc_response.data) 

531 

532 if callback is None or cacheable: 

533 blob += grpc_response.data 

534 except RpcError as e: 

535 LOGGER.warning(e.details()) 

536 if e.code() == StatusCode.NOT_FOUND: 

537 raise web.HTTPNotFound() 

538 raise web.HTTPInternalServerError() 

539 

540 if cacheable: 

541 await cache.store_blob(resource_name, blob) # type: ignore 

542 

543 if message_class is not None and callback is None: 

544 return message_class.FromString(blob) 

545 return blob 

546 

547 

548async def _download_directory( 

549 context: Context, 

550 cache: ResponseCache, 

551 base: str, 

552 path: str, 

553 directory: Directory 

554) -> None: 

555 """Download the contents of a directory from CAS. 

556 

557 This function takes a Directory message and downloads the directory 

558 contents defined in the message from CAS. Raises a 400 error if the 

559 directory contains a symlink which points to a location outside the 

560 initial directory. 

561 

562 The directory is downloaded recursively depth-first. 

563 

564 Args: 

565 context (Context): The context to use for making gRPC requests. 

566 cache (ResponseCache): The response cache to use when downloading the 

567 directory contents. 

568 base (str): The initial directory path, used to check symlinks don't 

569 escape into the wider filesystem. 

570 path (str): The path to download the directory into. 

571 directory (Directory): The Directory message to fetch the contents of. 

572 

573 """ 

574 for directory_node in directory.directories: 

575 dir_path = os.path.join(path, directory_node.name) 

576 os.mkdir(dir_path) 

577 child = await _fetch_blob(context, cache, directory_node.digest, message_class=Directory) 

578 await _download_directory(context, cache, base, dir_path, child) 

579 

580 for file_node in directory.files: 

581 file_path = os.path.join(path, file_node.name) 

582 async with aiofiles.open(file_path, 'wb') as f: 

583 await _fetch_blob(context, cache, file_node.digest, callback=f.write) 

584 if file_node.is_executable: 

585 os.chmod(file_path, 0o755) 

586 

587 for link_node in directory.symlinks: 

588 link_path = os.path.join(path, link_node.name) 

589 target_path = os.path.realpath(link_node.target) 

590 target_relpath = os.path.relpath(base, target_path) 

591 if target_relpath.startswith(os.pardir): 

592 raise web.HTTPBadRequest( 

593 reason="Requested directory contains a symlink targeting a location outside the tarball") 

594 

595 os.symlink(link_node.target, link_path) 

596 

597 

598async def _tarball_from_directory(context: Context, cache: ResponseCache, directory: Directory, tmp_dir: str) -> str: 

599 """Construct a tarball of a directory stored in CAS. 

600 

601 This function fetches the contents of the given directory message into a 

602 temporary directory, and then constructs a tarball of the directory. The 

603 path to this tarball is returned when construction is complete. 

604 

605 Args: 

606 context (Context): The context to use to send the gRPC requests. 

607 cache (ResponseCache): The response cache to use when fetching the 

608 tarball contents. 

609 directory (Directory): The Directory message for the directory we're 

610 making a tarball of. 

611 tmp_dir (str): Path to a temporary directory to use for storing the 

612 directory contents and its tarball. 

613 

614 """ 

615 tarball_dir = tempfile.mkdtemp(dir=tmp_dir) 

616 tarball_path = os.path.join(tmp_dir, 'directory.tar.gz') 

617 loop = asyncio.get_event_loop() 

618 

619 # Fetch the contents of the directory into a temporary directory 

620 await _download_directory(context, cache, tarball_dir, tarball_dir, directory) 

621 

622 # Make a tarball from that temporary directory 

623 # NOTE: We do this using loop.run_in_executor to avoid the 

624 # synchronous and blocking tarball construction 

625 tarball_result = await loop.run_in_executor( 

626 None, _create_tarball, tarball_dir, tarball_path) 

627 if not tarball_result: 

628 raise web.HTTPInternalServerError() 

629 return tarball_path 

630 

631 

632def get_tarball_handler(context: Context, cache: ResponseCache, allow_all: bool=False, 

633 allowed_origins: Sequence=(), tarball_dir: Optional[str]=None) -> Callable: 

634 """Factory function which returns a handler for tarball requests. 

635 

636 This function also takes care of cleaning up old incomplete tarball constructions 

637 when given a named directory to do the construction in. 

638 

639 The returned handler takes the hash and size_bytes of a Digest of a Directory 

640 message and constructs a tarball of the directory defined by the message. 

641 

642 Args: 

643 context (Context): The context to use to send the gRPC requests. 

644 allow_all (bool): Whether or not to allow all CORS origins. 

645 allowed_origins (list): List of valid CORS origins. 

646 tarball_dir (str): Base directory to use for tarball construction. 

647 

648 """ 

649 class FetchSpec: 

650 """Simple class used to store information about a tarball request. 

651 

652 A class is used here rather than a namedtuple since we need this state 

653 to be mutable. 

654 

655 """ 

656 def __init__(self, *, 

657 error: Optional[Exception], 

658 event: Optional[asyncio.Event], 

659 path: str, 

660 refcount: int): 

661 self.error = error 

662 self.event = event 

663 self.path = path 

664 self.refcount = refcount 

665 

666 in_flight_fetches: Dict[str, FetchSpec] = {} 

667 fetch_lock = asyncio.Lock() 

668 

669 # If we have a tarball directory to use, empty all existing tarball constructions from it 

670 # to provide some form of cleanup after a crash. 

671 if tarball_dir is not None: 

672 for path in os.listdir(tarball_dir): 

673 if path.startswith(TARBALL_DIRECTORY_PREFIX): 

674 shutil.rmtree(os.path.join(tarball_dir, path)) 

675 

676 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse: 

677 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

678 LOGGER.info(f'Received request for a tarball from CAS for blob with digest [{digest_str}]') 

679 

680 digest = Digest( 

681 hash=request.match_info['hash'], 

682 size_bytes=int(request.match_info['size_bytes']) 

683 ) 

684 

685 tmp_dir = tempfile.mkdtemp(prefix=TARBALL_DIRECTORY_PREFIX, dir=tarball_dir) 

686 

687 try: 

688 duplicate_request = False 

689 event = None 

690 async with fetch_lock: 

691 if digest_str in in_flight_fetches: 

692 LOGGER.info(f'Deduplicating request for tarball of [{digest_str}]') 

693 spec = in_flight_fetches[digest_str] 

694 spec.refcount += 1 

695 event = spec.event 

696 duplicate_request = True 

697 else: 

698 event = asyncio.Event() 

699 in_flight_fetches[digest_str] = FetchSpec( 

700 error=None, 

701 event=event, 

702 path='', 

703 refcount=1 

704 ) 

705 

706 if duplicate_request and event: 

707 # If we're a duplicate of an existing request, then wait for the 

708 # existing request to finish tarball creation before reading the 

709 # path from the cache. 

710 await event.wait() 

711 spec = in_flight_fetches[digest_str] 

712 if spec.error is not None: 

713 raise spec.error 

714 tarball_path = in_flight_fetches[digest_str].path 

715 else: 

716 try: 

717 directory = await _fetch_blob(context, cache, digest, message_class=Directory) 

718 tarball_path = await _tarball_from_directory(context, cache, directory, tmp_dir) 

719 except web.HTTPError as e: 

720 in_flight_fetches[digest_str].error = e 

721 if event: 

722 event.set() 

723 raise e 

724 except Exception as e: 

725 LOGGER.debug("Unexpected error constructing tarball", exc_info=True) 

726 in_flight_fetches[digest_str].error = e 

727 if event: 

728 event.set() 

729 raise web.HTTPInternalServerError() 

730 

731 # Update path in deduplication cache, and set event to notify 

732 # duplicate requests that the tarball is ready 

733 async with fetch_lock: 

734 if event: 

735 in_flight_fetches[digest_str].path = tarball_path 

736 event.set() 

737 

738 response = web.StreamResponse() 

739 

740 # We need to explicitly set CORS headers here, because the middleware that 

741 # normally handles this only adds the header when the response is returned 

742 # from this function. However, when using a StreamResponse with chunked 

743 # encoding enabled, the client begins to receive the response when we call 

744 # `response.write()`. This leads to the request being disallowed due to the 

745 # missing reponse header for clients executing in browsers. 

746 cors_headers = get_cors_headers( 

747 request.headers.get('Origin'), allowed_origins, allow_all) 

748 response.headers.update(cors_headers) 

749 

750 response.enable_chunked_encoding() 

751 await response.prepare(request) 

752 

753 async with aiofiles.open(tarball_path, 'rb') as tarball: 

754 await tarball.seek(0) 

755 chunk = await tarball.read(1024) 

756 while chunk: 

757 await response.write(chunk) 

758 chunk = await tarball.read(1024) 

759 return response 

760 

761 except RpcError as e: 

762 LOGGER.warning(e.details()) 

763 if e.code() == StatusCode.NOT_FOUND: 

764 raise web.HTTPNotFound() 

765 raise web.HTTPInternalServerError() 

766 

767 finally: 

768 cleanup = False 

769 async with fetch_lock: 

770 # Decrement refcount now we're done with the tarball. If we're the 

771 # last request interested in the tarball then remove it along with 

772 # its construction directory. 

773 spec = in_flight_fetches[digest_str] 

774 spec.refcount -= 1 

775 if spec.refcount <= 0: 

776 cleanup = True 

777 in_flight_fetches.pop(digest_str) 

778 if cleanup: 

779 shutil.rmtree(tmp_dir) 

780 return _get_tarball_handler 

781 

782 

783def logstream_handler(context: Context) -> Callable: 

784 async def _logstream_handler(request): 

785 LOGGER.info('Receieved request for a LogStream websocket') 

786 stub = ByteStreamStub(context.logstream_channel) 

787 ws = web.WebSocketResponse() 

788 await ws.prepare(request) 

789 

790 async for msg in ws: 

791 if msg.type == WSMsgType.BINARY: 

792 read_request = ReadRequest() 

793 read_request.ParseFromString(msg.data) 

794 

795 read_request.resource_name = f'{context.instance_name}/{read_request.resource_name}' 

796 try: 

797 async for response in stub.Read(read_request): 

798 serialized_response = response.SerializeToString() 

799 if serialized_response: 

800 ws_response = { 

801 "resource_name": read_request.resource_name, 

802 "data": b64encode(serialized_response).decode('utf-8'), 

803 "complete": False 

804 } 

805 await ws.send_json(ws_response) 

806 ws_response = { 

807 "resource_name": read_request.resource_name, 

808 "data": "", 

809 "complete": True 

810 } 

811 await ws.send_json(ws_response) 

812 except RpcError as e: 

813 LOGGER.warning(e.details()) 

814 if e.code() == StatusCode.NOT_FOUND: 

815 ws_response = { 

816 "resource_name": read_request.resource_name, 

817 "data": "NOT_FOUND", 

818 "complete": True 

819 } 

820 await ws.send_json(ws_response) 

821 ws_response = { 

822 "resource_name": read_request.resource_name, 

823 "data": "INTERNAL", 

824 "complete": True 

825 } 

826 await ws.send_json(ws_response) 

827 return _logstream_handler