Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/sharded.py: 99.06%

106 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-02-11 15:07 +0000

1# Copyright (C) 2023 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from collections import defaultdict 

17from contextlib import ExitStack 

18from typing import IO, Callable, Iterable, Iterator, TypeVar 

19 

20import mmh3 

21 

22from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest 

23from buildgrid._protos.google.rpc import code_pb2 

24from buildgrid._protos.google.rpc.status_pb2 import Status 

25from buildgrid.server.decorators import timed 

26from buildgrid.server.logging import buildgrid_logger 

27from buildgrid.server.metrics_names import METRIC 

28from buildgrid.server.metrics_utils import publish_counter_metric 

29from buildgrid.server.threading import ContextThreadPoolExecutor 

30 

31from .storage_abc import StorageABC 

32 

33LOGGER = buildgrid_logger(__name__) 

34 

35 

36_T = TypeVar("_T") 

37_R = TypeVar("_R") 

38 

39 

40# wrapper functions for the bulk StorageABC interfaces 

41def _bulk_delete_for_storage(storage_digests: tuple[StorageABC, list[Digest]]) -> list[str]: 

42 storage, digests = storage_digests 

43 return storage.bulk_delete(digests) 

44 

45 

46def _fmb_for_storage(storage_digests: tuple[StorageABC, list[Digest]]) -> list[Digest]: 

47 storage, digests = storage_digests 

48 return storage.missing_blobs(digests) 

49 

50 

51def _bulk_update_for_storage( 

52 storage_digests: tuple[StorageABC, list[tuple[Digest, bytes]]], 

53) -> tuple[StorageABC, list[Status]]: 

54 storage, digest_tuples = storage_digests 

55 return storage, storage.bulk_update_blobs(digest_tuples) 

56 

57 

58def _bulk_read_for_storage(storage_digests: tuple[StorageABC, list[Digest]]) -> dict[str, bytes]: 

59 storage, digests = storage_digests 

60 return storage.bulk_read_blobs(digests) 

61 

62 

63class ShardedStorage(StorageABC): 

64 TYPE = "Sharded" 

65 

66 def __init__(self, storages: dict[str, StorageABC], thread_pool_size: int | None = None): 

67 self._stack = ExitStack() 

68 if not storages: 

69 raise ValueError("ShardedStorage requires at least one shard") 

70 self._storages = storages 

71 self._threadpool = None 

72 if thread_pool_size: 

73 self._threadpool = ContextThreadPoolExecutor(thread_pool_size, "sharded-storage") 

74 

75 def start(self) -> None: 

76 if self._threadpool: 

77 self._stack.enter_context(self._threadpool) 

78 for storage in self._storages.values(): 

79 self._stack.enter_context(storage) 

80 

81 def stop(self) -> None: 

82 self._stack.close() 

83 LOGGER.info(f"Stopped {type(self).__name__}") 

84 

85 def _storage_from_digest(self, digest: Digest) -> StorageABC: 

86 def _score(shard_name: str, digest: Digest) -> int: 

87 hash = mmh3.hash(f"{shard_name}\t{digest.hash}", signed=False) 

88 return hash 

89 

90 shard_name = min(self._storages.keys(), key=lambda name: _score(name, digest)) 

91 return self._storages[shard_name] 

92 

93 def _partition_digests(self, digests: list[Digest]) -> dict[StorageABC, list[Digest]]: 

94 partition: dict[StorageABC, list[Digest]] = defaultdict(list) 

95 for digest in digests: 

96 storage = self._storage_from_digest(digest) 

97 partition[storage].append(digest) 

98 return partition 

99 

100 def _map(self, fn: Callable[[_T], _R], args: Iterable[_T]) -> Iterator[_R]: 

101 if self._threadpool: 

102 return self._threadpool.map(fn, args) 

103 else: 

104 return map(fn, args) 

105 

106 @timed(METRIC.STORAGE.STAT_DURATION, type=TYPE) 

107 def has_blob(self, digest: Digest) -> bool: 

108 return self._storage_from_digest(digest).has_blob(digest) 

109 

110 @timed(METRIC.STORAGE.READ_DURATION, type=TYPE) 

111 def get_blob(self, digest: Digest) -> IO[bytes] | None: 

112 return self._storage_from_digest(digest).get_blob(digest) 

113 

114 @timed(METRIC.STORAGE.DELETE_DURATION, type=TYPE) 

115 def delete_blob(self, digest: Digest) -> None: 

116 self._storage_from_digest(digest).delete_blob(digest) 

117 

118 @timed(METRIC.STORAGE.WRITE_DURATION, type=TYPE) 

119 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None: 

120 self._storage_from_digest(digest).commit_write(digest, write_session) 

121 

122 @timed(METRIC.STORAGE.DELETE_DURATION, type=TYPE) 

123 def bulk_delete(self, digests: list[Digest]) -> list[str]: 

124 failed_deletions: list[str] = [] 

125 for result in self._map(_bulk_delete_for_storage, self._partition_digests(digests).items()): 

126 failed_deletions.extend(result) 

127 

128 publish_counter_metric(METRIC.STORAGE.DELETE_ERRORS_COUNT, len(failed_deletions), type=self.TYPE) 

129 return failed_deletions 

130 

131 @timed(METRIC.STORAGE.BULK_STAT_DURATION, type=TYPE) 

132 def missing_blobs(self, digests: list[Digest]) -> list[Digest]: 

133 missing_blobs: list[Digest] = [] 

134 

135 for result in self._map(_fmb_for_storage, self._partition_digests(digests).items()): 

136 missing_blobs.extend(result) 

137 

138 return missing_blobs 

139 

140 @timed(METRIC.STORAGE.BULK_WRITE_DURATION, type=TYPE) 

141 def bulk_update_blobs(self, blobs: list[tuple[Digest, bytes]]) -> list[Status]: 

142 partitioned_digests: dict[StorageABC, list[tuple[Digest, bytes]]] = defaultdict(list) 

143 idx_map: dict[StorageABC, list[int]] = defaultdict(list) 

144 for orig_idx, digest_tuple in enumerate(blobs): 

145 storage = self._storage_from_digest(digest_tuple[0]) 

146 partitioned_digests[storage].append(digest_tuple) 

147 idx_map[storage].append(orig_idx) 

148 

149 results: list[Status] = [Status(code=code_pb2.INTERNAL, message="inconsistent batch results")] * len(blobs) 

150 for storage, statuses in self._map(_bulk_update_for_storage, partitioned_digests.items()): 

151 for status_idx, status in enumerate(statuses): 

152 results[idx_map[storage][status_idx]] = status 

153 return results 

154 

155 @timed(METRIC.STORAGE.BULK_READ_DURATION, type=TYPE) 

156 def bulk_read_blobs(self, digests: list[Digest]) -> dict[str, bytes]: 

157 bulk_read_results: dict[str, bytes] = {} 

158 for result in self._map(_bulk_read_for_storage, self._partition_digests(digests).items()): 

159 bulk_read_results.update(result) 

160 

161 return bulk_read_results