Coverage for /builds/BuildGrid/buildgrid/buildgrid/browser/rest_api.py: 87.11%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

357 statements  

1# Copyright (C) 2021 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License' is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import asyncio 

16from base64 import b64encode 

17from collections import namedtuple 

18import logging 

19import os 

20import shutil 

21import tarfile 

22import tempfile 

23from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type 

24 

25import aiofiles 

26from aiohttp import web, WSMsgType 

27from grpc import RpcError, StatusCode 

28from grpc.aio import Call # type: ignore 

29 

30from buildgrid._app.cli import Context 

31from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

32from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc import ActionCacheStub 

33from buildgrid._protos.google.bytestream.bytestream_pb2_grpc import ByteStreamStub 

34from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest 

35from buildgrid._protos.google.longrunning.operations_pb2_grpc import OperationsStub 

36from buildgrid._protos.google.longrunning import operations_pb2 

37from buildgrid.browser import utils 

38from buildgrid.server.request_metadata_utils import extract_request_metadata 

39from buildgrid.server.lru_inmemory_cache import LruInMemoryCache 

40 

41 

42LOGGER = logging.getLogger(__name__) 

43 

44 

45def list_operations_handler(context: Context) -> Callable: 

46 """Factory function which returns a handler for ListOperations. 

47 

48 The returned handler uses ``context.channel`` and ``context.instance_name`` 

49 to send a ListOperations request constructed based on the provided URL 

50 query parameters. 

51 

52 The handler returns a serialised ListOperationsResponse, raises a 400 

53 error in the case of a bad filter or other invalid argument, or raises 

54 a 500 error in the case of some other RPC error. 

55 

56 Args: 

57 context (Context): The context to use to send the gRPC request. 

58 

59 """ 

60 async def _list_operations_handler(request: web.Request) -> web.Response: 

61 filter_string = request.rel_url.query.get('q', '') 

62 page_token = request.rel_url.query.get('page_token') 

63 page_size_str = request.rel_url.query.get('page_size') 

64 page_size = None 

65 if page_size_str is not None: 

66 page_size = int(page_size_str) 

67 

68 LOGGER.info(f'Received ListOperations request, filter_string="{filter_string}" ' 

69 f'page_token="{page_token}" page_size="{page_size}"') 

70 stub = OperationsStub(context.operations_channel) 

71 grpc_request = operations_pb2.ListOperationsRequest( 

72 name=context.instance_name, 

73 page_token=page_token, 

74 page_size=page_size, 

75 filter=filter_string) 

76 

77 try: 

78 grpc_response = await stub.ListOperations(grpc_request) 

79 except RpcError as e: 

80 LOGGER.warning(e.details()) 

81 if e.code() == StatusCode.INVALID_ARGUMENT: 

82 raise web.HTTPBadRequest() 

83 raise web.HTTPInternalServerError() 

84 

85 serialised_response = grpc_response.SerializeToString() 

86 return web.Response(body=serialised_response) 

87 return _list_operations_handler 

88 

89 

90async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]: 

91 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

92 

93 stub = OperationsStub(context.operations_channel) 

94 grpc_request = operations_pb2.GetOperationRequest(name=operation_name) 

95 

96 try: 

97 call = stub.GetOperation(grpc_request) 

98 operation = await call 

99 except RpcError as e: 

100 LOGGER.warning(e.details()) 

101 if e.code() == StatusCode.INVALID_ARGUMENT: 

102 raise web.HTTPNotFound() 

103 raise web.HTTPInternalServerError() 

104 

105 return operation, call 

106 

107 

108def get_operation_handler(context: Context) -> Callable: 

109 """Factory function which returns a handler for GetOperation. 

110 

111 The returned handler uses ``context.channel`` and ``context.instance_name`` 

112 to send a GetOperation request constructed based on the path component of 

113 the URL. 

114 

115 The handler returns a serialised Operation message, raises a 400 error in 

116 the case of an invalid operation name, or raises a 500 error in the case 

117 of some other RPC error. 

118 

119 Args: 

120 context (Context): The context to use to send the gRPC request. 

121 

122 """ 

