Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/remote.py: 92.62%

122 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2025-04-14 16:27 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17RemoteStorage 

18================== 

19 

20Forwwards storage requests to a remote storage. 

21""" 

22 

23import io 

24import logging 

25from tempfile import NamedTemporaryFile 

26from typing import IO, Any, Sequence 

27 

28import grpc 

29 

30from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc 

31from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import Digest 

32from buildgrid._protos.google.rpc import code_pb2, status_pb2 

33from buildgrid._protos.google.rpc.status_pb2 import Status 

34from buildgrid.server.client.authentication import ClientCredentials 

35from buildgrid.server.client.cas import download, upload 

36from buildgrid.server.client.channel import setup_channel 

37from buildgrid.server.context import current_instance 

38from buildgrid.server.decorators import timed 

39from buildgrid.server.exceptions import GrpcUninitializedError, NotFoundError 

40from buildgrid.server.logging import buildgrid_logger 

41from buildgrid.server.metadata import metadata_list 

42from buildgrid.server.metrics_names import METRIC 

43from buildgrid.server.settings import HASH, MAX_IN_MEMORY_BLOB_SIZE_BYTES 

44 

45from .storage_abc import StorageABC 

46 

47LOGGER = buildgrid_logger(__name__) 

48 

49 

50class RemoteStorage(StorageABC): 

51 TYPE = "Remote" 

52 

53 def __init__( 

54 self, 

55 remote: str, 

56 instance_name: str | None = None, 

57 channel_options: Sequence[tuple[str, Any]] | None = None, 

58 credentials: ClientCredentials | None = None, 

59 retries: int = 0, 

60 max_backoff: int = 64, 

61 request_timeout: float | None = None, 

62 ) -> None: 

63 self._remote_instance_name = instance_name 

64 self._remote = remote 

65 self._channel_options = channel_options 

66 if credentials is None: 

67 credentials = {} 

68 self.credentials = credentials 

69 self.retries = retries 

70 self.max_backoff = max_backoff 

71 self._request_timeout = request_timeout 

72 

73 self._stub_cas: remote_execution_pb2_grpc.ContentAddressableStorageStub | None = None 

74 self.channel: grpc.Channel | None = None 

75 

76 def start(self) -> None: 

77 if self.channel is None: 

78 self.channel, *_ = setup_channel( 

79 self._remote, 

80 auth_token=self.credentials.get("auth-token"), 

81 auth_token_refresh_seconds=self.credentials.get("token-refresh-seconds"), 

82 client_key=self.credentials.get("tls-client-key"), 

83 client_cert=self.credentials.get("tls-client-cert"), 

84 server_cert=self.credentials.get("tls-server-cert"), 

85 timeout=self._request_timeout, 

86 ) 

87 

88 if self._stub_cas is None: 

89 self._stub_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) 

90 

91 def stop(self) -> None: 

92 if self.channel is not None: 

93 self.channel.close() 

94 

95 @property 

96 def remote_instance_name(self) -> str: 

97 if self._remote_instance_name is not None: 

98 return self._remote_instance_name 

99 return current_instance() 

100 

101 @timed(METRIC.STORAGE.STAT_DURATION, type=TYPE) 

102 def has_blob(self, digest: Digest) -> bool: 

103 LOGGER.debug("Checking for blob.", tags=dict(digest=digest)) 

104 if not self.missing_blobs([digest]): 

105 return True 

106 return False 

107 

108 @timed(METRIC.STORAGE.READ_DURATION, type=TYPE) 

109 def get_blob(self, digest: Digest) -> IO[bytes] | None: 

110 if self.channel is None: 

111 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

112 

113 LOGGER.debug("Getting blob.", tags=dict(digest=digest)) 

114 with download( 

115 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

116 ) as downloader: 

117 if digest.size_bytes > MAX_IN_MEMORY_BLOB_SIZE_BYTES: 

118 # Avoid storing the large blob completely in memory. 

119 temp_file = NamedTemporaryFile(delete=True) 

120 try: 

121 downloader.download_file(digest, temp_file.name, queue=False) 

122 except NotFoundError: 

123 return None 

124 # TODO Fix this: incompatible type "_TemporaryFileWrapper[bytes]"; expected "RawIOBase" 

125 reader = io.BufferedReader(temp_file) # type: ignore[arg-type] 

126 reader.seek(0) 

127 return reader 

128 else: 

129 blob = downloader.get_blob(digest) 

130 if blob is not None: 

131 return io.BytesIO(blob) 

132 else: 

133 return None 

134 

135 def delete_blob(self, digest: Digest) -> None: 

136 """The REAPI doesn't have a deletion method, so we can't support 

