Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/browser/rest_api.py: 82.72%

405 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2021 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License' is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import asyncio 

16import json 

17import os 

18import shutil 

19import tarfile 

20import tempfile 

21from collections import namedtuple 

22from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type 

23 

24import aiofiles 

25from aiohttp import WSMsgType, web 

26from aiohttp_middlewares.annotations import UrlCollection 

27from grpc import RpcError, StatusCode 

28from grpc.aio import Call, Metadata 

29 

30from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

31 ActionResult, 

32 Digest, 

33 Directory, 

34 GetActionResultRequest, 

35 RequestMetadata, 

36) 

37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc_aio import ActionCacheStub 

38from buildgrid._protos.buildgrid.v2.query_build_events_pb2 import QueryEventStreamsRequest 

39from buildgrid._protos.buildgrid.v2.query_build_events_pb2_grpc_aio import QueryBuildEventsStub 

40from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest 

41from buildgrid._protos.google.bytestream.bytestream_pb2_grpc_aio import ByteStreamStub 

42from buildgrid._protos.google.longrunning import operations_pb2 

43from buildgrid._protos.google.longrunning.operations_pb2_grpc_aio import OperationsStub 

44from buildgrid.server.app.cli import Context 

45from buildgrid.server.browser.utils import TARBALL_DIRECTORY_PREFIX, ResponseCache, get_cors_headers 

46from buildgrid.server.logging import buildgrid_logger 

47from buildgrid.server.metadata import extract_request_metadata, extract_trailing_client_identity 

48from buildgrid.server.operations.filtering.interpreter import VALID_OPERATION_FILTERS, OperationFilterSpec 

49from buildgrid.server.operations.filtering.sanitizer import DatetimeValueSanitizer, SortKeyValueSanitizer 

50from buildgrid.server.settings import BROWSER_MAX_CACHE_ENTRY_SIZE 

51 

52LOGGER = buildgrid_logger(__name__) 

53 

54 

55def query_build_events_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

56 """Factory function which returns a handler for QueryEventStreams. 

57 

58 The returned handler uses ``context.channel`` to send a QueryEventStreams 

59 request constructed based on the provided URL query parameters. Currently 

60 only querying by build_id (equivalent to correlated invocations ID) is 

61 supported. 

62 

63 The handler returns a serialised QueryEventStreamsResponse, and raises a 

64 500 error in the case of some RPC error. 

65 

66 Args: 

67 context (Context): The context to use to send the gRPC request. 

68 

69 """ 

70 

71 async def _query_build_events_handler(request: web.Request) -> web.Response: 

72 build_id = request.rel_url.query.get("build_id", ".*") 

73 

74 stub = QueryBuildEventsStub(context.channel) # type: ignore # Requires stub regen 

75 

76 grpc_request = QueryEventStreamsRequest(build_id_pattern=build_id) 

77 

78 try: 

79 grpc_response = await stub.QueryEventStreams(grpc_request) 

80 except RpcError as e: 

81 LOGGER.warning(e.details()) 

82 raise web.HTTPInternalServerError() 

83 

84 serialized_response = grpc_response.SerializeToString() 

85 return web.Response(body=serialized_response) 

86 

87 return _query_build_events_handler 

88 

89 

90async def get_operation_filters_handler(request: web.Request) -> web.Response: 

91 """Return the available Operation filter keys.""" 

92 

93 def _generate_filter_spec(key: str, spec: OperationFilterSpec) -> Dict[str, Any]: 

94 comparators = ["<", "<=", "=", "!=", ">=", ">"] 

95 filter_type = "text" 

96 if isinstance(spec.sanitizer, SortKeyValueSanitizer): 

97 comparators = ["="] 

98 elif isinstance(spec.sanitizer, DatetimeValueSanitizer): 

99 filter_type = "datetime" 

100 

101 ret = { 

102 "comparators": comparators, 

103 "description": spec.description, 

104 "key": key, 

105 "name": spec.name, 

106 "type": filter_type, 

107 } 

108 

109 try: 

110 ret["values"] = spec.sanitizer.valid_values 