123 async def _get_operation_handler(request: web.Request) -> web.Response: 

124 LOGGER.info(f'Received GetOperation request for "{request.match_info["name"]}"') 

125 operation, _ = await _get_operation(context, request) 

126 serialised_response = operation.SerializeToString() 

127 return web.Response(body=serialised_response) 

128 return _get_operation_handler 

129 

130 

131def get_operation_request_metadata_handler(context: Context) -> Callable: 

132 """Factory function which returns a handler to get RequestMetadata. 

133 

134 The returned handler uses ``context.channel`` and ``context.instance_name`` 

135 to send a GetOperation request constructed based on the path component of 

136 the URL. 

137 

138 The handler returns a serialised RequestMetadata proto message, retrieved 

139 from the trailing metadata of the GetOperation response. In the event of 

140 an invalid operation name it raises a 404 error, and raises a 500 error in 

141 the case of some other RPC error. 

142 

143 Args: 

144 context (Context): The context to use to send the gRPC request. 

145 

146 """ 

147 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response: 

148 LOGGER.info(f'Received request for RequestMetadata for "{request.match_info["name"]}') 

149 _, call = await _get_operation(context, request) 

150 metadata = await call.trailing_metadata() 

151 

152 def extract_metadata(m): 

153 # `m` contains a list of tuples, but `extract_request_metadata()` 

154 # expects a `key` and `value` attributes. 

155 MetadataTuple = namedtuple('MetadataTuple', ['key', 'value']) 

156 m = [MetadataTuple(entry[0], entry[1]) for entry in m] 

157 return extract_request_metadata(m) 

158 

159 request_metadata = extract_metadata(metadata) 

160 return web.Response(body=request_metadata.SerializeToString()) 

161 return _get_operation_request_metadata_handler 

162 

163 

164def cancel_operation_handler(context: Context) -> Callable: 

165 """Factory function which returns a handler for CancelOperation. 

166 

167 The returned handler uses ``context.channel`` and ``context.instance_name`` 

168 to send a CancelOperation request constructed based on the path component of 

169 the URL. 

170 

171 The handler raises a 404 error in the case of an invalid operation name, 

172 or a 500 error in the case of some other RPC error. 

173 

174 On success, the response is empty. 

175 

176 Args: 

177 context (Context): The context to use to send the gRPC request. 

178 

179 """ 

180 async def _cancel_operation_handler(request: web.Request) -> web.Response: 

181 LOGGER.info(f'Received CancelOperation request for "{request.match_info["name"]}"') 

182 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

183 

184 stub = OperationsStub(context.operations_channel) 

185 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name) 

186 

187 try: 

188 await stub.CancelOperation(grpc_request) 

189 return web.Response() 

190 except RpcError as e: 

191 LOGGER.warning(e.details()) 

192 if e.code() == StatusCode.INVALID_ARGUMENT: 

193 raise web.HTTPNotFound() 

194 raise web.HTTPInternalServerError() 

195 

196 return _cancel_operation_handler 

197 

198 

199async def _fetch_action_result(context: Context, request: web.Request) -> bytes: 

200 stub = ActionCacheStub(context.cache_channel) 

201 digest = remote_execution_pb2.Digest( 

202 hash=request.match_info["hash"], 

203 size_bytes=int(request.match_info["size_bytes"]) 

204 ) 

205 grpc_request = remote_execution_pb2.GetActionResultRequest( 

206 action_digest=digest, instance_name=context.instance_name) 

207 

208 try: 

209 result = await stub.GetActionResult(grpc_request) 

210 except RpcError as e: 

211 LOGGER.warning(e.details()) 

212 if e.code() == StatusCode.NOT_FOUND: 

213 raise web.HTTPNotFound() 

214 raise web.HTTPInternalServerError() 

215 

216 return result.SerializeToString() 

