Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/instance.py: 93.21%
265 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-28 16:20 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-03-28 16:20 +0000
1# Copyright (C) 2018 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16"""
17Storage Instances
18=================
19Instances of CAS and ByteStream
20"""
22import logging
23from datetime import timedelta
24from typing import Iterable, Iterator, Optional, Sequence, Tuple
26from cachetools import TTLCache
27from grpc import RpcError
29from buildgrid._exceptions import (
30 InvalidArgumentError,
31 NotFoundError,
32 OutOfRangeError,
33 PermissionDeniedError,
34 RetriableError,
35)
36from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import DESCRIPTOR as RE_DESCRIPTOR
37from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import (
38 BatchReadBlobsResponse,
39 BatchUpdateBlobsRequest,
40 BatchUpdateBlobsResponse,
41 Digest,
42 DigestFunction,
43 Directory,
44 FindMissingBlobsResponse,
45 GetTreeRequest,
46 GetTreeResponse,
47 SymlinkAbsolutePathStrategy,
48 Tree,
49)
50from buildgrid._protos.google.bytestream import bytestream_pb2 as bs_pb2
51from buildgrid._protos.google.rpc import code_pb2, status_pb2
52from buildgrid.server.cas.storage.storage_abc import StorageABC, create_write_session
53from buildgrid.server.metrics_names import (
54 CAS_BATCH_READ_BLOBS_EXCEPTION_COUNT_METRIC_NAME,
55 CAS_BATCH_READ_BLOBS_SIZE_BYTES,
56 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME,
57 CAS_BATCH_UPDATE_BLOBS_EXCEPTION_COUNT_METRIC_NAME,
58 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES,
59 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME,
60 CAS_BYTESTREAM_READ_EXCEPTION_COUNT_METRIC_NAME,
61 CAS_BYTESTREAM_READ_SIZE_BYTES,
62 CAS_BYTESTREAM_READ_TIME_METRIC_NAME,
63 CAS_BYTESTREAM_WRITE_EXCEPTION_COUNT_METRIC_NAME,
64 CAS_BYTESTREAM_WRITE_SIZE_BYTES,
65 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME,
66 CAS_DOWNLOADED_BYTES_METRIC_NAME,
67 CAS_EXCEPTION_COUNT_METRIC_NAME,
68 CAS_FIND_MISSING_BLOBS_EXCEPTION_COUNT_METRIC_NAME,
69 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME,
70 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME,
71 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME,
72 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME,
73 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME,
74 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME,
75 CAS_GET_TREE_CACHE_HIT,
76 CAS_GET_TREE_CACHE_MISS,
77 CAS_GET_TREE_EXCEPTION_COUNT_METRIC_NAME,
78 CAS_GET_TREE_TIME_METRIC_NAME,
79 CAS_UPLOADED_BYTES_METRIC_NAME,
80)
81from buildgrid.server.metrics_utils import (
82 Counter,
83 Distribution,
84 DurationMetric,
85 ExceptionCounter,
86 generator_method_duration_metric,
87 generator_method_exception_counter,
88)
89from buildgrid.server.servicer import Instance
90from buildgrid.settings import HASH, HASH_LENGTH, MAX_REQUEST_COUNT, STREAM_ERROR_RETRY_PERIOD
91from buildgrid.utils import create_digest, get_unique_objects_by_attribute
93LOGGER = logging.getLogger(__name__)
95EMPTY_BLOB = b""
96EMPTY_BLOB_DIGEST: Digest = create_digest(EMPTY_BLOB)
99class ContentAddressableStorageInstance(Instance):
100 SERVICE_NAME = RE_DESCRIPTOR.services_by_name["ContentAddressableStorage"].full_name
102 def __init__(
103 self,
104 storage: StorageABC,
105 read_only: bool = False,
106 tree_cache_size: Optional[int] = None,
107 tree_cache_ttl_minutes: float = 60,
108 ) -> None:
109 self._storage = storage
110 self.__read_only = read_only
112 self._tree_cache: Optional[TTLCache[Tuple[str, int], Digest]] = None
113 if tree_cache_size:
114 self._tree_cache = TTLCache(tree_cache_size, tree_cache_ttl_minutes * 60)
116 # --- Public API ---
118 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME)
119 def set_instance_name(self, instance_name: str) -> None:
120 super().set_instance_name(instance_name)
121 self._storage.set_instance_name(instance_name)
123 def start(self) -> None:
124 self._storage.start()
126 def stop(self) -> None:
127 self._storage.stop()
129 def hash_type(self) -> "DigestFunction.Value.ValueType":
130 return self._storage.hash_type()
132 def max_batch_total_size_bytes(self) -> int:
133 return self._storage.max_batch_total_size_bytes()
135 def symlink_absolute_path_strategy(self) -> "SymlinkAbsolutePathStrategy.Value.ValueType":
136 return self._storage.symlink_absolute_path_strategy()
138 find_missing_blobs_ignored_exceptions = (RetriableError,)
140 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, instanced=True)
141 @ExceptionCounter(
142 CAS_FIND_MISSING_BLOBS_EXCEPTION_COUNT_METRIC_NAME,
143 ignored_exceptions=find_missing_blobs_ignored_exceptions,
144 )
145 def find_missing_blobs(self, blob_digests: Sequence[Digest]) -> FindMissingBlobsResponse:
146 storage = self._storage
147 blob_digests = list(get_unique_objects_by_attribute(blob_digests, "hash"))
148 missing_blobs = storage.missing_blobs(blob_digests)
150 num_blobs_in_request = len(blob_digests)
151 if num_blobs_in_request > 0:
152 num_blobs_missing = len(missing_blobs)
153 percent_missing = float((num_blobs_missing / num_blobs_in_request) * 100)
155 with Distribution(
156 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, instance_name=self._instance_name
157 ) as metric_num_requested:
158 metric_num_requested.count = float(num_blobs_in_request)
160 with Distribution(
161 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, instance_name=self._instance_name
162 ) as metric_num_missing:
163 metric_num_missing.count = float(num_blobs_missing)
165 with Distribution(
166 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, instance_name=self._instance_name
167 ) as metric_percent_missing:
168 metric_percent_missing.count = percent_missing
170 for digest in blob_digests:
171 with Distribution(
172 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_REQUESTED_METRIC_NAME, instance_name=self._instance_name
173 ) as metric_requested_blob_size:
174 metric_requested_blob_size.count = float(digest.size_bytes)
176 for digest in missing_blobs:
177 with Distribution(
178 CAS_FIND_MISSING_BLOBS_SIZE_BYTES_MISSING_METRIC_NAME, instance_name=self._instance_name
179 ) as metric_missing_blob_size:
180 metric_missing_blob_size.count = float(digest.size_bytes)
182 return FindMissingBlobsResponse(missing_blob_digests=missing_blobs)
184 batch_update_blobs_ignored_exceptions = (RetriableError,)
186 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, instanced=True)
187 @ExceptionCounter(
188 CAS_BATCH_UPDATE_BLOBS_EXCEPTION_COUNT_METRIC_NAME,
189 ignored_exceptions=batch_update_blobs_ignored_exceptions,
190 )
191 def batch_update_blobs(self, requests: Sequence[BatchUpdateBlobsRequest.Request]) -> BatchUpdateBlobsResponse:
192 if self.__read_only:
193 raise PermissionDeniedError(f"CAS instance {self._instance_name} is read-only")
195 storage = self._storage
196 store = []
197 for request_proto in get_unique_objects_by_attribute(requests, "digest.hash"):
198 store.append((request_proto.digest, request_proto.data))
200 with Distribution(
201 CAS_BATCH_UPDATE_BLOBS_SIZE_BYTES, instance_name=self._instance_name
202 ) as metric_blob_size:
203 metric_blob_size.count = float(request_proto.digest.size_bytes)
205 response = BatchUpdateBlobsResponse()
206 statuses = storage.bulk_update_blobs(store)
208 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter:
209 for (digest, _), status in zip(store, statuses):
210 response_proto = response.responses.add()
211 response_proto.digest.CopyFrom(digest)
212 response_proto.status.CopyFrom(status)
213 if response_proto.status.code == 0:
214 bytes_counter.increment(response_proto.digest.size_bytes)
216 return response
218 batch_read_blobs_ignored_exceptions = (RetriableError,)
220 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, instanced=True)
221 @ExceptionCounter(
222 CAS_BATCH_READ_BLOBS_EXCEPTION_COUNT_METRIC_NAME,
223 ignored_exceptions=batch_read_blobs_ignored_exceptions,
224 )
225 def batch_read_blobs(self, digests: Sequence[Digest]) -> BatchReadBlobsResponse:
226 storage = self._storage
228 max_batch_size = storage.max_batch_total_size_bytes()
230 # Only process unique digests
231 good_digests = []
232 bad_digests = []
233 requested_bytes = 0
234 for digest in get_unique_objects_by_attribute(digests, "hash"):
235 if len(digest.hash) != HASH_LENGTH:
236 bad_digests.append(digest)
237 else:
238 good_digests.append(digest)
239 requested_bytes += digest.size_bytes
241 if requested_bytes > max_batch_size:
242 raise InvalidArgumentError(
243 "Combined total size of blobs exceeds "
244 "server limit. "
245 f"({requested_bytes} > {max_batch_size} [byte])"
246 )
248 if len(good_digests) > 0:
249 blobs_read = storage.bulk_read_blobs(good_digests)
250 else:
251 blobs_read = {}
253 response = BatchReadBlobsResponse()
255 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter:
256 for digest in good_digests:
257 response_proto = response.responses.add()
258 response_proto.digest.CopyFrom(digest)
260 if digest.hash in blobs_read and blobs_read[digest.hash] is not None:
261 response_proto.data = blobs_read[digest.hash]
262 status_code = code_pb2.OK
263 bytes_counter.increment(digest.size_bytes)
265 with Distribution(
266 CAS_BATCH_READ_BLOBS_SIZE_BYTES, instance_name=self._instance_name
267 ) as metric_blob_size:
268 metric_blob_size.count = float(digest.size_bytes)
269 else:
270 status_code = code_pb2.NOT_FOUND
272 response_proto.status.CopyFrom(status_pb2.Status(code=status_code))
274 for digest in bad_digests:
275 response_proto = response.responses.add()
276 response_proto.digest.CopyFrom(digest)
277 status_code = code_pb2.INVALID_ARGUMENT
278 response_proto.status.CopyFrom(status_pb2.Status(code=status_code))
280 return response
282 def lookup_tree_cache(self, root_digest: Digest) -> Optional[Tree]:
283 """Find full Tree from cache"""
284 if self._tree_cache is None:
285 return None
286 tree = None
287 if response_digest := self._tree_cache.get((root_digest.hash, root_digest.size_bytes)):
288 tree = self._storage.get_message(response_digest, Tree)
289 if tree is None:
290 self._tree_cache.pop((root_digest.hash, root_digest.size_bytes))
292 metric = CAS_GET_TREE_CACHE_HIT if tree is not None else CAS_GET_TREE_CACHE_MISS
293 with Counter(metric) as counter:
294 counter.increment(1)
295 return tree
297 def put_tree_cache(self, root_digest: Digest, root: Directory, children: Iterable[Directory]) -> None:
298 """Put Tree with a full list of directories into CAS"""
299 if self._tree_cache is None:
300 return
301 tree = Tree(root=root, children=children)
302 tree_digest = self._storage.put_message(tree)
303 self._tree_cache[(root_digest.hash, root_digest.size_bytes)] = tree_digest
305 get_tree_ignored_exceptions = (NotFoundError, RetriableError)
307 @DurationMetric(CAS_GET_TREE_TIME_METRIC_NAME, instanced=True)
308 @generator_method_exception_counter(
309 CAS_GET_TREE_EXCEPTION_COUNT_METRIC_NAME,
310 ignored_exceptions=get_tree_ignored_exceptions,
311 )
312 def get_tree(self, request: GetTreeRequest) -> Iterator[GetTreeResponse]:
313 storage = self._storage
315 if not request.page_size:
316 request.page_size = MAX_REQUEST_COUNT
318 if tree := self.lookup_tree_cache(request.root_digest):
319 # Cache hit, yield responses based on page size
320 directories = [tree.root]
321 directories.extend(tree.children)
322 yield from (
323 GetTreeResponse(directories=directories[start : start + request.page_size])
324 for start in range(0, len(directories), request.page_size)
325 )
326 return
328 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name) as bytes_counter:
329 # From the spec, a NotFound response only occurs if the root directory is missing.
330 root_directory = storage.get_message(request.root_digest, Directory)
331 if not root_directory:
332 raise NotFoundError(
333 f"Root digest not found: {request.root_digest.hash}/{request.root_digest.size_bytes}"
334 )
336 bytes_counter.increment(request.root_digest.size_bytes)
338 results = [root_directory]
339 offset = 0
340 queue = [subdir.digest for subdir in root_directory.directories]
341 while queue:
342 bytes_counter.increment(sum(d.size_bytes for d in queue))
343 blobs = storage.bulk_read_blobs(queue)
345 directories = [Directory.FromString(b) for b in blobs.values()]
346 queue = [subdir.digest for d in directories for subdir in d.directories]
348 results.extend(directories)
349 while len(results) - offset >= request.page_size:
350 yield GetTreeResponse(directories=results[offset : offset + request.page_size])
351 offset += request.page_size
353 if len(results) - offset > 0:
354 yield GetTreeResponse(directories=results[offset:])
355 if results:
356 self.put_tree_cache(request.root_digest, results[0], results[1:])
359class ByteStreamInstance(Instance):
360 SERVICE_NAME = bs_pb2.DESCRIPTOR.services_by_name["ByteStream"].full_name
362 BLOCK_SIZE = 1 * 1024 * 1024 # 1 MB block size
364 def __init__(
365 self,
366 storage: StorageABC,
367 read_only: bool = False,
368 disable_overwrite_early_return: bool = False,
369 ) -> None:
370 self._storage = storage
371 self._query_activity_timeout = 30
373 self.__read_only = read_only
375 # If set, prevents `ByteStream.Write()` from returning without
376 # reading all the client's `WriteRequests` for a digest that is
377 # already in storage (i.e. not follow the REAPI-specified
378 # behavior).
379 self.__disable_overwrite_early_return = disable_overwrite_early_return
380 # (Should only be used to work around issues with implementations
381 # that treat the server half-closing its end of the gRPC stream
382 # as a HTTP/2 stream error.)
384 # --- Public API ---
386 def start(self) -> None:
387 self._storage.start()
389 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME)
390 def set_instance_name(self, instance_name: str) -> None:
391 super().set_instance_name(instance_name)
392 self._storage.set_instance_name(instance_name)
394 bytestream_read_ignored_exceptions = (NotFoundError, RetriableError)
396 @generator_method_duration_metric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME)
397 @generator_method_exception_counter(
398 CAS_BYTESTREAM_READ_EXCEPTION_COUNT_METRIC_NAME,
399 ignored_exceptions=bytestream_read_ignored_exceptions,
400 )
401 def read_cas_blob(
402 self, digest_hash: str, digest_size: str, read_offset: int, read_limit: int
403 ) -> Iterator[bs_pb2.ReadResponse]:
404 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit():
405 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]")
407 digest = Digest(hash=digest_hash, size_bytes=int(digest_size))
409 # Check the given read offset and limit.
410 if read_offset < 0 or read_offset > digest.size_bytes:
411 raise OutOfRangeError("Read offset out of range")
413 elif read_limit == 0:
414 bytes_remaining = digest.size_bytes - read_offset
416 elif read_limit > 0:
417 bytes_remaining = read_limit
419 else:
420 raise InvalidArgumentError("Negative read_limit is invalid")
422 # Read the blob from storage and send its contents to the client.
423 result = self._storage.get_blob(digest)
424 if result is None:
425 raise NotFoundError("Blob not found")
427 try:
428 if read_offset > 0:
429 result.seek(read_offset)
431 with Distribution(
432 metric_name=CAS_BYTESTREAM_READ_SIZE_BYTES, instance_name=self._instance_name
433 ) as metric_blob_size:
434 metric_blob_size.count = float(digest.size_bytes)
436 with Counter(
437 metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name
438 ) as bytes_counter:
439 while bytes_remaining > 0:
440 block_data = result.read(min(self.BLOCK_SIZE, bytes_remaining))
441 yield bs_pb2.ReadResponse(data=block_data)
442 bytes_counter.increment(len(block_data))
443 bytes_remaining -= self.BLOCK_SIZE
444 finally:
445 result.close()
447 bytestream_write_ignored_exceptions = (NotFoundError, RetriableError)
449 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, instanced=True)
450 @ExceptionCounter(
451 CAS_BYTESTREAM_WRITE_EXCEPTION_COUNT_METRIC_NAME,
452 ignored_exceptions=bytestream_write_ignored_exceptions,
453 )
454 def write_cas_blob(
455 self, digest_hash: str, digest_size: str, requests: Iterator[bs_pb2.WriteRequest]
456 ) -> bs_pb2.WriteResponse:
457 if self.__read_only:
458 raise PermissionDeniedError(f"ByteStream instance {self._instance_name} is read-only")
460 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit():
461 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]")
463 digest = Digest(hash=digest_hash, size_bytes=int(digest_size))
465 with Distribution(
466 metric_name=CAS_BYTESTREAM_WRITE_SIZE_BYTES, instance_name=self._instance_name
467 ) as metric_blob_size:
468 metric_blob_size.count = float(digest.size_bytes)
470 if self._storage.has_blob(digest):
471 # According to the REAPI specification:
472 # "When attempting an upload, if another client has already
473 # completed the upload (which may occur in the middle of a single
474 # upload if another client uploads the same blob concurrently),
475 # the request will terminate immediately [...]".
476 #
477 # However, half-closing the stream can be problematic with some
478 # intermediaries like HAProxy.
479 # (https://github.com/haproxy/haproxy/issues/1219)
480 #
481 # If half-closing the stream is not allowed, we read and drop
482 # all the client's messages before returning, still saving
483 # the cost of a write to storage.
484 if self.__disable_overwrite_early_return:
485 try:
486 for request in requests:
487 if request.finish_write:
488 break
489 continue
490 except RpcError:
491 msg = "ByteStream client disconnected whilst streaming requests, upload cancelled."
492 LOGGER.debug(msg)
493 raise RetriableError(msg, retry_period=timedelta(seconds=STREAM_ERROR_RETRY_PERIOD))
495 return bs_pb2.WriteResponse(committed_size=digest.size_bytes)
497 # Start the write session and write the first request's data.
498 bytes_counter = Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, instance_name=self._instance_name)
499 write_session = create_write_session(digest)
500 with bytes_counter, write_session:
501 computed_hash = HASH()
503 # Handle subsequent write requests.
504 try:
505 for request in requests:
506 write_session.write(request.data)
508 computed_hash.update(request.data)
509 bytes_counter.increment(len(request.data))
511 if request.finish_write:
512 break
513 except RpcError:
514 write_session.close()
515 msg = "ByteStream client disconnected whilst streaming requests, upload cancelled."
516 LOGGER.debug(msg)
517 raise RetriableError(msg, retry_period=timedelta(seconds=STREAM_ERROR_RETRY_PERIOD))
519 # Check that the data matches the provided digest.
520 if bytes_counter.count != digest.size_bytes:
521 raise NotImplementedError(
522 "Cannot close stream before finishing write, "
523 f"got {bytes_counter.count} bytes but expected {digest.size_bytes}"
524 )
526 if computed_hash.hexdigest() != digest.hash:
527 raise InvalidArgumentError("Data does not match hash")
529 self._storage.commit_write(digest, write_session)
530 return bs_pb2.WriteResponse(committed_size=int(bytes_counter.count))