Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/remote.py: 92.62%
122 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-04-14 16:27 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2025-04-14 16:27 +0000
1# Copyright (C) 2018 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16"""
17RemoteStorage
18==================
20Forwwards storage requests to a remote storage.
21"""
23import io
24import logging
25from tempfile import NamedTemporaryFile
26from typing import IO, Any, Sequence
28import grpc
30from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc
31from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest
32from buildgrid._protos.google.rpc import code_pb2, status_pb2
33from buildgrid._protos.google.rpc.status_pb2 import Status
34from buildgrid.server.client.authentication import ClientCredentials
35from buildgrid.server.client.cas import download, upload
36from buildgrid.server.client.channel import setup_channel
37from buildgrid.server.context import current_instance
38from buildgrid.server.decorators import timed
39from buildgrid.server.exceptions import GrpcUninitializedError, NotFoundError
40from buildgrid.server.logging import buildgrid_logger
41from buildgrid.server.metadata import metadata_list
42from buildgrid.server.metrics_names import METRIC
43from buildgrid.server.settings import HASH, MAX_IN_MEMORY_BLOB_SIZE_BYTES
45from .storage_abc import StorageABC
47LOGGER = buildgrid_logger(__name__)
50class RemoteStorage(StorageABC):
51 TYPE = "Remote"
53 def __init__(
54 self,
55 remote: str,
56 instance_name: str | None = None,
57 channel_options: Sequence[tuple[str, Any]] | None = None,
58 credentials: ClientCredentials | None = None,
59 retries: int = 0,
60 max_backoff: int = 64,
61 request_timeout: float | None = None,
62 ) -> None:
63 self._remote_instance_name = instance_name
64 self._remote = remote
65 self._channel_options = channel_options
66 if credentials is None:
67 credentials = {}
68 self.credentials = credentials
69 self.retries = retries
70 self.max_backoff = max_backoff
71 self._request_timeout = request_timeout
73 self._stub_cas: remote_execution_pb2_grpc.ContentAddressableStorageStub | None = None
74 self.channel: grpc.Channel | None = None
76 def start(self) -> None:
77 if self.channel is None:
78 self.channel, *_ = setup_channel(
79 self._remote,
80 auth_token=self.credentials.get("auth-token"),
81 auth_token_refresh_seconds=self.credentials.get("token-refresh-seconds"),
82 client_key=self.credentials.get("tls-client-key"),
83 client_cert=self.credentials.get("tls-client-cert"),
84 server_cert=self.credentials.get("tls-server-cert"),
85 timeout=self._request_timeout,
86 )
88 if self._stub_cas is None:
89 self._stub_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel)
91 def stop(self) -> None:
92 if self.channel is not None:
93 self.channel.close()
95 @property
96 def remote_instance_name(self) -> str:
97 if self._remote_instance_name is not None:
98 return self._remote_instance_name
99 return current_instance()
101 @timed(METRIC.STORAGE.STAT_DURATION, type=TYPE)
102 def has_blob(self, digest: Digest) -> bool:
103 LOGGER.debug("Checking for blob.", tags=dict(digest=digest))
104 if not self.missing_blobs([digest]):
105 return True
106 return False
108 @timed(METRIC.STORAGE.READ_DURATION, type=TYPE)
109 def get_blob(self, digest: Digest) -> IO[bytes] | None:
110 if self.channel is None:
111 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.")
113 LOGGER.debug("Getting blob.", tags=dict(digest=digest))
114 with download(
115 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff
116 ) as downloader:
117 if digest.size_bytes > MAX_IN_MEMORY_BLOB_SIZE_BYTES:
118 # Avoid storing the large blob completely in memory.
119 temp_file = NamedTemporaryFile(delete=True)
120 try:
121 downloader.download_file(digest, temp_file.name, queue=False)
122 except NotFoundError:
123 return None
124 # TODO Fix this: incompatible type "_TemporaryFileWrapper[bytes]"; expected "RawIOBase"
125 reader = io.BufferedReader(temp_file) # type: ignore[arg-type]
126 reader.seek(0)
127 return reader
128 else:
129 blob = downloader.get_blob(digest)
130 if blob is not None:
131 return io.BytesIO(blob)
132 else:
133 return None
135 def delete_blob(self, digest: Digest) -> None:
136 """The REAPI doesn't have a deletion method, so we can't support
137 deletion for remote storage.
138 """
139 raise NotImplementedError("Deletion is not supported for remote storage!")
141 def bulk_delete(self, digests: list[Digest]) -> list[str]:
142 """The REAPI doesn't have a deletion method, so we can't support
143 bulk deletion for remote storage.
144 """
145 raise NotImplementedError("Bulk deletion is not supported for remote storage!")
147 @timed(METRIC.STORAGE.WRITE_DURATION, type=TYPE)
148 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None:
149 if self.channel is None:
150 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.")
152 write_session.seek(0)
153 LOGGER.debug("Writing blob.", tags=dict(digest=digest))
154 with upload(
155 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff
156 ) as uploader:
157 uploader.put_blob(write_session, digest=digest)
159 @timed(METRIC.STORAGE.BULK_STAT_DURATION, type=TYPE)
160 def missing_blobs(self, digests: list[Digest]) -> list[Digest]:
161 if self._stub_cas is None:
162 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.")
164 # Avoid expensive string creation.
165 if LOGGER.is_enabled_for(logging.DEBUG):
166 if len(digests) > 100:
167 LOGGER.debug(f"Missing blobs request for: {digests[:100]} (truncated)")
168 else:
169 LOGGER.debug(f"Missing blobs request for: {digests}")
171 request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=self.remote_instance_name)
173 for blob in digests:
174 request_digest = request.blob_digests.add()
175 request_digest.hash = blob.hash
176 request_digest.size_bytes = blob.size_bytes
178 response = self._stub_cas.FindMissingBlobs(request, metadata=metadata_list())
180 return list(response.missing_blob_digests)
182 @timed(METRIC.STORAGE.BULK_WRITE_DURATION, type=TYPE)
183 def bulk_update_blobs(self, blobs: list[tuple[Digest, bytes]]) -> list[Status]:
184 if self._stub_cas is None or self.channel is None:
185 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.")
187 sent_digests = []
188 with upload(
189 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff
190 ) as uploader:
191 for digest, blob in blobs:
192 if len(blob) != digest.size_bytes or HASH(blob).hexdigest() != digest.hash:
193 sent_digests.append(remote_execution_pb2.Digest())
194 else:
195 sent_digests.append(uploader.put_blob(io.BytesIO(blob), digest=digest, queue=True))
197 assert len(sent_digests) == len(blobs)
199 return [
200 status_pb2.Status(code=code_pb2.OK) if d.ByteSize() > 0 else status_pb2.Status(code=code_pb2.UNKNOWN)
201 for d in sent_digests
202 ]
204 @timed(METRIC.STORAGE.BULK_READ_DURATION, type=TYPE)
205 def bulk_read_blobs(self, digests: list[Digest]) -> dict[str, bytes]:
206 if self._stub_cas is None or self.channel is None:
207 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.")
209 # Avoid expensive string creation.
210 if LOGGER.is_enabled_for(logging.DEBUG):
211 LOGGER.debug(f"Bulk read blobs request for: {digests}")
213 with download(
214 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff
215 ) as downloader:
216 results = downloader.get_available_blobs(digests)
217 # Transform List of (data, digest) pairs to expected hash-blob map
218 return {digest.hash: data for data, digest in results}