217 

218 

219def get_action_result_handler(context: Context, cache_capacity: int=512) -> Callable: 

220 """Factory function which returns a handler for GetActionResult. 

221 

222 The returned handler uses ``context.channel`` and ``context.instance_name`` 

223 to send a GetActionResult request constructed based on the path components 

224 of the URL. 

225 

226 The handler returns a serialised ActionResult message, raises a 404 error 

227 if there's no result cached, or raises a 500 error in the case of some 

228 other RPC error. 

229 

230 Args: 

231 context (Context): The context to use to send the gRPC request. 

232 cache_capacity (int): The number of ActionResults to cache in memory 

233 to avoid hitting the actual ActionCache. 

234 

235 """ 

236 class FetchSpec: 

237 """Simple class used to store information about a GetActionResult request. 

238 

239 A class is used here rather than a namedtuple since we need this state 

240 to be mutable. 

241 

242 """ 

243 def __init__(self, *, 

244 error: Optional[Exception], 

245 event: asyncio.Event, 

246 result: Optional[bytes], 

247 refcount: int): 

248 self.error = error 

249 self.event = event 

250 self.result = result 

251 self.refcount = refcount 

252 

253 in_flight_fetches: Dict[str, FetchSpec] = {} 

254 fetch_lock = asyncio.Lock() 

255 

256 result_cache = LruInMemoryCache(capacity=cache_capacity) 

257 

258 async def _get_action_result_handler(request: web.Request) -> web.Response: 

259 LOGGER.info( 

260 'Received GetActionResult request for ' 

261 f'"{request.match_info["hash"]}/{request.match_info["size_bytes"]}"' 

262 ) 

263 

264 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

265 serialized_result = result_cache.get(cache_key) 

266 

267 if serialized_result is None: 

268 try: 

269 duplicate_request = False 

270 spec = None 

271 async with fetch_lock: 

272 if cache_key in in_flight_fetches: 

273 LOGGER.info(f'Deduplicating GetActionResult request for [{cache_key}]') 

274 spec = in_flight_fetches[cache_key] 

275 spec.refcount += 1 

276 duplicate_request = True 

277 else: 

278 spec = FetchSpec( 

279 error=None, 

280 event=asyncio.Event(), 

281 result=None, 

282 refcount=1 

283 ) 

284 in_flight_fetches[cache_key] = spec 

285 

286 if duplicate_request and spec: 

287 # If we're a duplicate of an existing request, then wait for the 

288 # existing request to finish fetching from the ActionCache. 

289 await spec.event.wait() 

290 if spec.error is not None: 

291 raise spec.error 

292 if spec.result is None: 

293 # Note: this should be impossible, but lets guard against accidentally setting 

294 # the event before the result is populated 

295 LOGGER.info(f'Result not set in deduplicated request for [{cache_key}]') 

296 raise web.HTTPInternalServerError() 

297 serialized_result = spec.result 

298 else: 

299 try: 

300 serialized_result = await _fetch_action_result(context, request) 

301 

302 except Exception as e: 

303 async with fetch_lock: 

304 if spec is not None: 

305 spec.error = e 

306 raise e 

307 

308 result_cache.update(cache_key, serialized_result) 

309 async with fetch_lock: 

310 if spec is not None: 

311 spec.result = serialized_result 

312 spec.event.set() 

313 

314 finally: 

315 async with fetch_lock: 

316 # Decrement refcount now we're done with the result. If we're the 

317 # last request interested in the result then remove it from the 

318 # `in_flight_fetches` dictionary. 

319 spec = in_flight_fetches.get(cache_key) 

320 if spec is not None: 

321 spec.refcount -= 1 

322 if spec.refcount <= 0: 

323 in_flight_fetches.pop(cache_key) 

324 

325 return web.Response(body=serialized_result) 

326 return _get_action_result_handler 

327 

328 

329def get_blob_handler(context: Context, allow_all: bool=False, allowed_origins: Sequence=()) -> Callable: 