111 except NotImplementedError: 

112 pass 

113 return ret 

114 

115 operation_filters = [_generate_filter_spec(key, spec) for key, spec in VALID_OPERATION_FILTERS.items()] 

116 return web.Response(text=json.dumps(operation_filters)) 

117 

118 

119def list_operations_handler( 

120 context: Context, cache: ResponseCache 

121) -> Callable[[web.Request], Awaitable[web.Response]]: 

122 """Factory function which returns a handler for ListOperations. 

123 

124 The returned handler uses ``context.channel`` and ``context.instance_name`` 

125 to send a ListOperations request constructed based on the provided URL 

126 query parameters. 

127 

128 The handler returns a serialised ListOperationsResponse, raises a 400 

129 error in the case of a bad filter or other invalid argument, or raises 

130 a 500 error in the case of some other RPC error. 

131 

132 Args: 

133 context (Context): The context to use to send the gRPC request. 

134 

135 """ 

136 

137 async def _list_operations_handler(request: web.Request) -> web.Response: 

138 filter_string = request.rel_url.query.get("q", "") 

139 page_token = request.rel_url.query.get("page_token", "") 

140 page_size_str = request.rel_url.query.get("page_size") 

141 page_size = 0 

142 if page_size_str is not None: 

143 page_size = int(page_size_str) 

144 

145 LOGGER.info( 

146 "Received ListOperations request.", 

147 tags=dict(filter_string=filter_string, page_token=page_token, page_size=page_size), 

148 ) 

149 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

150 grpc_request = operations_pb2.ListOperationsRequest( 

151 name=context.instance_name, page_token=page_token, page_size=page_size, filter=filter_string 

152 ) 

153 

154 try: 

155 grpc_response = await stub.ListOperations(grpc_request) 

156 except RpcError as e: 

157 LOGGER.warning(e.details()) 

158 if e.code() == StatusCode.INVALID_ARGUMENT: 

159 raise web.HTTPBadRequest() 

160 raise web.HTTPInternalServerError() 

161 

162 serialised_response = grpc_response.SerializeToString() 

163 return web.Response(body=serialised_response) 

164 

165 return _list_operations_handler 

166 

167 

168async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]: 

169 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

170 

171 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

172 grpc_request = operations_pb2.GetOperationRequest(name=operation_name) 

173 

174 try: 

175 call = stub.GetOperation(grpc_request) 

176 operation = await call 

177 except RpcError as e: 

178 LOGGER.warning(f"Error fetching operation: {e.details()}") 

179 if e.code() == StatusCode.INVALID_ARGUMENT: 

180 raise web.HTTPNotFound() 

181 raise web.HTTPInternalServerError() 

182 

183 return operation, call 

184 

185 

186def get_operation_handler(context: Context, cache: ResponseCache) -> Callable[[web.Request], Awaitable[web.Response]]: 

187 """Factory function which returns a handler for GetOperation. 

188 

189 The returned handler uses ``context.channel`` and ``context.instance_name`` 

190 to send a GetOperation request constructed based on the path component of 

191 the URL. 

192 

193 The handler returns a serialised Operation message, raises a 400 error in 

194 the case of an invalid operation name, or raises a 500 error in the case 

195 of some other RPC error. 

196 

197 Args: 

198 context (Context): The context to use to send the gRPC request. 

199 

200 """ 

201 

202 async def _get_operation_handler(request: web.Request) -> web.Response: 

203 name = request.match_info["name"] 

204 LOGGER.info("Received GetOperation request.", tags=dict(name=name)) 

205 operation, _ = await _get_operation(context, request) 

206 

207 serialised_response = operation.SerializeToString() 

208 return web.Response(body=serialised_response) 

209 

210 return _get_operation_handler 

211 

212 

213def get_operation_request_metadata_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