137 deletion for remote storage. 

138 """ 

139 raise NotImplementedError("Deletion is not supported for remote storage!") 

140 

141 def bulk_delete(self, digests: list[Digest]) -> list[str]: 

142 """The REAPI doesn't have a deletion method, so we can't support 

143 bulk deletion for remote storage. 

144 """ 

145 raise NotImplementedError("Bulk deletion is not supported for remote storage!") 

146 

147 @timed(METRIC.STORAGE.WRITE_DURATION, type=TYPE) 

148 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None: 

149 if self.channel is None: 

150 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

151 

152 write_session.seek(0) 

153 LOGGER.debug("Writing blob.", tags=dict(digest=digest)) 

154 with upload( 

155 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

156 ) as uploader: 

157 uploader.put_blob(write_session, digest=digest) 

158 

159 @timed(METRIC.STORAGE.BULK_STAT_DURATION, type=TYPE) 

160 def missing_blobs(self, digests: list[Digest]) -> list[Digest]: 

161 if self._stub_cas is None: 

162 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

163 

164 # Avoid expensive string creation. 

165 if LOGGER.is_enabled_for(logging.DEBUG): 

166 if len(digests) > 100: 

167 LOGGER.debug(f"Missing blobs request for: {digests[:100]} (truncated)") 

168 else: 

169 LOGGER.debug(f"Missing blobs request for: {digests}") 

170 

171 request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=self.remote_instance_name) 

172 

173 for blob in digests: 

174 request_digest = request.blob_digests.add() 

175 request_digest.hash = blob.hash 

176 request_digest.size_bytes = blob.size_bytes 

177 

178 response = self._stub_cas.FindMissingBlobs(request, metadata=metadata_list()) 

179 

180 return list(response.missing_blob_digests) 

181 

182 @timed(METRIC.STORAGE.BULK_WRITE_DURATION, type=TYPE) 

183 def bulk_update_blobs(self, blobs: list[tuple[Digest, bytes]]) -> list[Status]: 

184 if self._stub_cas is None or self.channel is None: 

185 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

186 

187 sent_digests = [] 

188 with upload( 

189 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

190 ) as uploader: 

191 for digest, blob in blobs: 

192 if len(blob) != digest.size_bytes or HASH(blob).hexdigest() != digest.hash: 

193 sent_digests.append(remote_execution_pb2.Digest()) 

194 else: 

195 sent_digests.append(uploader.put_blob(io.BytesIO(blob), digest=digest, queue=True)) 

196 

197 assert len(sent_digests) == len(blobs) 

198 

199 return [ 

200 status_pb2.Status(code=code_pb2.OK) if d.ByteSize() > 0 else status_pb2.Status(code=code_pb2.UNKNOWN) 

201 for d in sent_digests 

202 ] 

203 

204 @timed(METRIC.STORAGE.BULK_READ_DURATION, type=TYPE) 

205 def bulk_read_blobs(self, digests: list[Digest]) -> dict[str, bytes]: 

206 if self._stub_cas is None or self.channel is None: 

207 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

208 

209 # Avoid expensive string creation. 

210 if LOGGER.is_enabled_for(logging.DEBUG): 

211 LOGGER.debug(f"Bulk read blobs request for: {digests}") 

212 

213 with download( 

214 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

215 ) as downloader: 

216 results = downloader.get_available_blobs(digests) 

217 # Transform List of (data, digest) pairs to expected hash-blob map 

218 return {digest.hash: data for data, digest in results}