330 async def _get_blob_handler(request: web.Request) -> web.StreamResponse: 

331 stub = ByteStreamStub(context.cas_channel) 

332 resource_name = f"blobs/{request.match_info['hash']}/{request.match_info['size_bytes']}" 

333 grpc_request = ReadRequest(resource_name=f"{context.instance_name}/{resource_name}") 

334 try: 

335 grpc_request.read_offset = int(request.rel_url.query.get('offset', '0')) 

336 grpc_request.read_limit = int(request.rel_url.query.get('limit', '0')) 

337 except ValueError: 

338 raise web.HTTPBadRequest() 

339 

340 response = web.StreamResponse() 

341 

342 # We need to explicitly set CORS headers here, because the middleware that 

343 # normally handles this only adds the header when the response is returned 

344 # from this function. However, when using a StreamResponse with chunked 

345 # encoding enabled, the client begins to receive the response when we call 

346 # `response.write()`. This leads to the request being disallowed due to the 

347 # missing reponse header for clients executing in browsers. 

348 cors_headers = utils.get_cors_headers( 

349 request.headers.get('Origin'), allowed_origins, allow_all) 

350 response.headers.update(cors_headers) 

351 

352 if request.rel_url.query.get('raw', default='') == '1': 

353 response.headers['Content-type'] = 'text/plain; charset=utf-8' 

354 else: 

355 # Setting the Content-Disposition header so that when 

356 # downloading a blob with a browser the default name uniquely 

357 # identifies its contents: 

358 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}" 

359 

360 # For partial reads also include the indices to prevent file mix-ups: 

361 if grpc_request.read_offset != 0 or grpc_request.read_limit != 0: 

362 filename += f"_chunk_{grpc_request.read_offset}-" 

363 if grpc_request.read_limit != 0: 

364 filename += f"{grpc_request.read_offset + grpc_request.read_limit}" 

365 

366 response.headers['Content-Disposition'] = f'Attachment;filename={filename}' 

367 

368 try: 

369 prepared = False 

370 async for grpc_response in stub.Read(grpc_request): 

371 if grpc_response.data: 

372 if not prepared: 

373 response.enable_chunked_encoding() 

374 await response.prepare(request) 

375 prepared = True 

376 await response.write(grpc_response.data) 

377 except RpcError as e: 

378 LOGGER.warning(e.details()) 

379 if e.code() == StatusCode.NOT_FOUND: 

380 raise web.HTTPNotFound() 

381 raise web.HTTPInternalServerError() 

382 return response 

383 return _get_blob_handler 

384 

385 

386def _create_tarball(directory: str, name: str) -> bool: 

387 """Makes a tarball from a given directory. 

388 

389 Returns True if the tarball was successfully created, and False if not. 

390 

391 Args: 

392 directory (str): The directory to tar up. 

393 name (str): The name of the tarball to be produced. 

394 

395 """ 

396 try: 

397 with tarfile.open(name, 'w:gz') as tarball: 

398 tarball.add(directory, arcname="") 

399 except Exception: 

400 return False 

401 return True 

402 

403 

404async def _fetch_blob( 

405 context: Context, 

406 digest: remote_execution_pb2.Digest, 

407 message_class: Optional[Type]=None, 

408 callback: Optional[Callable]=None 

409) -> Any: 

410 """Fetch a blob from CAS. 

411 

412 This function sends a ByteStream Read request for the given digest. If ``callback`` 

413 is set then the callback is called with the data in each ReadResponse message and 

414 this function returns an empty bytes object once the response is finished. If 

415 ``message_class`` is set with no ``callback`` set then this function calls 

416 ``message_class.FromString`` on the fetched blob and returns the result. 

417 

418 If neither ``callback`` or ``message_class`` are set then this function just returns 

419 the raw blob that was fetched from CAS. 

420 

421 Args: 

422 context (Context): The context to use to send the gRPC request. 

423 digest (Digest): The Digest of the blob to fetch from CAS. 

424 message_class (type): A class which implements a ``FromString`` class method. 

425 The method should take a bytes object expected to contain the blob fetched 

426 from CAS. 

427 callback (callable): A function or other callable to act on a subset of the 

428 blob contents. 

429 

430 """ 

