Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/storage/remote.py: 93.69%

111 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17RemoteStorage 

18================== 

19 

20Forwwards storage requests to a remote storage. 

21""" 

22 

23import io 

24import logging 

25from tempfile import NamedTemporaryFile 

26from typing import IO, Any, Dict, List, Optional, Sequence, Tuple 

27 

28import grpc 

29 

30import buildgrid.server.context as context_module 

31from buildgrid._exceptions import GrpcUninitializedError, NotFoundError 

32from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2, remote_execution_pb2_grpc 

33from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import CacheCapabilities, Digest 

34from buildgrid._protos.google.rpc import code_pb2, status_pb2 

35from buildgrid._protos.google.rpc.status_pb2 import Status 

36from buildgrid.client.authentication import ClientCredentials 

37from buildgrid.client.capabilities import CapabilitiesInterface 

38from buildgrid.client.cas import download, upload 

39from buildgrid.client.channel import setup_channel 

40from buildgrid.settings import HASH, MAX_IN_MEMORY_BLOB_SIZE_BYTES 

41 

42from .storage_abc import StorageABC 

43 

44LOGGER = logging.getLogger(__name__) 

45 

46 

47class RemoteStorage(StorageABC): 

48 def __init__( 

49 self, 

50 remote: str, 

51 instance_name: str, 

52 channel_options: Optional[Sequence[Tuple[str, Any]]] = None, 

53 credentials: Optional[ClientCredentials] = None, 

54 retries: int = 0, 

55 max_backoff: int = 64, 

56 request_timeout: Optional[float] = None, 

57 ) -> None: 

58 self.remote_instance_name = instance_name 

59 self._remote = remote 

60 self._channel_options = channel_options 

61 if credentials is None: 

62 credentials = {} 

63 self.credentials = credentials 

64 self.retries = retries 

65 self.max_backoff = max_backoff 

66 self._request_timeout = request_timeout 

67 

68 self._stub_cas: Optional[remote_execution_pb2_grpc.ContentAddressableStorageStub] = None 

69 self.channel: Optional[grpc.Channel] = None 

70 

71 def start(self) -> None: 

72 if self.channel is None: 

73 self.channel, *_ = setup_channel( 

74 self._remote, 

75 auth_token=self.credentials.get("auth-token"), 

76 auth_token_refresh_seconds=self.credentials.get("token-refresh-seconds"), 

77 client_key=self.credentials.get("tls-client-key"), 

78 client_cert=self.credentials.get("tls-client-cert"), 

79 server_cert=self.credentials.get("tls-server-cert"), 

80 timeout=self._request_timeout, 

81 ) 

82 

83 if self._stub_cas is None: 

84 self._stub_cas = remote_execution_pb2_grpc.ContentAddressableStorageStub(self.channel) 

85 

86 def stop(self) -> None: 

87 if self.channel is not None: 

88 self.channel.close() 

89 

90 def get_capabilities(self) -> CacheCapabilities: 

91 if self.channel is None: 

92 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

93 interface = CapabilitiesInterface(self.channel) 

94 capabilities = interface.get_capabilities(self.remote_instance_name) 

95 return capabilities.cache_capabilities 

96 

97 def has_blob(self, digest: Digest) -> bool: 

98 LOGGER.debug(f"Checking for blob: [{digest}]") 

99 if not self.missing_blobs([digest]): 

100 return True 

101 return False 

102 

103 def get_blob(self, digest: Digest) -> Optional[IO[bytes]]: 

104 if self.channel is None: 

105 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

106 

107 LOGGER.debug(f"Getting blob: [{digest}]") 

108 with download( 

109 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

110 ) as downloader: 

111 if digest.size_bytes > MAX_IN_MEMORY_BLOB_SIZE_BYTES: 

112 # Avoid storing the large blob completely in memory. 

113 temp_file = NamedTemporaryFile(delete=True) 

114 try: 

115 downloader.download_file(digest, temp_file.name, queue=False) 

116 except NotFoundError: 

117 return None 

118 # TODO Fix this: incompatible type "_TemporaryFileWrapper[bytes]"; expected "RawIOBase" 

119 reader = io.BufferedReader(temp_file) # type: ignore[arg-type] 

120 reader.seek(0) 

121 return reader 

122 else: 

123 blob = downloader.get_blob(digest) 

124 if blob is not None: 

125 return io.BytesIO(blob) 

126 else: 

127 return None 

128 

129 def delete_blob(self, digest: Digest) -> None: 

130 """The REAPI doesn't have a deletion method, so we can't support 

131 deletion for remote storage. 

132 """ 

133 raise NotImplementedError("Deletion is not supported for remote storage!") 

134 

135 def bulk_delete(self, digests: List[Digest]) -> List[str]: 

136 """The REAPI doesn't have a deletion method, so we can't support 

137 bulk deletion for remote storage. 

138 """ 

139 raise NotImplementedError(" Bulk deletion is not supported for remote storage!") 

140 

141 def commit_write(self, digest: Digest, write_session: IO[bytes]) -> None: 

142 if self.channel is None: 

143 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

144 

145 write_session.seek(0) 

146 LOGGER.debug(f"Writing blob: [{digest}]") 

147 with upload( 

148 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

149 ) as uploader: 

150 uploader.put_blob(write_session, digest=digest) 

151 

152 def missing_blobs(self, digests: List[Digest]) -> List[Digest]: 

153 if self._stub_cas is None: 

154 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

155 

156 if len(digests) > 100: 

157 LOGGER.debug(f"Missing blobs request for: {digests[:100]} (truncated)") 

158 else: 

159 LOGGER.debug(f"Missing blobs request for: {digests}") 

160 request = remote_execution_pb2.FindMissingBlobsRequest(instance_name=self.remote_instance_name) 

161 

162 for blob in digests: 

163 request_digest = request.blob_digests.add() 

164 request_digest.hash = blob.hash 

165 request_digest.size_bytes = blob.size_bytes 

166 

167 response = self._stub_cas.FindMissingBlobs(request, metadata=context_module.metadata_list()) 

168 

169 return list(response.missing_blob_digests) 

170 

171 def bulk_update_blobs(self, blobs: List[Tuple[Digest, bytes]]) -> List[Status]: 

172 if self._stub_cas is None or self.channel is None: 

173 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

174 

175 sent_digests = [] 

176 with upload( 

177 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

178 ) as uploader: 

179 for digest, blob in blobs: 

180 if len(blob) != digest.size_bytes or HASH(blob).hexdigest() != digest.hash: 

181 sent_digests.append(remote_execution_pb2.Digest()) 

182 else: 

183 sent_digests.append(uploader.put_blob(io.BytesIO(blob), digest=digest, queue=True)) 

184 

185 assert len(sent_digests) == len(blobs) 

186 

187 return [ 

188 status_pb2.Status(code=code_pb2.OK) if d.ByteSize() > 0 else status_pb2.Status(code=code_pb2.UNKNOWN) 

189 for d in sent_digests 

190 ] 

191 

192 def bulk_read_blobs(self, digests: List[Digest]) -> Dict[str, bytes]: 

193 if self._stub_cas is None or self.channel is None: 

194 raise GrpcUninitializedError("Remote CAS backend used before gRPC initialization.") 

195 

196 LOGGER.debug(f"Bulk read blobs request for: {digests}") 

197 with download( 

198 self.channel, instance=self.remote_instance_name, retries=self.retries, max_backoff=self.max_backoff 

199 ) as downloader: 

200 results = downloader.get_available_blobs(digests) 

201 # Transform List of (data, digest) pairs to expected hash-blob map 

202 return {digest.hash: data for data, digest in results}