Coverage for /builds/BuildGrid/buildgrid/buildgrid/browser/rest_api.py: 82.95%
387 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-05 15:37 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-05 15:37 +0000
1# Copyright (C) 2021 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License' is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import asyncio
16import logging
17import os
18import shutil
19import tarfile
20import tempfile
21from base64 import b64encode
22from collections import namedtuple
23from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type
25import aiofiles
26from aiohttp import WSMsgType, web
27from aiohttp_middlewares.annotations import UrlCollection
28from grpc import RpcError, StatusCode
29from grpc.aio import Call, Metadata
31from buildgrid._app.cli import Context
32from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import (
33 ActionResult,
34 Digest,
35 Directory,
36 GetActionResultRequest,
37 RequestMetadata,
38)
39from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc_aio import ActionCacheStub
40from buildgrid._protos.buildgrid.v2.query_build_events_pb2 import QueryEventStreamsRequest
41from buildgrid._protos.buildgrid.v2.query_build_events_pb2_grpc_aio import QueryBuildEventsStub
42from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest
43from buildgrid._protos.google.bytestream.bytestream_pb2_grpc_aio import ByteStreamStub
44from buildgrid._protos.google.longrunning import operations_pb2
45from buildgrid._protos.google.longrunning.operations_pb2_grpc_aio import OperationsStub
46from buildgrid.browser.utils import TARBALL_DIRECTORY_PREFIX, ResponseCache, get_cors_headers
47from buildgrid.server.request_metadata_utils import extract_request_metadata
48from buildgrid.settings import BROWSER_MAX_CACHE_ENTRY_SIZE, BROWSER_OPERATION_CACHE_MAX_LENGTH
50LOGGER = logging.getLogger(__name__)
53def query_build_events_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
54 """Factory function which returns a handler for QueryEventStreams.
56 The returned handler uses ``context.channel`` to send a QueryEventStreams
57 request constructed based on the provided URL query parameters. Currently
58 only querying by build_id (equivalent to correlated invocations ID) is
59 supported.
61 The handler returns a serialised QueryEventStreamsResponse, and raises a
62 500 error in the case of some RPC error.
64 Args:
65 context (Context): The context to use to send the gRPC request.
67 """
69 async def _query_build_events_handler(request: web.Request) -> web.Response:
70 build_id = request.rel_url.query.get("build_id", ".*")
72 stub = QueryBuildEventsStub(context.channel) # type: ignore # Requires stub regen
74 grpc_request = QueryEventStreamsRequest(build_id_pattern=build_id)
76 try:
77 grpc_response = await stub.QueryEventStreams(grpc_request)
78 except RpcError as e:
79 LOGGER.warning(e.details())
80 raise web.HTTPInternalServerError()
82 serialized_response = grpc_response.SerializeToString()
83 return web.Response(body=serialized_response)
85 return _query_build_events_handler
88def list_operations_handler(
89 context: Context, cache: ResponseCache
90) -> Callable[[web.Request], Awaitable[web.Response]]:
91 """Factory function which returns a handler for ListOperations.
93 The returned handler uses ``context.channel`` and ``context.instance_name``
94 to send a ListOperations request constructed based on the provided URL
95 query parameters.
97 The handler returns a serialised ListOperationsResponse, raises a 400
98 error in the case of a bad filter or other invalid argument, or raises
99 a 500 error in the case of some other RPC error.
101 Args:
102 context (Context): The context to use to send the gRPC request.
104 """
106 async def _list_operations_handler(request: web.Request) -> web.Response:
107 filter_string = request.rel_url.query.get("q", "")
108 page_token = request.rel_url.query.get("page_token", "")
109 page_size_str = request.rel_url.query.get("page_size")
110 page_size = 0
111 if page_size_str is not None:
112 page_size = int(page_size_str)
114 LOGGER.info(
115 f'Received ListOperations request, filter_string="{filter_string}" '
116 f'page_token="{page_token}" page_size="{page_size}"'
117 )
118 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen
119 grpc_request = operations_pb2.ListOperationsRequest(
120 name=context.instance_name, page_token=page_token, page_size=page_size, filter=filter_string
121 )
123 try:
124 grpc_response = await stub.ListOperations(grpc_request)
125 except RpcError as e:
126 LOGGER.warning(e.details())
127 if e.code() == StatusCode.INVALID_ARGUMENT:
128 raise web.HTTPBadRequest()
129 raise web.HTTPInternalServerError()
131 if grpc_response.operations:
132 for i, operation in enumerate(grpc_response.operations):
133 if i > BROWSER_OPERATION_CACHE_MAX_LENGTH:
134 break
135 await cache.store_operation(operation.name, operation)
137 serialised_response = grpc_response.SerializeToString()
138 return web.Response(body=serialised_response)
140 return _list_operations_handler
143async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]:
144 operation_name = f"{context.instance_name}/{request.match_info['name']}"
146 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen
147 grpc_request = operations_pb2.GetOperationRequest(name=operation_name)
149 try:
150 call = stub.GetOperation(grpc_request)
151 operation = await call
152 except RpcError as e:
153 LOGGER.warning(f"Error fetching operation: {e.details()}")
154 if e.code() == StatusCode.INVALID_ARGUMENT:
155 raise web.HTTPNotFound()
156 raise web.HTTPInternalServerError()
158 return operation, call
161def get_operation_handler(context: Context, cache: ResponseCache) -> Callable[[web.Request], Awaitable[web.Response]]:
162 """Factory function which returns a handler for GetOperation.
164 The returned handler uses ``context.channel`` and ``context.instance_name``
165 to send a GetOperation request constructed based on the path component of
166 the URL.
168 The handler returns a serialised Operation message, raises a 400 error in
169 the case of an invalid operation name, or raises a 500 error in the case
170 of some other RPC error.
172 Args:
173 context (Context): The context to use to send the gRPC request.
175 """
177 async def _get_operation_handler(request: web.Request) -> web.Response:
178 name = request.match_info["name"]
179 LOGGER.info(f'Received GetOperation request for "{name}"')
181 # Check the cache to see if we already fetched this operation recently
182 operation = await cache.get_operation(name)
184 # Fall back to sending an actual gRPC request
185 if operation is None:
186 operation, _ = await _get_operation(context, request)
187 await cache.store_operation(name, operation)
189 serialised_response = operation.SerializeToString()
190 return web.Response(body=serialised_response)
192 return _get_operation_handler
195def get_operation_request_metadata_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
196 """Factory function which returns a handler to get RequestMetadata.
198 The returned handler uses ``context.channel`` and ``context.instance_name``
199 to send a GetOperation request constructed based on the path component of
200 the URL.
202 The handler returns a serialised RequestMetadata proto message, retrieved
203 from the trailing metadata of the GetOperation response. In the event of
204 an invalid operation name it raises a 404 error, and raises a 500 error in
205 the case of some other RPC error.
207 Args:
208 context (Context): The context to use to send the gRPC request.
210 """
212 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response:
213 LOGGER.info(f'Received request for RequestMetadata for "{request.match_info["name"]}')
214 _, call = await _get_operation(context, request)
215 metadata = await call.trailing_metadata()
217 def extract_metadata(m: Metadata) -> RequestMetadata:
218 # `m` contains a list of tuples, but `extract_request_metadata()`
219 # expects a `key` and `value` attributes.
220 MetadataTuple = namedtuple("MetadataTuple", ["key", "value"])
221 return extract_request_metadata([MetadataTuple(entry[0], entry[1]) for entry in m])
223 request_metadata = extract_metadata(metadata)
224 return web.Response(body=request_metadata.SerializeToString())
226 return _get_operation_request_metadata_handler
229def cancel_operation_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
230 """Factory function which returns a handler for CancelOperation.
232 The returned handler uses ``context.channel`` and ``context.instance_name``
233 to send a CancelOperation request constructed based on the path component of
234 the URL.
236 The handler raises a 404 error in the case of an invalid operation name,
237 or a 500 error in the case of some other RPC error.
239 On success, the response is empty.
241 Args:
242 context (Context): The context to use to send the gRPC request.
244 """
246 async def _cancel_operation_handler(request: web.Request) -> web.Response:
247 LOGGER.info(f'Received CancelOperation request for "{request.match_info["name"]}"')
248 operation_name = f"{context.instance_name}/{request.match_info['name']}"
250 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen
251 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name)
253 try:
254 await stub.CancelOperation(grpc_request)
255 return web.Response()
256 except RpcError as e:
257 LOGGER.warning(e.details())
258 if e.code() == StatusCode.INVALID_ARGUMENT:
259 raise web.HTTPNotFound()
260 raise web.HTTPInternalServerError()
262 return _cancel_operation_handler
265async def _fetch_action_result(
266 context: Context, request: web.Request, cache: ResponseCache, cache_key: str
267) -> ActionResult:
268 stub = ActionCacheStub(context.cache_channel) # type: ignore # Requires stub regen
269 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"]))
270 grpc_request = GetActionResultRequest(action_digest=digest, instance_name=context.instance_name)
272 try:
273 result = await stub.GetActionResult(grpc_request)
274 except RpcError as e:
275 LOGGER.warning(f"Failed to fetch ActionResult: [{e.details()}]")
276 if e.code() == StatusCode.NOT_FOUND:
277 raise web.HTTPNotFound()
278 raise web.HTTPInternalServerError()
280 await cache.store_action_result(cache_key, result)
281 return result
284def get_action_result_handler(
285 context: Context, cache: ResponseCache, cache_capacity: int = 512
286) -> Callable[[web.Request], Awaitable[web.Response]]:
287 """Factory function which returns a handler for GetActionResult.
289 The returned handler uses ``context.channel`` and ``context.instance_name``
290 to send a GetActionResult request constructed based on the path components
291 of the URL.
293 The handler returns a serialised ActionResult message, raises a 404 error
294 if there's no result cached, or raises a 500 error in the case of some
295 other RPC error.
297 Args:
298 context (Context): The context to use to send the gRPC request.
299 cache_capacity (int): The number of ActionResults to cache in memory
300 to avoid hitting the actual ActionCache.
302 """
304 class FetchSpec:
305 """Simple class used to store information about a GetActionResult request.
307 A class is used here rather than a namedtuple since we need this state
308 to be mutable.
310 """
312 def __init__(
313 self, *, error: Optional[Exception], event: asyncio.Event, result: Optional[ActionResult], refcount: int
314 ):
315 self.error = error
316 self.event = event
317 self.result = result
318 self.refcount = refcount
320 in_flight_fetches: Dict[str, FetchSpec] = {}
321 fetch_lock = asyncio.Lock()
323 async def _get_action_result_handler(request: web.Request) -> web.Response:
324 LOGGER.info(
325 "Received GetActionResult request for "
326 f'"{request.match_info["hash"]}/{request.match_info["size_bytes"]}"'
327 )
329 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}'
330 result = await cache.get_action_result(cache_key)
332 if result is None:
333 try:
334 duplicate_request = False
335 spec = None
336 async with fetch_lock:
337 if cache_key in in_flight_fetches:
338 LOGGER.info(f"Deduplicating GetActionResult request for [{cache_key}]")
339 spec = in_flight_fetches[cache_key]
340 spec.refcount += 1
341 duplicate_request = True
342 else:
343 spec = FetchSpec(error=None, event=asyncio.Event(), result=None, refcount=1)
344 in_flight_fetches[cache_key] = spec
346 if duplicate_request and spec:
347 # If we're a duplicate of an existing request, then wait for the
348 # existing request to finish fetching from the ActionCache.
349 await spec.event.wait()
350 if spec.error is not None:
351 raise spec.error
352 if spec.result is None:
353 # Note: this should be impossible, but lets guard against accidentally setting
354 # the event before the result is populated
355 LOGGER.info(f"Result not set in deduplicated request for [{cache_key}]")
356 raise web.HTTPInternalServerError()
357 result = spec.result
358 else:
359 try:
360 result = await _fetch_action_result(context, request, cache, cache_key)
361 except Exception as e:
362 async with fetch_lock:
363 if spec is not None:
364 spec.error = e
365 raise e
367 async with fetch_lock:
368 if spec is not None:
369 spec.result = result
370 spec.event.set()
372 finally:
373 async with fetch_lock:
374 # Decrement refcount now we're done with the result. If we're the
375 # last request interested in the result then remove it from the
376 # `in_flight_fetches` dictionary.
377 spec = in_flight_fetches.get(cache_key)
378 if spec is not None:
379 spec.refcount -= 1
380 if spec.refcount <= 0:
381 in_flight_fetches.pop(cache_key)
383 return web.Response(body=result.SerializeToString())
385 return _get_action_result_handler
388def get_blob_handler(
389 context: Context, cache: ResponseCache, allow_all: bool = False, allowed_origins: UrlCollection = ()
390) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
391 async def _get_blob_handler(request: web.Request) -> web.StreamResponse:
392 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"]))
393 try:
394 offset = int(request.rel_url.query.get("offset", "0"))
395 limit = int(request.rel_url.query.get("limit", "0"))
396 except ValueError:
397 raise web.HTTPBadRequest()
399 response = web.StreamResponse()
401 # We need to explicitly set CORS headers here, because the middleware that
402 # normally handles this only adds the header when the response is returned
403 # from this function. However, when using a StreamResponse with chunked
404 # encoding enabled, the client begins to receive the response when we call
405 # `response.write()`. This leads to the request being disallowed due to the
406 # missing reponse header for clients executing in browsers.
407 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all)
408 response.headers.update(cors_headers)
410 if request.rel_url.query.get("raw", "") == "1":
411 response.headers["Content-type"] = "text/plain; charset=utf-8"
412 else:
413 # Setting the Content-Disposition header so that when
414 # downloading a blob with a browser the default name uniquely
415 # identifies its contents:
416 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}"
418 # For partial reads also include the indices to prevent file mix-ups:
419 if offset != 0 or limit != 0:
420 filename += f"_chunk_{offset}-"
421 if limit != 0:
422 filename += f"{offset + limit}"
424 response.headers["Content-Disposition"] = f"Attachment;filename={filename}"
426 prepared = False
428 async def _callback(data: bytes) -> None:
429 nonlocal prepared
430 if not prepared:
431 # Prepare for chunked encoding when the callback is first called,
432 # so that we're sure we actually have some data before doing so.
433 response.enable_chunked_encoding()
434 await response.prepare(request)
435 prepared = True
436 await response.write(data)
438 await _fetch_blob(context, cache, digest, callback=_callback, offset=offset, limit=limit)
440 return response
442 return _get_blob_handler
445def _create_tarball(directory: str, name: str) -> bool:
446 """Makes a tarball from a given directory.
448 Returns True if the tarball was successfully created, and False if not.
450 Args:
451 directory (str): The directory to tar up.
452 name (str): The name of the tarball to be produced.
454 """
455 try:
456 with tarfile.open(name, "w:gz") as tarball:
457 tarball.add(directory, arcname="")
458 except Exception:
459 return False
460 return True
463async def _fetch_blob(
464 context: Context,
465 cache: ResponseCache,
466 digest: Digest,
467 message_class: Optional[Type[Any]] = None,
468 callback: Optional[Callable[[bytes], Awaitable[Any]]] = None,
469 offset: int = 0,
470 limit: int = 0,
471) -> Any:
472 """Fetch a blob from CAS.
474 This function sends a ByteStream Read request for the given digest. If ``callback``
475 is set then the callback is called with the data in each ReadResponse message and
476 this function returns an empty bytes object once the response is finished. If
477 ``message_class`` is set with no ``callback`` set then this function calls
478 ``message_class.FromString`` on the fetched blob and returns the result.
480 If neither ``callback`` or ``message_class`` are set then this function just returns
481 the raw blob that was fetched from CAS.
483 Args:
484 context (Context): The context to use to send the gRPC request.
485 cache (ResponseCache): The response cache to check/update with the fetched blob.
486 digest (Digest): The Digest of the blob to fetch from CAS.
487 message_class (type): A class which implements a ``FromString`` class method.
488 The method should take a bytes object expected to contain the blob fetched
489 from CAS.
490 callback (callable): A function or other callable to act on a subset of the
491 blob contents.
492 offset (int): Read offset to start reading the blob at. Defaults to 0, the start
493 of the blob.
494 limit (int): Maximum number of bytes to read from the blob. Defaults to 0, no
495 limit.
497 """
498 cacheable = digest.size_bytes <= BROWSER_MAX_CACHE_ENTRY_SIZE
499 resource_name = f"{context.instance_name}/blobs/{digest.hash}/{digest.size_bytes}"
500 blob = None
501 if cacheable:
502 blob = await cache.get_blob(resource_name)
503 if blob is not None and callback is not None:
504 if limit > 0:
505 try:
506 blob = blob[offset : offset + limit]
507 except IndexError:
508 raise web.HTTPBadRequest()
509 await callback(blob)
511 if blob is None:
512 stub = ByteStreamStub(context.cas_channel) # type: ignore # Requires stub regen
513 grpc_request = ReadRequest(resource_name=resource_name, read_offset=offset, read_limit=limit)
515 blob = b""
516 try:
517 async for grpc_response in stub.Read(grpc_request):
518 if grpc_response.data:
519 if callback is not None:
520 await callback(grpc_response.data)
522 if callback is None or cacheable:
523 blob += grpc_response.data
524 except RpcError as e:
525 LOGGER.warning(e.details())
526 if e.code() == StatusCode.NOT_FOUND:
527 raise web.HTTPNotFound()
528 raise web.HTTPInternalServerError()
530 if cacheable:
531 await cache.store_blob(resource_name, blob)
533 if message_class is not None and callback is None:
534 return message_class.FromString(blob)
535 return blob
538async def _download_directory(
539 context: Context, cache: ResponseCache, base: str, path: str, directory: Directory
540) -> None:
541 """Download the contents of a directory from CAS.
543 This function takes a Directory message and downloads the directory
544 contents defined in the message from CAS. Raises a 400 error if the
545 directory contains a symlink which points to a location outside the
546 initial directory.
548 The directory is downloaded recursively depth-first.
550 Args:
551 context (Context): The context to use for making gRPC requests.
552 cache (ResponseCache): The response cache to use when downloading the
553 directory contents.
554 base (str): The initial directory path, used to check symlinks don't
555 escape into the wider filesystem.
556 path (str): The path to download the directory into.
557 directory (Directory): The Directory message to fetch the contents of.
559 """
560 for directory_node in directory.directories:
561 dir_path = os.path.join(path, directory_node.name)
562 os.mkdir(dir_path)
563 child = await _fetch_blob(context, cache, directory_node.digest, message_class=Directory)
564 await _download_directory(context, cache, base, dir_path, child)
566 for file_node in directory.files:
567 file_path = os.path.join(path, file_node.name)
568 async with aiofiles.open(file_path, "wb") as f:
569 await _fetch_blob(context, cache, file_node.digest, callback=f.write)
570 if file_node.is_executable:
571 os.chmod(file_path, 0o755)
573 for link_node in directory.symlinks:
574 link_path = os.path.join(path, link_node.name)
575 target_path = os.path.realpath(link_node.target)
576 target_relpath = os.path.relpath(base, target_path)
577 if target_relpath.startswith(os.pardir):
578 raise web.HTTPBadRequest(
579 reason="Requested directory contains a symlink targeting a location outside the tarball"
580 )
582 os.symlink(link_node.target, link_path)
585async def _tarball_from_directory(context: Context, cache: ResponseCache, directory: Directory, tmp_dir: str) -> str:
586 """Construct a tarball of a directory stored in CAS.
588 This function fetches the contents of the given directory message into a
589 temporary directory, and then constructs a tarball of the directory. The
590 path to this tarball is returned when construction is complete.
592 Args:
593 context (Context): The context to use to send the gRPC requests.
594 cache (ResponseCache): The response cache to use when fetching the
595 tarball contents.
596 directory (Directory): The Directory message for the directory we're
597 making a tarball of.
598 tmp_dir (str): Path to a temporary directory to use for storing the
599 directory contents and its tarball.
601 """
602 tarball_dir = tempfile.mkdtemp(dir=tmp_dir)
603 tarball_path = os.path.join(tmp_dir, "directory.tar.gz")
604 loop = asyncio.get_event_loop()
606 # Fetch the contents of the directory into a temporary directory
607 await _download_directory(context, cache, tarball_dir, tarball_dir, directory)
609 # Make a tarball from that temporary directory
610 # NOTE: We do this using loop.run_in_executor to avoid the
611 # synchronous and blocking tarball construction
612 tarball_result = await loop.run_in_executor(None, _create_tarball, tarball_dir, tarball_path)
613 if not tarball_result:
614 raise web.HTTPInternalServerError()
615 return tarball_path
618def get_tarball_handler(
619 context: Context,
620 cache: ResponseCache,
621 allow_all: bool = False,
622 allowed_origins: UrlCollection = (),
623 tarball_dir: Optional[str] = None,
624) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
625 """Factory function which returns a handler for tarball requests.
627 This function also takes care of cleaning up old incomplete tarball constructions
628 when given a named directory to do the construction in.
630 The returned handler takes the hash and size_bytes of a Digest of a Directory
631 message and constructs a tarball of the directory defined by the message.
633 Args:
634 context (Context): The context to use to send the gRPC requests.
635 allow_all (bool): Whether or not to allow all CORS origins.
636 allowed_origins (list): List of valid CORS origins.
637 tarball_dir (str): Base directory to use for tarball construction.
639 """
641 class FetchSpec:
642 """Simple class used to store information about a tarball request.
644 A class is used here rather than a namedtuple since we need this state
645 to be mutable.
647 """
649 def __init__(self, *, error: Optional[Exception], event: Optional[asyncio.Event], path: str, refcount: int):
650 self.error = error
651 self.event = event
652 self.path = path
653 self.refcount = refcount
655 in_flight_fetches: Dict[str, FetchSpec] = {}
656 fetch_lock = asyncio.Lock()
658 # If we have a tarball directory to use, empty all existing tarball constructions from it
659 # to provide some form of cleanup after a crash.
660 if tarball_dir is not None:
661 for path in os.listdir(tarball_dir):
662 if path.startswith(TARBALL_DIRECTORY_PREFIX):
663 shutil.rmtree(os.path.join(tarball_dir, path))
665 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse:
666 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}'
667 LOGGER.info(f"Received request for a tarball from CAS for blob with digest [{digest_str}]")
669 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"]))
671 tmp_dir = tempfile.mkdtemp(prefix=TARBALL_DIRECTORY_PREFIX, dir=tarball_dir)
673 try:
674 duplicate_request = False
675 event = None
676 async with fetch_lock:
677 if digest_str in in_flight_fetches:
678 LOGGER.info(f"Deduplicating request for tarball of [{digest_str}]")
679 spec = in_flight_fetches[digest_str]
680 spec.refcount += 1
681 event = spec.event
682 duplicate_request = True
683 else:
684 event = asyncio.Event()
685 in_flight_fetches[digest_str] = FetchSpec(error=None, event=event, path="", refcount=1)
687 if duplicate_request and event:
688 # If we're a duplicate of an existing request, then wait for the
689 # existing request to finish tarball creation before reading the
690 # path from the cache.
691 await event.wait()
692 spec = in_flight_fetches[digest_str]
693 if spec.error is not None:
694 raise spec.error
695 tarball_path = in_flight_fetches[digest_str].path
696 else:
697 try:
698 directory = await _fetch_blob(context, cache, digest, message_class=Directory)
699 tarball_path = await _tarball_from_directory(context, cache, directory, tmp_dir)
700 except web.HTTPError as e:
701 in_flight_fetches[digest_str].error = e
702 if event:
703 event.set()
704 raise e
705 except Exception as e:
706 LOGGER.debug("Unexpected error constructing tarball", exc_info=True)
707 in_flight_fetches[digest_str].error = e
708 if event:
709 event.set()
710 raise web.HTTPInternalServerError()
712 # Update path in deduplication cache, and set event to notify
713 # duplicate requests that the tarball is ready
714 async with fetch_lock:
715 if event:
716 in_flight_fetches[digest_str].path = tarball_path
717 event.set()
719 response = web.StreamResponse()
721 # We need to explicitly set CORS headers here, because the middleware that
722 # normally handles this only adds the header when the response is returned
723 # from this function. However, when using a StreamResponse with chunked
724 # encoding enabled, the client begins to receive the response when we call
725 # `response.write()`. This leads to the request being disallowed due to the
726 # missing reponse header for clients executing in browsers.
727 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all)
728 response.headers.update(cors_headers)
730 response.enable_chunked_encoding()
731 await response.prepare(request)
733 async with aiofiles.open(tarball_path, "rb") as tarball:
734 await tarball.seek(0)
735 chunk = await tarball.read(1024)
736 while chunk:
737 await response.write(chunk)
738 chunk = await tarball.read(1024)
739 return response
741 except RpcError as e:
742 LOGGER.warning(e.details())
743 if e.code() == StatusCode.NOT_FOUND:
744 raise web.HTTPNotFound()
745 raise web.HTTPInternalServerError()
747 finally:
748 cleanup = False
749 async with fetch_lock:
750 # Decrement refcount now we're done with the tarball. If we're the
751 # last request interested in the tarball then remove it along with
752 # its construction directory.
753 spec = in_flight_fetches[digest_str]
754 spec.refcount -= 1
755 if spec.refcount <= 0:
756 cleanup = True
757 in_flight_fetches.pop(digest_str)
758 if cleanup:
759 shutil.rmtree(tmp_dir)
761 return _get_tarball_handler
764def logstream_handler(context: Context) -> Callable[[web.Request], Awaitable[Any]]:
765 async def _logstream_handler(request: web.Request) -> Any:
766 LOGGER.info("Receieved request for a LogStream websocket")
767 stub = ByteStreamStub(context.logstream_channel) # type: ignore # Requires stub regen
768 ws = web.WebSocketResponse()
769 await ws.prepare(request)
771 async for msg in ws:
772 if msg.type == WSMsgType.BINARY:
773 read_request = ReadRequest()
774 read_request.ParseFromString(msg.data)
776 read_request.resource_name = f"{context.instance_name}/{read_request.resource_name}"
777 try:
778 async for response in stub.Read(read_request):
779 serialized_response = response.SerializeToString()
780 if serialized_response:
781 ws_response = {
782 "resource_name": read_request.resource_name,
783 "data": b64encode(serialized_response).decode("utf-8"),
784 "complete": False,
785 }
786 await ws.send_json(ws_response)
787 ws_response = {"resource_name": read_request.resource_name, "data": "", "complete": True}
788 await ws.send_json(ws_response)
789 except RpcError as e:
790 LOGGER.warning(e.details())
791 if e.code() == StatusCode.NOT_FOUND:
792 ws_response = {
793 "resource_name": read_request.resource_name,
794 "data": "NOT_FOUND",
795 "complete": True,
796 }
797 await ws.send_json(ws_response)
798 ws_response = {"resource_name": read_request.resource_name, "data": "INTERNAL", "complete": True}
799 await ws.send_json(ws_response)
801 return _logstream_handler