214 """Factory function which returns a handler to get RequestMetadata. 

215 

216 The returned handler uses ``context.channel`` and ``context.instance_name`` 

217 to send a GetOperation request constructed based on the path component of 

218 the URL. 

219 

220 The handler returns a serialised RequestMetadata proto message, retrieved 

221 from the trailing metadata of the GetOperation response. In the event of 

222 an invalid operation name it raises a 404 error, and raises a 500 error in 

223 the case of some other RPC error. 

224 

225 Args: 

226 context (Context): The context to use to send the gRPC request. 

227 

228 """ 

229 

230 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response: 

231 LOGGER.info("Received request for RequestMetadata.", tags=dict(name=str(request.match_info["name"]))) 

232 _, call = await _get_operation(context, request) 

233 metadata = await call.trailing_metadata() 

234 

235 def extract_metadata(m: Metadata) -> RequestMetadata: 

236 # `m` contains a list of tuples, but `extract_request_metadata()` 

237 # expects a `key` and `value` attributes. 

238 MetadataTuple = namedtuple("MetadataTuple", ["key", "value"]) 

239 return extract_request_metadata([MetadataTuple(entry[0], entry[1]) for entry in m]) 

240 

241 request_metadata = extract_metadata(metadata) 

242 return web.Response(body=request_metadata.SerializeToString()) 

243 

244 return _get_operation_request_metadata_handler 

245 

246 

247def get_operation_client_identity_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

248 """Factory function which returns a handler to get ClientIdentity metadata. 

249 

250 The returned handler uses ``context.channel`` and ``context.instance_name`` 

251 to send a GetOperation request constructed based on the path component of 

252 the URL. 

253 

254 The handler returns a serialised ClientIdentity proto message, retrieved 

255 from the trailing metadata of the GetOperation response. In the event of 

256 an invalid operation name it raises a 404 error, and raises a 500 error in 

257 the case of some other RPC error. 

258 

259 Args: 

260 context (Context): The context to use to send the gRPC request. 

261 

262 """ 

263 

264 async def _get_operation_client_identity_handler(request: web.Request) -> web.Response: 

265 LOGGER.info("Received request for RequestMetadata.", tags=dict(name=str(request.match_info["name"]))) 

266 _, call = await _get_operation(context, request) 

267 metadata = await call.trailing_metadata() 

268 client_identity = extract_trailing_client_identity(metadata) 

269 return web.Response(body=client_identity.SerializeToString()) 

270 

271 return _get_operation_client_identity_handler 

272 

273 

274def cancel_operation_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]: 

275 """Factory function which returns a handler for CancelOperation. 

276 

277 The returned handler uses ``context.channel`` and ``context.instance_name`` 

278 to send a CancelOperation request constructed based on the path component of 

279 the URL. 

280 

281 The handler raises a 404 error in the case of an invalid operation name, 

282 or a 500 error in the case of some other RPC error. 

283 

284 On success, the response is empty. 

285 

286 Args: 

287 context (Context): The context to use to send the gRPC request. 

288 

289 """ 

290 

291 async def _cancel_operation_handler(request: web.Request) -> web.Response: 

292 LOGGER.info("Received CancelOperation request.", tags=dict(name=str(request.match_info["name"]))) 

293 operation_name = f"{context.instance_name}/{request.match_info['name']}" 

294 

295 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen 

296 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name) 

297 

298 try: 

299 await stub.CancelOperation(grpc_request) 

300 return web.Response() 

301 except RpcError as e: 

302 LOGGER.warning(e.details()) 

303 if e.code() == StatusCode.INVALID_ARGUMENT: 

304 raise web.HTTPNotFound() 

305 raise web.HTTPInternalServerError() 

306 

307 return _cancel_operation_handler 

308 

309 

310async def _fetch_action_result( 

311 context: Context, request: web.Request, cache: ResponseCache, cache_key: str 

312) -> ActionResult: 

313 stub = ActionCacheStub(context.cache_channel) # type: ignore # Requires stub regen 

314 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

315 grpc_request = GetActionResultRequest(action_digest=digest, instance_name=context.instance_name) 

316 

317 try: 

318 result = await stub.GetActionResult(grpc_request) 

319 except RpcError as e: 

320 LOGGER.warning(f"Failed to fetch ActionResult: [{e.details()}]") 