431 stub = ByteStreamStub(context.cas_channel) 

432 resource_name = f"blobs/{digest.hash}/{digest.size_bytes}" 

433 grpc_request = ReadRequest(resource_name=f"{context.instance_name}/{resource_name}") 

434 blob = b'' 

435 try: 

436 async for grpc_response in stub.Read(grpc_request): 

437 if grpc_response.data: 

438 if callback is not None: 

439 await callback(grpc_response.data) 

440 else: 

441 blob += grpc_response.data 

442 except RpcError as e: 

443 LOGGER.warning(e.details()) 

444 if e.code() == StatusCode.NOT_FOUND: 

445 raise web.HTTPNotFound() 

446 raise web.HTTPInternalServerError() 

447 

448 if message_class is not None and callback is None: 

449 return message_class.FromString(blob) 

450 return blob 

451 

452 

453async def _download_directory( 

454 base: str, 

455 path: str, 

456 directory: remote_execution_pb2.Directory, 

457 context: Context 

458) -> None: 

459 """Download the contents of a directory from CAS. 

460 

461 This function takes a Directory message and downloads the directory 

462 contents defined in the message from CAS. Raises a 400 error if the 

463 directory contains a symlink which points to a location outside the 

464 initial directory. 

465 

466 The directory is downloaded recursively depth-first. 

467 

468 Args: 

469 base (str): The initial directory path, used to check symlinks don't 

470 escape into the wider filesystem. 

471 path (str): The path to download the directory into. 

472 directory (Directory): The Directory message to fetch the contents of. 

473 context (Context): The context to use for making gRPC requests. 

474 

475 """ 

476 for directory_node in directory.directories: 

477 dir_path = os.path.join(path, directory_node.name) 

478 os.mkdir(dir_path) 

479 child = await _fetch_blob(context, directory_node.digest, message_class=remote_execution_pb2.Directory) 

480 await _download_directory(base, dir_path, child, context) 

481 

482 for file_node in directory.files: 

483 file_path = os.path.join(path, file_node.name) 

484 async with aiofiles.open(file_path, 'wb') as f: 

485 await _fetch_blob(context, file_node.digest, callback=f.write) 

486 if file_node.is_executable: 

487 os.chmod(file_path, 0o755) 

488 

489 for link_node in directory.symlinks: 

490 link_path = os.path.join(path, link_node.name) 

491 target_path = os.path.realpath(link_node.target) 

492 target_relpath = os.path.relpath(base, target_path) 

493 if target_relpath.startswith(os.pardir): 

494 raise web.HTTPBadRequest( 

495 reason="Requested directory contains a symlink targeting a location outside the tarball") 

496 

497 os.symlink(link_node.target, link_path) 

498 

499 

500async def _tarball_from_directory(directory: remote_execution_pb2.Directory, context: Context, tmp_dir: str) -> str: 

501 """Construct a tarball of a directory stored in CAS. 

502 

503 This function fetches the contents of the given directory message into a 

504 temporary directory, and then constructs a tarball of the directory. The 

505 path to this tarball is returned when construction is complete. 

506 

507 Args: 

508 directory (Directory): The Directory message for the directory we're 

509 making a tarball of. 

510 context (Context): The context to use to send the gRPC requests. 

511 tmp_dir (str): Path to a temporary directory to use for storing the 

512 directory contents and its tarball. 

513 

514 """ 

515 tarball_dir = tempfile.mkdtemp(dir=tmp_dir) 

516 tarball_path = os.path.join(tmp_dir, 'directory.tar.gz') 

517 loop = asyncio.get_event_loop() 

518 

519 # Fetch the contents of the directory into a temporary directory 

