Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/browser/rest_api.py: 82.72%
405 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2021 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License' is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import asyncio
16import json
17import os
18import shutil
19import tarfile
20import tempfile
21from collections import namedtuple
22from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type
24import aiofiles
25from aiohttp import WSMsgType, web
26from aiohttp_middlewares.annotations import UrlCollection
27from grpc import RpcError, StatusCode
28from grpc.aio import Call, Metadata
30from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import (
31 ActionResult,
32 Digest,
33 Directory,
34 GetActionResultRequest,
35 RequestMetadata,
36)
37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc_aio import ActionCacheStub
38from buildgrid._protos.buildgrid.v2.query_build_events_pb2 import QueryEventStreamsRequest
39from buildgrid._protos.buildgrid.v2.query_build_events_pb2_grpc_aio import QueryBuildEventsStub
40from buildgrid._protos.google.bytestream.bytestream_pb2 import ReadRequest
41from buildgrid._protos.google.bytestream.bytestream_pb2_grpc_aio import ByteStreamStub
42from buildgrid._protos.google.longrunning import operations_pb2
43from buildgrid._protos.google.longrunning.operations_pb2_grpc_aio import OperationsStub
44from buildgrid.server.app.cli import Context
45from buildgrid.server.browser.utils import TARBALL_DIRECTORY_PREFIX, ResponseCache, get_cors_headers
46from buildgrid.server.logging import buildgrid_logger
47from buildgrid.server.metadata import extract_request_metadata, extract_trailing_client_identity
48from buildgrid.server.operations.filtering.interpreter import VALID_OPERATION_FILTERS, OperationFilterSpec
49from buildgrid.server.operations.filtering.sanitizer import DatetimeValueSanitizer, SortKeyValueSanitizer
50from buildgrid.server.settings import BROWSER_MAX_CACHE_ENTRY_SIZE
52LOGGER = buildgrid_logger(__name__)
55def query_build_events_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
56 """Factory function which returns a handler for QueryEventStreams.
58 The returned handler uses ``context.channel`` to send a QueryEventStreams
59 request constructed based on the provided URL query parameters. Currently
60 only querying by build_id (equivalent to correlated invocations ID) is
61 supported.
63 The handler returns a serialised QueryEventStreamsResponse, and raises a
64 500 error in the case of some RPC error.
66 Args:
67 context (Context): The context to use to send the gRPC request.
69 """
71 async def _query_build_events_handler(request: web.Request) -> web.Response:
72 build_id = request.rel_url.query.get("build_id", ".*")
74 stub = QueryBuildEventsStub(context.channel) # type: ignore # Requires stub regen
76 grpc_request = QueryEventStreamsRequest(build_id_pattern=build_id)
78 try:
79 grpc_response = await stub.QueryEventStreams(grpc_request)
80 except RpcError as e:
81 LOGGER.warning(e.details())
82 raise web.HTTPInternalServerError()
84 serialized_response = grpc_response.SerializeToString()
85 return web.Response(body=serialized_response)
87 return _query_build_events_handler
90async def get_operation_filters_handler(request: web.Request) -> web.Response:
91 """Return the available Operation filter keys."""
93 def _generate_filter_spec(key: str, spec: OperationFilterSpec) -> Dict[str, Any]:
94 comparators = ["<", "<=", "=", "!=", ">=", ">"]
95 filter_type = "text"
96 if isinstance(spec.sanitizer, SortKeyValueSanitizer):
97 comparators = ["="]
98 elif isinstance(spec.sanitizer, DatetimeValueSanitizer):
99 filter_type = "datetime"
101 ret = {
102 "comparators": comparators,
103 "description": spec.description,
104 "key": key,
105 "name": spec.name,
106 "type": filter_type,
107 }
109 try:
110 ret["values"] = spec.sanitizer.valid_values
111 except NotImplementedError:
112 pass
113 return ret
115 operation_filters = [_generate_filter_spec(key, spec) for key, spec in VALID_OPERATION_FILTERS.items()]
116 return web.Response(text=json.dumps(operation_filters))
119def list_operations_handler(
120 context: Context, cache: ResponseCache
121) -> Callable[[web.Request], Awaitable[web.Response]]:
122 """Factory function which returns a handler for ListOperations.
124 The returned handler uses ``context.channel`` and ``context.instance_name``
125 to send a ListOperations request constructed based on the provided URL
126 query parameters.
128 The handler returns a serialised ListOperationsResponse, raises a 400
129 error in the case of a bad filter or other invalid argument, or raises
130 a 500 error in the case of some other RPC error.
132 Args:
133 context (Context): The context to use to send the gRPC request.
135 """
137 async def _list_operations_handler(request: web.Request) -> web.Response:
138 filter_string = request.rel_url.query.get("q", "")
139 page_token = request.rel_url.query.get("page_token", "")
140 page_size_str = request.rel_url.query.get("page_size")
141 page_size = 0
142 if page_size_str is not None:
143 page_size = int(page_size_str)
145 LOGGER.info(
146 "Received ListOperations request.",
147 tags=dict(filter_string=filter_string, page_token=page_token, page_size=page_size),
148 )
149 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen
150 grpc_request = operations_pb2.ListOperationsRequest(
151 name=context.instance_name, page_token=page_token, page_size=page_size, filter=filter_string
152 )
154 try:
155 grpc_response = await stub.ListOperations(grpc_request)
156 except RpcError as e:
157 LOGGER.warning(e.details())
158 if e.code() == StatusCode.INVALID_ARGUMENT:
159 raise web.HTTPBadRequest()
160 raise web.HTTPInternalServerError()
162 serialised_response = grpc_response.SerializeToString()
163 return web.Response(body=serialised_response)
165 return _list_operations_handler
168async def _get_operation(context: Context, request: web.Request) -> Tuple[operations_pb2.Operation, Call]:
169 operation_name = f"{context.instance_name}/{request.match_info['name']}"
171 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen
172 grpc_request = operations_pb2.GetOperationRequest(name=operation_name)
174 try:
175 call = stub.GetOperation(grpc_request)
176 operation = await call
177 except RpcError as e:
178 LOGGER.warning(f"Error fetching operation: {e.details()}")
179 if e.code() == StatusCode.INVALID_ARGUMENT:
180 raise web.HTTPNotFound()
181 raise web.HTTPInternalServerError()
183 return operation, call
186def get_operation_handler(context: Context, cache: ResponseCache) -> Callable[[web.Request], Awaitable[web.Response]]:
187 """Factory function which returns a handler for GetOperation.
189 The returned handler uses ``context.channel`` and ``context.instance_name``
190 to send a GetOperation request constructed based on the path component of
191 the URL.
193 The handler returns a serialised Operation message, raises a 400 error in
194 the case of an invalid operation name, or raises a 500 error in the case
195 of some other RPC error.
197 Args:
198 context (Context): The context to use to send the gRPC request.
200 """
202 async def _get_operation_handler(request: web.Request) -> web.Response:
203 name = request.match_info["name"]
204 LOGGER.info("Received GetOperation request.", tags=dict(name=name))
205 operation, _ = await _get_operation(context, request)
207 serialised_response = operation.SerializeToString()
208 return web.Response(body=serialised_response)
210 return _get_operation_handler
213def get_operation_request_metadata_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
214 """Factory function which returns a handler to get RequestMetadata.
216 The returned handler uses ``context.channel`` and ``context.instance_name``
217 to send a GetOperation request constructed based on the path component of
218 the URL.
220 The handler returns a serialised RequestMetadata proto message, retrieved
221 from the trailing metadata of the GetOperation response. In the event of
222 an invalid operation name it raises a 404 error, and raises a 500 error in
223 the case of some other RPC error.
225 Args:
226 context (Context): The context to use to send the gRPC request.
228 """
230 async def _get_operation_request_metadata_handler(request: web.Request) -> web.Response:
231 LOGGER.info("Received request for RequestMetadata.", tags=dict(name=str(request.match_info["name"])))
232 _, call = await _get_operation(context, request)
233 metadata = await call.trailing_metadata()
235 def extract_metadata(m: Metadata) -> RequestMetadata:
236 # `m` contains a list of tuples, but `extract_request_metadata()`
237 # expects a `key` and `value` attributes.
238 MetadataTuple = namedtuple("MetadataTuple", ["key", "value"])
239 return extract_request_metadata([MetadataTuple(entry[0], entry[1]) for entry in m])
241 request_metadata = extract_metadata(metadata)
242 return web.Response(body=request_metadata.SerializeToString())
244 return _get_operation_request_metadata_handler
247def get_operation_client_identity_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
248 """Factory function which returns a handler to get ClientIdentity metadata.
250 The returned handler uses ``context.channel`` and ``context.instance_name``
251 to send a GetOperation request constructed based on the path component of
252 the URL.
254 The handler returns a serialised ClientIdentity proto message, retrieved
255 from the trailing metadata of the GetOperation response. In the event of
256 an invalid operation name it raises a 404 error, and raises a 500 error in
257 the case of some other RPC error.
259 Args:
260 context (Context): The context to use to send the gRPC request.
262 """
264 async def _get_operation_client_identity_handler(request: web.Request) -> web.Response:
265 LOGGER.info("Received request for RequestMetadata.", tags=dict(name=str(request.match_info["name"])))
266 _, call = await _get_operation(context, request)
267 metadata = await call.trailing_metadata()
268 client_identity = extract_trailing_client_identity(metadata)
269 return web.Response(body=client_identity.SerializeToString())
271 return _get_operation_client_identity_handler
274def cancel_operation_handler(context: Context) -> Callable[[web.Request], Awaitable[web.Response]]:
275 """Factory function which returns a handler for CancelOperation.
277 The returned handler uses ``context.channel`` and ``context.instance_name``
278 to send a CancelOperation request constructed based on the path component of
279 the URL.
281 The handler raises a 404 error in the case of an invalid operation name,
282 or a 500 error in the case of some other RPC error.
284 On success, the response is empty.
286 Args:
287 context (Context): The context to use to send the gRPC request.
289 """
291 async def _cancel_operation_handler(request: web.Request) -> web.Response:
292 LOGGER.info("Received CancelOperation request.", tags=dict(name=str(request.match_info["name"])))
293 operation_name = f"{context.instance_name}/{request.match_info['name']}"
295 stub = OperationsStub(context.operations_channel) # type: ignore # Requires stub regen
296 grpc_request = operations_pb2.CancelOperationRequest(name=operation_name)
298 try:
299 await stub.CancelOperation(grpc_request)
300 return web.Response()
301 except RpcError as e:
302 LOGGER.warning(e.details())
303 if e.code() == StatusCode.INVALID_ARGUMENT:
304 raise web.HTTPNotFound()
305 raise web.HTTPInternalServerError()
307 return _cancel_operation_handler
310async def _fetch_action_result(
311 context: Context, request: web.Request, cache: ResponseCache, cache_key: str
312) -> ActionResult:
313 stub = ActionCacheStub(context.cache_channel) # type: ignore # Requires stub regen
314 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"]))
315 grpc_request = GetActionResultRequest(action_digest=digest, instance_name=context.instance_name)
317 try:
318 result = await stub.GetActionResult(grpc_request)
319 except RpcError as e:
320 LOGGER.warning(f"Failed to fetch ActionResult: [{e.details()}]")
321 if e.code() == StatusCode.NOT_FOUND:
322 raise web.HTTPNotFound()
323 raise web.HTTPInternalServerError()
325 await cache.store_action_result(cache_key, result)
326 return result
329def get_action_result_handler(
330 context: Context, cache: ResponseCache, cache_capacity: int = 512
331) -> Callable[[web.Request], Awaitable[web.Response]]:
332 """Factory function which returns a handler for GetActionResult.
334 The returned handler uses ``context.channel`` and ``context.instance_name``
335 to send a GetActionResult request constructed based on the path components
336 of the URL.
338 The handler returns a serialised ActionResult message, raises a 404 error
339 if there's no result cached, or raises a 500 error in the case of some
340 other RPC error.
342 Args:
343 context (Context): The context to use to send the gRPC request.
344 cache_capacity (int): The number of ActionResults to cache in memory
345 to avoid hitting the actual ActionCache.
347 """
349 class FetchSpec:
350 """Simple class used to store information about a GetActionResult request.
352 A class is used here rather than a namedtuple since we need this state
353 to be mutable.
355 """
357 def __init__(
358 self, *, error: Optional[Exception], event: asyncio.Event, result: Optional[ActionResult], refcount: int
359 ):
360 self.error = error
361 self.event = event
362 self.result = result
363 self.refcount = refcount
365 in_flight_fetches: Dict[str, FetchSpec] = {}
366 fetch_lock = asyncio.Lock()
368 async def _get_action_result_handler(request: web.Request) -> web.Response:
369 cache_key = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}'
370 LOGGER.info("Received GetActionResult request.", tags=dict(cache_key=cache_key))
372 result = await cache.get_action_result(cache_key)
374 if result is None:
375 try:
376 duplicate_request = False
377 spec = None
378 async with fetch_lock:
379 if cache_key in in_flight_fetches:
380 LOGGER.info("Deduplicating GetActionResult request.", tags=dict(cache_key=cache_key))
381 spec = in_flight_fetches[cache_key]
382 spec.refcount += 1
383 duplicate_request = True
384 else:
385 spec = FetchSpec(error=None, event=asyncio.Event(), result=None, refcount=1)
386 in_flight_fetches[cache_key] = spec
388 if duplicate_request and spec:
389 # If we're a duplicate of an existing request, then wait for the
390 # existing request to finish fetching from the ActionCache.
391 await spec.event.wait()
392 if spec.error is not None:
393 raise spec.error
394 if spec.result is None:
395 # Note: this should be impossible, but lets guard against accidentally setting
396 # the event before the result is populated
397 LOGGER.info("Result not set in deduplicated request.", tags=dict(cache_key=cache_key))
398 raise web.HTTPInternalServerError()
399 result = spec.result
400 else:
401 try:
402 result = await _fetch_action_result(context, request, cache, cache_key)
403 except Exception as e:
404 async with fetch_lock:
405 if spec is not None:
406 spec.error = e
407 raise e
409 async with fetch_lock:
410 if spec is not None:
411 spec.result = result
412 spec.event.set()
414 finally:
415 async with fetch_lock:
416 # Decrement refcount now we're done with the result. If we're the
417 # last request interested in the result then remove it from the
418 # `in_flight_fetches` dictionary.
419 spec = in_flight_fetches.get(cache_key)
420 if spec is not None:
421 spec.refcount -= 1
422 if spec.refcount <= 0:
423 in_flight_fetches.pop(cache_key)
425 return web.Response(body=result.SerializeToString())
427 return _get_action_result_handler
430def get_blob_handler(
431 context: Context, cache: ResponseCache, allow_all: bool = False, allowed_origins: UrlCollection = ()
432) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
433 async def _get_blob_handler(request: web.Request) -> web.StreamResponse:
434 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"]))
435 try:
436 offset = int(request.rel_url.query.get("offset", "0"))
437 limit = int(request.rel_url.query.get("limit", "0"))
438 except ValueError:
439 raise web.HTTPBadRequest()
441 response = web.StreamResponse()
443 # We need to explicitly set CORS headers here, because the middleware that
444 # normally handles this only adds the header when the response is returned
445 # from this function. However, when using a StreamResponse with chunked
446 # encoding enabled, the client begins to receive the response when we call
447 # `response.write()`. This leads to the request being disallowed due to the
448 # missing reponse header for clients executing in browsers.
449 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all)
450 response.headers.update(cors_headers)
452 if request.rel_url.query.get("raw", "") == "1":
453 response.headers["Content-type"] = "text/plain; charset=utf-8"
454 else:
455 # Setting the Content-Disposition header so that when
456 # downloading a blob with a browser the default name uniquely
457 # identifies its contents:
458 filename = f"{request.match_info['hash']}_{request.match_info['size_bytes']}"
460 # For partial reads also include the indices to prevent file mix-ups:
461 if offset != 0 or limit != 0:
462 filename += f"_chunk_{offset}-"
463 if limit != 0:
464 filename += f"{offset + limit}"
466 response.headers["Content-Disposition"] = f"Attachment;filename={filename}"
468 prepared = False
470 async def _callback(data: bytes) -> None:
471 nonlocal prepared
472 if not prepared:
473 # Prepare for chunked encoding when the callback is first called,
474 # so that we're sure we actually have some data before doing so.
475 response.enable_chunked_encoding()
476 await response.prepare(request)
477 prepared = True
478 await response.write(data)
480 await _fetch_blob(context, cache, digest, callback=_callback, offset=offset, limit=limit)
482 return response
484 return _get_blob_handler
487def _create_tarball(directory: str, name: str) -> bool:
488 """Makes a tarball from a given directory.
490 Returns True if the tarball was successfully created, and False if not.
492 Args:
493 directory (str): The directory to tar up.
494 name (str): The name of the tarball to be produced.
496 """
497 try:
498 with tarfile.open(name, "w:gz") as tarball:
499 tarball.add(directory, arcname="")
500 except Exception:
501 return False
502 return True
505async def _fetch_blob(
506 context: Context,
507 cache: ResponseCache,
508 digest: Digest,
509 message_class: Optional[Type[Any]] = None,
510 callback: Optional[Callable[[bytes], Awaitable[Any]]] = None,
511 offset: int = 0,
512 limit: int = 0,
513) -> Any:
514 """Fetch a blob from CAS.
516 This function sends a ByteStream Read request for the given digest. If ``callback``
517 is set then the callback is called with the data in each ReadResponse message and
518 this function returns an empty bytes object once the response is finished. If
519 ``message_class`` is set with no ``callback`` set then this function calls
520 ``message_class.FromString`` on the fetched blob and returns the result.
522 If neither ``callback`` or ``message_class`` are set then this function just returns
523 the raw blob that was fetched from CAS.
525 Args:
526 context (Context): The context to use to send the gRPC request.
527 cache (ResponseCache): The response cache to check/update with the fetched blob.
528 digest (Digest): The Digest of the blob to fetch from CAS.
529 message_class (type): A class which implements a ``FromString`` class method.
530 The method should take a bytes object expected to contain the blob fetched
531 from CAS.
532 callback (callable): A function or other callable to act on a subset of the
533 blob contents.
534 offset (int): Read offset to start reading the blob at. Defaults to 0, the start
535 of the blob.
536 limit (int): Maximum number of bytes to read from the blob. Defaults to 0, no
537 limit.
539 """
540 cacheable = digest.size_bytes <= BROWSER_MAX_CACHE_ENTRY_SIZE
541 resource_name = f"{context.instance_name}/blobs/{digest.hash}/{digest.size_bytes}"
542 blob = None
543 if cacheable:
544 blob = await cache.get_blob(resource_name)
545 if blob is not None and callback is not None:
546 if limit > 0:
547 try:
548 blob = blob[offset : offset + limit]
549 except IndexError:
550 raise web.HTTPBadRequest()
551 await callback(blob)
553 if blob is None:
554 stub = ByteStreamStub(context.cas_channel) # type: ignore # Requires stub regen
555 grpc_request = ReadRequest(resource_name=resource_name, read_offset=offset, read_limit=limit)
557 blob = b""
558 try:
559 async for grpc_response in stub.Read(grpc_request):
560 if grpc_response.data:
561 if callback is not None:
562 await callback(grpc_response.data)
564 if callback is None or cacheable:
565 blob += grpc_response.data
566 except RpcError as e:
567 LOGGER.warning(e.details())
568 if e.code() == StatusCode.NOT_FOUND:
569 raise web.HTTPNotFound()
570 raise web.HTTPInternalServerError()
572 if cacheable:
573 await cache.store_blob(resource_name, blob)
575 if message_class is not None and callback is None:
576 return message_class.FromString(blob)
577 return blob
580async def _download_directory(
581 context: Context, cache: ResponseCache, base: str, path: str, directory: Directory
582) -> None:
583 """Download the contents of a directory from CAS.
585 This function takes a Directory message and downloads the directory
586 contents defined in the message from CAS. Raises a 400 error if the
587 directory contains a symlink which points to a location outside the
588 initial directory.
590 The directory is downloaded recursively depth-first.
592 Args:
593 context (Context): The context to use for making gRPC requests.
594 cache (ResponseCache): The response cache to use when downloading the
595 directory contents.
596 base (str): The initial directory path, used to check symlinks don't
597 escape into the wider filesystem.
598 path (str): The path to download the directory into.
599 directory (Directory): The Directory message to fetch the contents of.
601 """
602 for directory_node in directory.directories:
603 dir_path = os.path.join(path, directory_node.name)
604 os.mkdir(dir_path)
605 child = await _fetch_blob(context, cache, directory_node.digest, message_class=Directory)
606 await _download_directory(context, cache, base, dir_path, child)
608 for file_node in directory.files:
609 file_path = os.path.join(path, file_node.name)
610 async with aiofiles.open(file_path, "wb") as f:
611 await _fetch_blob(context, cache, file_node.digest, callback=f.write)
612 if file_node.is_executable:
613 os.chmod(file_path, 0o755)
615 for link_node in directory.symlinks:
616 link_path = os.path.join(path, link_node.name)
617 target_path = os.path.realpath(link_node.target)
618 target_relpath = os.path.relpath(base, target_path)
619 if target_relpath.startswith(os.pardir):
620 raise web.HTTPBadRequest(
621 reason="Requested directory contains a symlink targeting a location outside the tarball"
622 )
624 os.symlink(link_node.target, link_path)
627async def _tarball_from_directory(context: Context, cache: ResponseCache, directory: Directory, tmp_dir: str) -> str:
628 """Construct a tarball of a directory stored in CAS.
630 This function fetches the contents of the given directory message into a
631 temporary directory, and then constructs a tarball of the directory. The
632 path to this tarball is returned when construction is complete.
634 Args:
635 context (Context): The context to use to send the gRPC requests.
636 cache (ResponseCache): The response cache to use when fetching the
637 tarball contents.
638 directory (Directory): The Directory message for the directory we're
639 making a tarball of.
640 tmp_dir (str): Path to a temporary directory to use for storing the
641 directory contents and its tarball.
643 """
644 tarball_dir = tempfile.mkdtemp(dir=tmp_dir)
645 tarball_path = os.path.join(tmp_dir, "directory.tar.gz")
646 loop = asyncio.get_event_loop()
648 # Fetch the contents of the directory into a temporary directory
649 await _download_directory(context, cache, tarball_dir, tarball_dir, directory)
651 # Make a tarball from that temporary directory
652 # NOTE: We do this using loop.run_in_executor to avoid the
653 # synchronous and blocking tarball construction
654 tarball_result = await loop.run_in_executor(None, _create_tarball, tarball_dir, tarball_path)
655 if not tarball_result:
656 raise web.HTTPInternalServerError()
657 return tarball_path
660def get_tarball_handler(
661 context: Context,
662 cache: ResponseCache,
663 allow_all: bool = False,
664 allowed_origins: UrlCollection = (),
665 tarball_dir: Optional[str] = None,
666) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
667 """Factory function which returns a handler for tarball requests.
669 This function also takes care of cleaning up old incomplete tarball constructions
670 when given a named directory to do the construction in.
672 The returned handler takes the hash and size_bytes of a Digest of a Directory
673 message and constructs a tarball of the directory defined by the message.
675 Args:
676 context (Context): The context to use to send the gRPC requests.
677 allow_all (bool): Whether or not to allow all CORS origins.
678 allowed_origins (list): List of valid CORS origins.
679 tarball_dir (str): Base directory to use for tarball construction.
681 """
683 class FetchSpec:
684 """Simple class used to store information about a tarball request.
686 A class is used here rather than a namedtuple since we need this state
687 to be mutable.
689 """
691 def __init__(self, *, error: Optional[Exception], event: Optional[asyncio.Event], path: str, refcount: int):
692 self.error = error
693 self.event = event
694 self.path = path
695 self.refcount = refcount
697 in_flight_fetches: Dict[str, FetchSpec] = {}
698 fetch_lock = asyncio.Lock()
700 # If we have a tarball directory to use, empty all existing tarball constructions from it
701 # to provide some form of cleanup after a crash.
702 if tarball_dir is not None:
703 for path in os.listdir(tarball_dir):
704 if path.startswith(TARBALL_DIRECTORY_PREFIX):
705 shutil.rmtree(os.path.join(tarball_dir, path))
707 async def _get_tarball_handler(request: web.Request) -> web.StreamResponse:
708 digest_str = f'{request.match_info["hash"]}/{request.match_info["size_bytes"]}'
709 LOGGER.info("Received request for a tarball from CAS for blob.", tags=dict(digest=digest_str))
711 digest = Digest(hash=request.match_info["hash"], size_bytes=int(request.match_info["size_bytes"]))
713 tmp_dir = tempfile.mkdtemp(prefix=TARBALL_DIRECTORY_PREFIX, dir=tarball_dir)
715 try:
716 duplicate_request = False
717 event = None
718 async with fetch_lock:
719 if digest_str in in_flight_fetches:
720 LOGGER.info("Deduplicating request for tarball.", tags=dict(digest=digest_str))
721 spec = in_flight_fetches[digest_str]
722 spec.refcount += 1
723 event = spec.event
724 duplicate_request = True
725 else:
726 event = asyncio.Event()
727 in_flight_fetches[digest_str] = FetchSpec(error=None, event=event, path="", refcount=1)
729 if duplicate_request and event:
730 # If we're a duplicate of an existing request, then wait for the
731 # existing request to finish tarball creation before reading the
732 # path from the cache.
733 await event.wait()
734 spec = in_flight_fetches[digest_str]
735 if spec.error is not None:
736 raise spec.error
737 tarball_path = in_flight_fetches[digest_str].path
738 else:
739 try:
740 directory = await _fetch_blob(context, cache, digest, message_class=Directory)
741 tarball_path = await _tarball_from_directory(context, cache, directory, tmp_dir)
742 except web.HTTPError as e:
743 in_flight_fetches[digest_str].error = e
744 if event:
745 event.set()
746 raise e
747 except Exception as e:
748 LOGGER.debug("Unexpected error constructing tarball.", tags=dict(digest=digest_str), exc_info=True)
749 in_flight_fetches[digest_str].error = e
750 if event:
751 event.set()
752 raise web.HTTPInternalServerError()
754 # Update path in deduplication cache, and set event to notify
755 # duplicate requests that the tarball is ready
756 async with fetch_lock:
757 if event:
758 in_flight_fetches[digest_str].path = tarball_path
759 event.set()
761 response = web.StreamResponse()
763 # We need to explicitly set CORS headers here, because the middleware that
764 # normally handles this only adds the header when the response is returned
765 # from this function. However, when using a StreamResponse with chunked
766 # encoding enabled, the client begins to receive the response when we call
767 # `response.write()`. This leads to the request being disallowed due to the
768 # missing reponse header for clients executing in browsers.
769 cors_headers = get_cors_headers(request.headers.get("Origin"), allowed_origins, allow_all)
770 response.headers.update(cors_headers)
772 response.enable_chunked_encoding()
773 await response.prepare(request)
775 async with aiofiles.open(tarball_path, "rb") as tarball:
776 await tarball.seek(0)
777 chunk = await tarball.read(1024)
778 while chunk:
779 await response.write(chunk)
780 chunk = await tarball.read(1024)
781 return response
783 except RpcError as e:
784 LOGGER.warning(e.details())
785 if e.code() == StatusCode.NOT_FOUND:
786 raise web.HTTPNotFound()
787 raise web.HTTPInternalServerError()
789 finally:
790 cleanup = False
791 async with fetch_lock:
792 # Decrement refcount now we're done with the tarball. If we're the
793 # last request interested in the tarball then remove it along with
794 # its construction directory.
795 spec = in_flight_fetches[digest_str]
796 spec.refcount -= 1
797 if spec.refcount <= 0:
798 cleanup = True
799 in_flight_fetches.pop(digest_str)
800 if cleanup:
801 shutil.rmtree(tmp_dir)
803 return _get_tarball_handler
806def logstream_handler(context: Context) -> Callable[[web.Request], Awaitable[Any]]:
807 async def _logstream_handler(request: web.Request) -> Any:
808 LOGGER.info("Receieved request for a LogStream websocket.")
809 stub = ByteStreamStub(context.logstream_channel) # type: ignore # Requires stub regen
810 ws = web.WebSocketResponse()
811 await ws.prepare(request)
813 async for msg in ws:
814 if msg.type == WSMsgType.BINARY:
815 read_request = ReadRequest()
816 read_request.ParseFromString(msg.data)
818 read_request.resource_name = f"{read_request.resource_name}"
819 try:
820 async for response in stub.Read(read_request):
821 serialized_response = response.SerializeToString()
822 if serialized_response:
823 ws_response = {
824 "resource_name": read_request.resource_name,
825 "data": response.data.decode("utf-8"),
826 "complete": False,
827 }
828 await ws.send_json(ws_response)
829 ws_response = {"resource_name": read_request.resource_name, "data": "", "complete": True}
830 await ws.send_json(ws_response)
831 except RpcError as e:
832 LOGGER.warning(e.details())
833 if e.code() == StatusCode.NOT_FOUND:
834 ws_response = {
835 "resource_name": read_request.resource_name,
836 "data": "NOT_FOUND",
837 "complete": True,
838 }
839 await ws.send_json(ws_response)
840 ws_response = {"resource_name": read_request.resource_name, "data": "INTERNAL", "complete": True}
841 await ws.send_json(ws_response)
843 return _logstream_handler