321 if e.code() == StatusCode.NOT_FOUND: 

322 raise web.HTTPNotFound() 

323 raise web.HTTPInternalServerError() 

324 

325 await cache.store_action_result(cache_key, result) 

326 return result 

327 

328 

329def get_action_result_handler( 

330 context: Context, cache: ResponseCache, cache_capacity: int = 512 

331) -> Callable[[web.Request], Awaitable[web.Response]]: 

332 """Factory function which returns a handler for GetActionResult. 

333 

334 The returned handler uses ``context.channel`` and ``context.instance_name`` 

335 to send a GetActionResult request constructed based on the path components 

336 of the URL. 

337 

338 The handler returns a serialised ActionResult message, raises a 404 error 

339 if there's no result cached, or raises a 500 error in the case of some 

340 other RPC error. 

341 

342 Args: 

343 context (Context): The context to use to send the gRPC request. 

344 cache_capacity (int): The number of ActionResults to cache in memory 

345 to avoid hitting the actual ActionCache. 

346 

347 """ 

348 

349 class FetchSpec: 

350 """Simple class used to store information about a GetActionResult request. 

351 

352 A class is used here rather than a namedtuple since we need this state 

353 to be mutable. 

354 

355 """ 

356 

357 def __init__( 

358 self, *, error: Optional[Exception], event: asyncio.Event, result: Optional[ActionResult], refcount: int 

359 ): 

360 self.error = error 

361 self.event = event 

362 self.result = result 

363 self.refcount = refcount 

364 

365 in_flight_fetches: Dict[str, FetchSpec] = {} 

366 fetch_lock = asyncio.Lock() 

367 

368 async def _get_action_result_handler(request: web.Request) -> web.Response: 

369 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

370 LOGGER.info("Received GetActionResult request.", tags=dict(cache_key=cache_key)) 

371 

372 result = await cache.get_action_result(cache_key) 

373 

374 if result is None: 

375 try: 

376 duplicate_request = False 

377 spec = None 

378 async with fetch_lock: 

379 if cache_key in in_flight_fetches: 

380 LOGGER.info("Deduplicating GetActionResult request.", tags=dict(cache_key=cache_key)) 

381 spec = in_flight_fetches[cache_key] 

382 spec.refcount += 1 

383 duplicate_request = True 

384 else: 

385 spec = FetchSpec(error=None, event=asyncio.Event(), result=None, refcount=1) 

386 in_flight_fetches[cache_key] = spec 

387 

388 if duplicate_request and spec: 

389 # If we're a duplicate of an existing request, then wait for the 

390 # existing request to finish fetching from the ActionCache. 

391 await spec.event.wait() 

392 if spec.error is not None: 

393 raise spec.error 

394 if spec.result is None: 

395 # Note: this should be impossible, but lets guard against accidentally setting 

396 # the event before the result is populated 

397 LOGGER.info("Result not set in deduplicated request.", tags=dict(cache_key=cache_key)) 

398 raise web.HTTPInternalServerError() 

399 result = spec.result 

400 else: 

401 try: 

402 result = await _fetch_action_result(context, request, cache, cache_key) 

403 except Exception as e: 

404 async with fetch_lock: 

405 if spec is not None: 

406 spec.error = e 

407 raise e 

408 

409 async with fetch_lock: 

410 if spec is not None: 

411 spec.result = result 

412 spec.event.set() 

413 

414 finally: 

415 async with fetch_lock: 

416 # Decrement refcount now we're done with the result. If we're the 

417 # last request interested in the result then remove it from the 

418 # `in_flight_fetches` dictionary. 

419 spec = in_flight_fetches.get(cache_key) 

420 if spec is not None: 

421 spec.refcount -= 1 

422 if spec.refcount <= 0: 

423 in_flight_fetches.pop(cache_key) 

424 

425 return web.Response(body=result.SerializeToString()) 

426 

427 return _get_action_result_handler 

428 

429 

430def get_blob_handler( 

431 context: Context, cache: ResponseCache, allow_all: bool = False, allowed_origins: UrlCollection = () 

432) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: 