520 await _download_directory(tarball_dir, tarball_dir, directory, context) 

521 

522 # Make a tarball from that temporary directory 

523 # NOTE: We do this using loop.run_in_executor to avoid the 

524 # synchronous and blocking tarball construction 

525 tarball_result = await loop.run_in_executor( 

526 None, _create_tarball, tarball_dir, tarball_path) 

527 if not tarball_result: 

528 raise web.HTTPInternalServerError() 

529 return tarball_path 

530 

531 

532def get_tarball_handler(context: Context, allow_all: bool=False, allowed_origins: Sequence=(), 

533 tarball_dir: Optional[str]=None) -> Callable: 

534 """Factory function which returns a handler for tarball requests. 

535 

536 This function also takes care of cleaning up old incomplete tarball constructions 

537 when given a named directory to do the construction in. 

538 

539 The returned handler takes the hash and size_bytes of a Digest of a Directory 

540 message and constructs a tarball of the directory defined by the message. 

541 

542 Args: 

543 context (Context): The context to use to send the gRPC requests. 

544 allow_all (bool): Whether or not to allow all CORS origins. 

545 allowed_origins (list): List of valid CORS origins. 

546 tarball_dir (str): Base directory to use for tarball construction. 

547 

548 """ 

549 class FetchSpec: 

550 """Simple class used to store information about a tarball request. 

551 

552 A class is used here rather than a namedtuple since we need this state 

553 to be mutable. 

554 

555 """ 

556 def __init__(self, *, 

557 error: Optional[Exception], 

558 event: Optional[asyncio.Event], 

559 path: str, 

560 refcount: int): 

561 self.error = error 

562 self.event = event 

563 self.path = path 

564 self.refcount = refcount 

565 

566 in_flight_fetches: Dict[str, FetchSpec] = {} 

567 fetch_lock = asyncio.Lock() 

568 

569 # If we have a tarball directory to use, empty all existing tarball constructions from it 

570 # to provide some form of cleanup after a crash. 

571 if tarball_dir is not None: 

572 for path in os.listdir(tarball_dir): 

573 if path.startswith(utils.TARBALL_DIRECTORY_PREFIX): 

574 shutil.rmtree(os.path.join(tarball_dir, path)) 

575 

576 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse: 

577 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

578 LOGGER.info(f'Received request for a tarball from CAS for blob with digest [{digest_str}]') 

579 

580 digest = remote_execution_pb2.Digest( 

581 hash=request.match_info['hash'], 

582 size_bytes=int(request.match_info['size_bytes']) 

583 ) 

584 

585 tmp_dir = tempfile.mkdtemp(prefix=utils.TARBALL_DIRECTORY_PREFIX, dir=tarball_dir) 

586 

587 try: 

588 duplicate_request = False 

589 event = None 

590 async with fetch_lock: 

591 if digest_str in in_flight_fetches: 

592 LOGGER.info(f'Deduplicating request for tarball of [{digest_str}]') 

593 spec = in_flight_fetches[digest_str] 

594 spec.refcount += 1 

595 event = spec.event 

596 duplicate_request = True 

597 else: 

598 event = asyncio.Event() 

599 in_flight_fetches[digest_str] = FetchSpec( 

600 error=None, 

601 event=event, 

602 path='', 

603 refcount=1 

604 ) 

605 

606 if duplicate_request and event: 

607 # If we're a duplicate of an existing request, then wait for the 

608 # existing request to finish tarball creation before reading the 

609 # path from the cache. 

610 await event.wait() 

611 spec = in_flight_fetches[digest_str] 

612 if spec.error is not None: 

613 raise spec.error 

614 tarball_path = in_flight_fetches[digest_str].path 

615 else: 

616 try: 

617 directory = await _fetch_blob(context, digest, message_class=remote_execution_pb2.Directory) 

618 tarball_path = await _tarball_from_directory(directory, context, tmp_dir) 

619 except web.HTTPError as e: 

620 in_flight_fetches[digest_str].error = e 

