Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/sharded.py: 99.11%
112 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-07-10 13:10 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2025-07-10 13:10 +0000
1# Copyright (C) 2023 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from collections import defaultdict
17from contextlib import ExitStack
18from typing import IO, Callable, Iterable, Iterator, TypeVar
20import mmh3
22from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest
23from buildgrid._protos.google.rpc import code_pb2
24from buildgrid._protos.google.rpc.status_pb2 import Status
25from buildgrid.server.decorators import timed
26from buildgrid.server.logging import buildgrid_logger
27from buildgrid.server.metrics_names import METRIC
28from buildgrid.server.metrics_utils import publish_counter_metric
29from buildgrid.server.threading import ContextThreadPoolExecutor
31from .storage_abc import StorageABC
33LOGGER = buildgrid_logger(__name__)
36_T = TypeVar("_T")
37_R = TypeVar("_R")
40# wrapper functions for the bulk StorageABC interfaces
41def _bulk_delete_for_storage(storage_digests: tuple[StorageABC, list[Digest]]) -> list[str]:
42 storage, digests = storage_digests
43 return storage.bulk_delete(digests)
46def _fmb_for_storage(storage_digests: tuple[StorageABC, list[Digest]]) -> list[Digest]:
47 storage, digests = storage_digests
48 return storage.missing_blobs(digests)
51def _bulk_update_for_storage(
52 storage_digests: tuple[StorageABC, list[tuple[Digest, bytes]]],
53) -> tuple[StorageABC, list[Status]]:
54 storage, digest_tuples = storage_digests
55 return storage, storage.bulk_update_blobs(digest_tuples)
58def _bulk_read_for_storage(storage_digests: tuple[StorageABC, list[Digest]]) -> dict[str, bytes]:
59 storage, digests = storage_digests
60 return storage.bulk_read_blobs(digests)
63class ShardedStorage(StorageABC):
64 TYPE = "Sharded"
66 def __init__(self, storages: dict[str, StorageABC], thread_pool_size: int | None = None):
67 self._stack = ExitStack()
68 if not storages:
69 raise ValueError("ShardedStorage requires at least one shard")
70 self._storages = storages
71 self._threadpool = None
72 if thread_pool_size:
73 self._threadpool = ContextThreadPoolExecutor(thread_pool_size, "sharded-storage")
75 def start(self) -> None:
76 if self._threadpool:
77 self._stack.enter_context(self._threadpool)
78 for storage in self._storages.values():
79 self._stack.enter_context(storage)
81 def stop(self) -> None:
82 self._stack.close()
83 LOGGER.info(f"Stopped {type(self).__name__}")
85 def _storage_from_digest(self, digest: Digest) -> StorageABC:
86 def _score(shard_name: str, digest: Digest) -> int:
87 hash = mmh3.hash(f"{shard_name}\t{digest.hash}", signed=False)
88 return hash
90 shard_name = min(self._storages.keys(), key=lambda name: _score(name, digest))
91 return self._storages[shard_name]
93 def _partition_digests(self, digests: list[Digest]) -> dict[StorageABC, list[Digest]]:
94 partition: dict[StorageABC, list[Digest]] = defaultdict(list)
95 for digest in digests:
96 storage = self._storage_from_digest(digest)
97 partition[storage].append(digest)
98 return partition
100 def _map(self, fn: Callable[[_T], _R], args: Iterable[_T]) -> Iterator[_R]:
101 if self._threadpool:
102 return self._threadpool.map(fn, args)
103 else:
104 return map(fn, args)
106 @timed(METRIC.STORAGE.STAT_DURATION, type=TYPE)
107 def has_blob(self, digest: Digest) -> bool:
108 return self._storage_from_digest(digest).has_blob(digest)
110 @timed(METRIC.STORAGE.READ_DURATION, type=TYPE)
111 def get_blob(self, digest: Digest) -> IO[bytes] | None:
112 return self._storage_from_digest(digest).get_blob(digest)
114 @timed(METRIC.STORAGE.STREAM_READ_DURATION, type=TYPE)
115 def stream_read_blob(self, digest: Digest, chunk_size: int, offset: int = 0, limit: int = 0) -> Iterator[bytes]:
116 yield from self._storage_from_digest(digest).stream_read_blob(digest, chunk_size, offset, limit)
118 @timed(METRIC.STORAGE.STREAM_WRITE_DURATION, type=TYPE)
119 def stream_write_blob(self, digest: Digest, chunks: Iterator[bytes]) -> None:
120 self._storage_from_digest(digest).stream_write_blob(digest, chunks)
122 @timed(METRIC.STORAGE.DELETE_DURATION, type=TYPE)
123 def delete_blob(self, digest: Digest) -> None:
124 self._storage_from_digest(digest).delete_blob(digest)
126 @timed(METRIC.STORAGE.WRITE_DURATION, type=TYPE)
127 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None:
128 self._storage_from_digest(digest).commit_write(digest, write_session)
130 @timed(METRIC.STORAGE.DELETE_DURATION, type=TYPE)
131 def bulk_delete(self, digests: list[Digest]) -> list[str]:
132 failed_deletions: list[str] = []
133 for result in self._map(_bulk_delete_for_storage, self._partition_digests(digests).items()):
134 failed_deletions.extend(result)
136 publish_counter_metric(METRIC.STORAGE.DELETE_ERRORS_COUNT, len(failed_deletions), type=self.TYPE)
137 return failed_deletions
139 @timed(METRIC.STORAGE.BULK_STAT_DURATION, type=TYPE)
140 def missing_blobs(self, digests: list[Digest]) -> list[Digest]:
141 missing_blobs: list[Digest] = []
143 for result in self._map(_fmb_for_storage, self._partition_digests(digests).items()):
144 missing_blobs.extend(result)
146 return missing_blobs
148 @timed(METRIC.STORAGE.BULK_WRITE_DURATION, type=TYPE)
149 def bulk_update_blobs(self, blobs: list[tuple[Digest, bytes]]) -> list[Status]:
150 partitioned_digests: dict[StorageABC, list[tuple[Digest, bytes]]] = defaultdict(list)
151 idx_map: dict[StorageABC, list[int]] = defaultdict(list)
152 for orig_idx, digest_tuple in enumerate(blobs):
153 storage = self._storage_from_digest(digest_tuple[0])
154 partitioned_digests[storage].append(digest_tuple)
155 idx_map[storage].append(orig_idx)
157 results: list[Status] = [Status(code=code_pb2.INTERNAL, message="inconsistent batch results")] * len(blobs)
158 for storage, statuses in self._map(_bulk_update_for_storage, partitioned_digests.items()):
159 for status_idx, status in enumerate(statuses):
160 results[idx_map[storage][status_idx]] = status
161 return results
163 @timed(METRIC.STORAGE.BULK_READ_DURATION, type=TYPE)
164 def bulk_read_blobs(self, digests: list[Digest]) -> dict[str, bytes]:
165 bulk_read_results: dict[str, bytes] = {}
166 for result in self._map(_bulk_read_for_storage, self._partition_digests(digests).items()):
167 bulk_read_results.update(result)
169 return bulk_read_results