433 async def _get_blob_handler(request: web.Request) -> web.StreamResponse: 

434 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

435 try: 

436 offset = int(request.rel_url.query.get("offset", "0")) 

437 limit = int(request.rel_url.query.get("limit", "0")) 

438 except ValueError: 

439 raise web.HTTPBadRequest() 

440 

441 response = web.StreamResponse() 

442 

443 # We need to explicitly set CORS headers here, because the middleware that 

444 # normally handles this only adds the header when the response is returned 

445 # from this function. However, when using a StreamResponse with chunked 

446 # encoding enabled, the client begins to receive the response when we call 

447 # `response.write()`. This leads to the request being disallowed due to the 

448 # missing reponse header for clients executing in browsers. 

449 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all) 

450 response.headers.update(cors_headers) 

451 

452 if request.rel_url.query.get("raw", "") == "1": 

453 response.headers["Content-type"] = "text/plain; charset=utf-8" 

454 else: 

455 # Setting the Content-Disposition header so that when 

456 # downloading a blob with a browser the default name uniquely 

457 # identifies its contents: 

458 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}" 

459 

460 # For partial reads also include the indices to prevent file mix-ups: 

461 if offset != 0 or limit != 0: 

462 filename += f"_chunk_{offset}-" 

463 if limit != 0: 

464 filename += f"{offset + limit}" 

465 

466 response.headers["Content-Disposition"] = f"Attachment;filename={filename}" 

467 

468 prepared = False 

469 

470 async def _callback(data: bytes) -> None: 

471 nonlocal prepared 

472 if not prepared: 

473 # Prepare for chunked encoding when the callback is first called, 

474 # so that we're sure we actually have some data before doing so. 

475 response.enable_chunked_encoding() 

476 await response.prepare(request) 

477 prepared = True 

478 await response.write(data) 

479 

480 await _fetch_blob(context, cache, digest, callback=_callback, offset=offset, limit=limit) 

481 

482 return response 

483 

484 return _get_blob_handler 

485 

486 

487def _create_tarball(directory: str, name: str) -> bool: 

488 """Makes a tarball from a given directory. 

489 

490 Returns True if the tarball was successfully created, and False if not. 

491 

492 Args: 

493 directory (str): The directory to tar up. 

494 name (str): The name of the tarball to be produced. 

495 

496 """ 

497 try: 

498 with tarfile.open(name, "w:gz") as tarball: 

499 tarball.add(directory, arcname="") 

500 except Exception: 

501 return False 

502 return True 

503 

504 

505async def _fetch_blob( 

506 context: Context, 

507 cache: ResponseCache, 

508 digest: Digest, 

509 message_class: Optional[Type[Any]] = None, 

510 callback: Optional[Callable[[bytes], Awaitable[Any]]] = None, 

511 offset: int = 0, 

512 limit: int = 0, 

513) -> Any: 

514 """Fetch a blob from CAS. 

515 

516 This function sends a ByteStream Read request for the given digest. If ``callback`` 

517 is set then the callback is called with the data in each ReadResponse message and 

518 this function returns an empty bytes object once the response is finished. If 

519 ``message_class`` is set with no ``callback`` set then this function calls 

520 ``message_class.FromString`` on the fetched blob and returns the result. 

521 

522 If neither ``callback`` or ``message_class`` are set then this function just returns 

523 the raw blob that was fetched from CAS. 

524 

525 Args: 

526 context (Context): The context to use to send the gRPC request. 

527 cache (ResponseCache): The response cache to check/update with the fetched blob. 

528 digest (Digest): The Digest of the blob to fetch from CAS. 

529 message_class (type): A class which implements a ``FromString`` class method. 

530 The method should take a bytes object expected to contain the blob fetched 

531 from CAS. 

532 callback (callable): A function or other callable to act on a subset of the 

533 blob contents. 

534 offset (int): Read offset to start reading the blob at. Defaults to 0, the start 

535 of the blob. 

536 limit (int): Maximum number of bytes to read from the blob. Defaults to 0, no 

537 limit. 

538 

539 """ 