621 if event: 

622 event.set() 

623 raise e 

624 except Exception as e: 

625 LOGGER.debug("Unexpected error constructing tarball", exc_info=True) 

626 in_flight_fetches[digest_str].error = e 

627 if event: 

628 event.set() 

629 raise web.HTTPInternalServerError() 

630 

631 # Update path in deduplication cache, and set event to notify 

632 # duplicate requests that the tarball is ready 

633 async with fetch_lock: 

634 if event: 

635 in_flight_fetches[digest_str].path = tarball_path 

636 event.set() 

637 

638 response = web.StreamResponse() 

639 

640 # We need to explicitly set CORS headers here, because the middleware that 

641 # normally handles this only adds the header when the response is returned 

642 # from this function. However, when using a StreamResponse with chunked 

643 # encoding enabled, the client begins to receive the response when we call 

644 # `response.write()`. This leads to the request being disallowed due to the 

645 # missing reponse header for clients executing in browsers. 

646 cors_headers = utils.get_cors_headers( 

647 request.headers.get('Origin'), allowed_origins, allow_all) 

648 response.headers.update(cors_headers) 

649 

650 response.enable_chunked_encoding() 

651 await response.prepare(request) 

652 

653 async with aiofiles.open(tarball_path, 'rb') as tarball: 

654 await tarball.seek(0) 

655 chunk = await tarball.read(1024) 

656 while chunk: 

657 await response.write(chunk) 

658 chunk = await tarball.read(1024) 

659 return response 

660 

661 except RpcError as e: 

662 LOGGER.warning(e.details()) 

663 if e.code() == StatusCode.NOT_FOUND: 

664 raise web.HTTPNotFound() 

665 raise web.HTTPInternalServerError() 

666 

667 finally: 

668 cleanup = False 

669 async with fetch_lock: 

670 # Decrement refcount now we're done with the tarball. If we're the 

671 # last request interested in the tarball then remove it along with 

672 # its construction directory. 

673 spec = in_flight_fetches[digest_str] 

674 spec.refcount -= 1 

675 if spec.refcount <= 0: 

676 cleanup = True 

677 in_flight_fetches.pop(digest_str) 

678 if cleanup: 

679 shutil.rmtree(tmp_dir) 

680 return _get_tarball_handler 

681 

682 

683def logstream_handler(context: Context) -> Callable: 

684 async def _logstream_handler(request): 

685 LOGGER.info('Receieved request for a LogStream websocket') 

686 stub = ByteStreamStub(context.logstream_channel) 

687 ws = web.WebSocketResponse() 

688 await ws.prepare(request) 

689 

690 async for msg in ws: 

691 if msg.type == WSMsgType.BINARY: 

692 read_request = ReadRequest() 

693 read_request.ParseFromString(msg.data) 

694 

695 read_request.resource_name = f'{context.instance_name}/{read_request.resource_name}' 

696 try: 

697 async for response in stub.Read(read_request): 

698 serialized_response = response.SerializeToString() 

699 if serialized_response: 

700 ws_response = { 

701 "resource_name": read_request.resource_name, 

702 "data": b64encode(serialized_response).decode('utf-8'), 

703 "complete": False 

704 } 

705 await ws.send_json(ws_response) 

706 ws_response = { 

707 "resource_name": read_request.resource_name, 

708 "data": "", 

709 "complete": True 

710 } 

711 await ws.send_json(ws_response) 

712 except RpcError as e: 

713 LOGGER.warning(e.details()) 

714 if e.code() == StatusCode.NOT_FOUND: 

715 ws_response = { 

716 "resource_name": read_request.resource_name, 

717 "data": "NOT_FOUND", 

718 "complete": True 

719 } 

720 await ws.send_json(ws_response) 

721 ws_response = { 

722 "resource_name": read_request.resource_name, 

723 "data": "INTERNAL", 

724 "complete": True 

725 } 

726 await ws.send_json(ws_response) 

727 return _logstream_handler