540 cacheable = digest.size_bytes <= BROWSER_MAX_CACHE_ENTRY_SIZE 

541 resource_name = f"{context.instance_name}/blobs/{digest.hash}/{digest.size_bytes}" 

542 blob = None 

543 if cacheable: 

544 blob = await cache.get_blob(resource_name) 

545 if blob is not None and callback is not None: 

546 if limit > 0: 

547 try: 

548 blob = blob[offset : offset + limit] 

549 except IndexError: 

550 raise web.HTTPBadRequest() 

551 await callback(blob) 

552 

553 if blob is None: 

554 stub = ByteStreamStub(context.cas_channel) # type: ignore # Requires stub regen 

555 grpc_request = ReadRequest(resource_name=resource_name, read_offset=offset, read_limit=limit) 

556 

557 blob = b"" 

558 try: 

559 async for grpc_response in stub.Read(grpc_request): 

560 if grpc_response.data: 

561 if callback is not None: 

562 await callback(grpc_response.data) 

563 

564 if callback is None or cacheable: 

565 blob += grpc_response.data 

566 except RpcError as e: 

567 LOGGER.warning(e.details()) 

568 if e.code() == StatusCode.NOT_FOUND: 

569 raise web.HTTPNotFound() 

570 raise web.HTTPInternalServerError() 

571 

572 if cacheable: 

573 await cache.store_blob(resource_name, blob) 

574 

575 if message_class is not None and callback is None: 

576 return message_class.FromString(blob) 

577 return blob 

578 

579 

580async def _download_directory( 

581 context: Context, cache: ResponseCache, base: str, path: str, directory: Directory 

582) -> None: 

583 """Download the contents of a directory from CAS. 

584 

585 This function takes a Directory message and downloads the directory 

586 contents defined in the message from CAS. Raises a 400 error if the 

587 directory contains a symlink which points to a location outside the 

588 initial directory. 

589 

590 The directory is downloaded recursively depth-first. 

591 

592 Args: 

593 context (Context): The context to use for making gRPC requests. 

594 cache (ResponseCache): The response cache to use when downloading the 

595 directory contents. 

596 base (str): The initial directory path, used to check symlinks don't 

597 escape into the wider filesystem. 

598 path (str): The path to download the directory into. 

599 directory (Directory): The Directory message to fetch the contents of. 

600 

601 """ 

602 for directory_node in directory.directories: 

603 dir_path = os.path.join(path, directory_node.name) 

604 os.mkdir(dir_path) 

605 child = await _fetch_blob(context, cache, directory_node.digest, message_class=Directory) 

606 await _download_directory(context, cache, base, dir_path, child) 

607 

608 for file_node in directory.files: 

609 file_path = os.path.join(path, file_node.name) 

610 async with aiofiles.open(file_path, "wb") as f: 

611 await _fetch_blob(context, cache, file_node.digest, callback=f.write) 

612 if file_node.is_executable: 

613 os.chmod(file_path, 0o755) 

614 

615 for link_node in directory.symlinks: 

616 link_path = os.path.join(path, link_node.name) 

617 target_path = os.path.realpath(link_node.target) 

618 target_relpath = os.path.relpath(base, target_path) 

619 if target_relpath.startswith(os.pardir): 

620 raise web.HTTPBadRequest( 

621 reason="Requested directory contains a symlink targeting a location outside the tarball" 

622 ) 

623 

624 os.symlink(link_node.target, link_path) 

625 

626 

627async def _tarball_from_directory(context: Context, cache: ResponseCache, directory: Directory, tmp_dir: str) -> str: 

628 """Construct a tarball of a directory stored in CAS. 

629 

630 This function fetches the contents of the given directory message into a 

631 temporary directory, and then constructs a tarball of the directory. The 

632 path to this tarball is returned when construction is complete. 

633 

634 Args: 

635 context (Context): The context to use to send the gRPC requests. 

636 cache (ResponseCache): The response cache to use when fetching the 

637 tarball contents. 

638 directory (Directory): The Directory message for the directory we're 

639 making a tarball of. 

640 tmp_dir (str): Path to a temporary directory to use for storing the 

641 directory contents and its tarball. 

642 

643 """ 

644 tarball_dir = tempfile.mkdtemp(dir=tmp_dir) 

645 tarball_path = os.path.join(tmp_dir, "directory.tar.gz") 

646 loop = asyncio.get_event_loop() 

647 

648 # Fetch the contents of the directory into a temporary directory 

649 await _download_directory(context, cache, tarball_dir, tarball_dir, directory) 

650 

651 # Make a tarball from that temporary directory 

652 # NOTE: We do this using loop.run_in_executor to avoid the 

653 # synchronous and blocking tarball construction 

654 tarball_result = await loop.run_in_executor(None, _create_tarball, tarball_dir, tarball_path) 

655 if not tarball_result: 

656 raise web.HTTPInternalServerError() 

657 return tarball_path 

658 

659 

660def get_tarball_handler( 

661 context: Context, 

662 cache: ResponseCache, 

663 allow_all: bool = False, 

664 allowed_origins: UrlCollection = (), 

665 tarball_dir: Optional[str] = None, 

666) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: 

667 """Factory function which returns a handler for tarball requests. 

668 

669 This function also takes care of cleaning up old incomplete tarball constructions 

670 when given a named directory to do the construction in. 

671 

672 The returned handler takes the hash and size_bytes of a Digest of a Directory 

673 message and constructs a tarball of the directory defined by the message. 

674 

675 Args: 

676 context (Context): The context to use to send the gRPC requests. 

677 allow_all (bool): Whether or not to allow all CORS origins. 

678 allowed_origins (list): List of valid CORS origins. 

679 tarball_dir (str): Base directory to use for tarball construction. 

680 

681 """ 

682 

683 class FetchSpec: 

684 """Simple class used to store information about a tarball request. 

685 

686 A class is used here rather than a namedtuple since we need this state 

687 to be mutable. 

688 

689 """ 

690 

691 def __init__(self, *, error: Optional[Exception], event: Optional[asyncio.Event], path: str, refcount: int): 

692 self.error = error 

693 self.event = event 

694 self.path = path 

695 self.refcount = refcount 

696 

697 in_flight_fetches: Dict[str, FetchSpec] = {} 

698 fetch_lock = asyncio.Lock() 

699 

700 # If we have a tarball directory to use, empty all existing tarball constructions from it 

701 # to provide some form of cleanup after a crash. 

702 if tarball_dir is not None: 

703 for path in os.listdir(tarball_dir): 

704 if path.startswith(TARBALL_DIRECTORY_PREFIX): 

705 shutil.rmtree(os.path.join(tarball_dir, path)) 

706 

707 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse: 

708 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}' 

709 LOGGER.info("Received request for a tarball from CAS for blob.", tags=dict(digest=digest_str)) 

710 

711 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"])) 

712 

713 tmp_dir = tempfile.mkdtemp(prefix=TARBALL_DIRECTORY_PREFIX, dir=tarball_dir) 

714 

715 try: 

716 duplicate_request = False 

717 event = None 

718 async with fetch_lock: 

719 if digest_str in in_flight_fetches: 

720 LOGGER.info("Deduplicating request for tarball.", tags=dict(digest=digest_str)) 

721 spec = in_flight_fetches[digest_str] 

722 spec.refcount += 1 

723 event = spec.event 

724 duplicate_request = True 

725 else: 

726 event = asyncio.Event() 

727 in_flight_fetches[digest_str] = FetchSpec(error=None, event=event, path="", refcount=1) 

728 

729 if duplicate_request and event: 

730 # If we're a duplicate of an existing request, then wait for the 

731 # existing request to finish tarball creation before reading the 

732 # path from the cache. 

733 await event.wait() 

734 spec = in_flight_fetches[digest_str] 

735 if spec.error is not None: 

736 raise spec.error 

737 tarball_path = in_flight_fetches[digest_str].path 

738 else: 

739 try: 

740 directory = await _fetch_blob(context, cache, digest, message_class=Directory) 

741 tarball_path = await _tarball_from_directory(context, cache, directory, tmp_dir) 

742 except web.HTTPError as e: 

743 in_flight_fetches[digest_str].error = e 

744 if event: 

745 event.set() 

746 raise e 

747 except Exception as e: 

748 LOGGER.debug("Unexpected error constructing tarball.", tags=dict(digest=digest_str), exc_info=True) 

749 in_flight_fetches[digest_str].error = e 

750 if event: 

751 event.set() 

752 raise web.HTTPInternalServerError() 

753 

754 # Update path in deduplication cache, and set event to notify 

755 # duplicate requests that the tarball is ready 

756 async with fetch_lock: 

757 if event: 

758 in_flight_fetches[digest_str].path = tarball_path 

759 event.set() 

760 

761 response = web.StreamResponse() 

762 

763 # We need to explicitly set CORS headers here, because the middleware that 

764 # normally handles this only adds the header when the response is returned 

765 # from this function. However, when using a StreamResponse with chunked 

766 # encoding enabled, the client begins to receive the response when we call 

767 # `response.write()`. This leads to the request being disallowed due to the 

768 # missing reponse header for clients executing in browsers. 

769 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all) 

770 response.headers.update(cors_headers) 

771 

772 response.enable_chunked_encoding() 

773 await response.prepare(request) 

774 

775 async with aiofiles.open(tarball_path, "rb") as tarball: 

776 await tarball.seek(0) 

777 chunk = await tarball.read(1024) 

778 while chunk: 

779 await response.write(chunk) 

780 chunk = await tarball.read(1024) 

781 return response 

782 

783 except RpcError as e: 

784 LOGGER.warning(e.details()) 

785 if e.code() == StatusCode.NOT_FOUND: 

786 raise web.HTTPNotFound() 

787 raise web.HTTPInternalServerError() 

788 

789 finally: 

790 cleanup = False 

791 async with fetch_lock: 

792 # Decrement refcount now we're done with the tarball. If we're the 

793 # last request interested in the tarball then remove it along with 

794 # its construction directory. 

795 spec = in_flight_fetches[digest_str] 

796 spec.refcount -= 1 

797 if spec.refcount <= 0: 

798 cleanup = True 

799 in_flight_fetches.pop(digest_str) 

800 if cleanup: 

801 shutil.rmtree(tmp_dir) 

802 

803 return _get_tarball_handler 

804 

805 

806def logstream_handler(context: Context) -> Callable[[web.Request], Awaitable[Any]]: 

807 async def _logstream_handler(request: web.Request) -> Any: 

808 LOGGER.info("Receieved request for a LogStream websocket.") 

809 stub = ByteStreamStub(context.logstream_channel) # type: ignore # Requires stub regen 

810 ws = web.WebSocketResponse() 

811 await ws.prepare(request) 

812 

813 async for msg in ws: 

814 if msg.type == WSMsgType.BINARY: 

815 read_request = ReadRequest() 

816 read_request.ParseFromString(msg.data) 

817 

818 read_request.resource_name = f"{read_request.resource_name}" 

819 try: 

820 async for response in stub.Read(read_request): 

821 serialized_response = response.SerializeToString() 

822 if serialized_response: 

823 ws_response = { 

824 "resource_name": read_request.resource_name, 

825 "data": response.data.decode("utf-8"), 

826 "complete": False, 

827 } 

828 await ws.send_json(ws_response) 

829 ws_response = {"resource_name": read_request.resource_name, "data": "", "complete": True} 

830 await ws.send_json(ws_response) 

831 except RpcError as e: 

832 LOGGER.warning(e.details()) 

833 if e.code() == StatusCode.NOT_FOUND: 

834 ws_response = { 

835 "resource_name": read_request.resource_name, 

836 "data": "NOT_FOUND", 

837 "complete": True, 

838 } 

839 await ws.send_json(ws_response) 

840 ws_response = {"resource_name": read_request.resource_name, "data": "INTERNAL", "complete": True} 

841 await ws.send_json(ws_response) 

842 

843 return _logstream_handler