Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/cas/service.py: 94.00%

150 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17CAS services 

18================== 

19 

20Implements the Content Addressable Storage API and ByteStream API. 

21""" 

22 

23import itertools 

24import logging 

25import re 

26from typing import Any, Dict, Iterator, Tuple, cast 

27 

28import grpc 

29 

30import buildgrid.server.context as context_module 

31from buildgrid._enums import ByteStreamResourceType 

32from buildgrid._exceptions import InvalidArgumentError 

33from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import DESCRIPTOR as RE_DESCRIPTOR 

34from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2 import ( 

35 BatchReadBlobsRequest, 

36 BatchReadBlobsResponse, 

37 BatchUpdateBlobsRequest, 

38 BatchUpdateBlobsResponse, 

39 Digest, 

40 FindMissingBlobsRequest, 

41 FindMissingBlobsResponse, 

42 GetTreeRequest, 

43 GetTreeResponse, 

44) 

45from buildgrid._protos.build.bazel.remote.execution.v2.remote_execution_pb2_grpc import ( 

46 ContentAddressableStorageServicer, 

47 add_ContentAddressableStorageServicer_to_server, 

48) 

49from buildgrid._protos.google.bytestream import bytestream_pb2, bytestream_pb2_grpc 

50from buildgrid._protos.google.bytestream.bytestream_pb2 import ( 

51 QueryWriteStatusRequest, 

52 QueryWriteStatusResponse, 

53 ReadRequest, 

54 ReadResponse, 

55 WriteRequest, 

56 WriteResponse, 

57) 

58from buildgrid._protos.google.rpc import code_pb2, status_pb2 

59from buildgrid.server.auth.manager import authorize 

60from buildgrid.server.cas.instance import ( 

61 EMPTY_BLOB, 

62 EMPTY_BLOB_DIGEST, 

63 ByteStreamInstance, 

64 ContentAddressableStorageInstance, 

65) 

66from buildgrid.server.instance import instanced 

67from buildgrid.server.metrics_names import ( 

68 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, 

69 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, 

70 CAS_BYTESTREAM_READ_TIME_METRIC_NAME, 

71 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, 

72 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, 

73 CAS_GET_TREE_TIME_METRIC_NAME, 

74) 

75from buildgrid.server.metrics_utils import DurationMetric, generator_method_duration_metric 

76from buildgrid.server.request_metadata_utils import printable_request_metadata 

77from buildgrid.server.servicer import InstancedServicer 

78from buildgrid.server.utils.decorators import ( 

79 handle_errors_stream_unary, 

80 handle_errors_unary_stream, 

81 handle_errors_unary_unary, 

82 track_request_id, 

83 track_request_id_generator, 

84) 

85from buildgrid.settings import HASH_LENGTH 

86 

87LOGGER = logging.getLogger(__name__) 

88 

89 

90def _printable_batch_update_blobs_request(request: BatchUpdateBlobsRequest) -> Dict[str, Any]: 

91 # Log the digests but not the data 

92 return { 

93 "instance_name": request.instance_name, 

94 "digests": [r.digest for r in request.requests], 

95 } 

96 

97 

98class ContentAddressableStorageService( 

99 ContentAddressableStorageServicer, InstancedServicer[ContentAddressableStorageInstance] 

100): 

101 REGISTER_METHOD = add_ContentAddressableStorageServicer_to_server 

102 FULL_NAME = RE_DESCRIPTOR.services_by_name["ContentAddressableStorage"].full_name 

103 

104 @instanced(lambda r: cast(str, r.instance_name)) 

105 @authorize 

106 @context_module.metadatacontext() 

107 @track_request_id 

108 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME) 

109 @handle_errors_unary_unary(FindMissingBlobsResponse) 

110 def FindMissingBlobs( 

111 self, request: FindMissingBlobsRequest, context: grpc.ServicerContext 

112 ) -> FindMissingBlobsResponse: 

113 LOGGER.info( 

114 f"FindMissingBlobs request from [{context.peer()}] " 

115 f"([{printable_request_metadata(context.invocation_metadata())}])" 

116 ) 

117 

118 instance = self.get_instance(request.instance_name) 

119 # No need to find the empty blob in the cas because the empty blob cannot be missing 

120 digests_to_find = [digest for digest in request.blob_digests if digest != EMPTY_BLOB_DIGEST] 

121 response = instance.find_missing_blobs(digests_to_find) 

122 return response 

123 

124 @instanced(lambda r: cast(str, r.instance_name)) 

125 @authorize 

126 @context_module.metadatacontext() 

127 @track_request_id 

128 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME) 

129 @handle_errors_unary_unary(BatchUpdateBlobsResponse, get_printable_request=_printable_batch_update_blobs_request) 

130 def BatchUpdateBlobs( 

131 self, request: BatchUpdateBlobsRequest, context: grpc.ServicerContext 

132 ) -> BatchUpdateBlobsResponse: 

133 LOGGER.info( 

134 f"BatchUpdateBlobs request from [{context.peer()}] " 

135 f"([{printable_request_metadata(context.invocation_metadata())}])" 

136 ) 

137 

138 instance = self.get_instance(request.instance_name) 

139 return instance.batch_update_blobs(request.requests) 

140 

141 @instanced(lambda r: cast(str, r.instance_name)) 

142 @authorize 

143 @context_module.metadatacontext() 

144 @track_request_id 

145 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME) 

146 @handle_errors_unary_unary(BatchReadBlobsResponse) 

147 def BatchReadBlobs(self, request: BatchReadBlobsRequest, context: grpc.ServicerContext) -> BatchReadBlobsResponse: 

148 LOGGER.info( 

149 f"BatchReadBlobs request from [{context.peer()}] " 

150 f"([{printable_request_metadata(context.invocation_metadata())}])" 

151 ) 

152 # No need to actually read the empty blob in the cas as it is always present 

153 digests_to_read = [digest for digest in request.digests if digest != EMPTY_BLOB_DIGEST] 

154 empty_digest_count = len(request.digests) - len(digests_to_read) 

155 

156 instance = self.get_instance(request.instance_name) 

157 response = instance.batch_read_blobs(digests_to_read) 

158 

159 # Append the empty blobs to the response 

160 for _ in range(empty_digest_count): 

161 response_proto = response.responses.add() 

162 response_proto.data = EMPTY_BLOB 

163 response_proto.digest.CopyFrom(EMPTY_BLOB_DIGEST) 

164 status_code = code_pb2.OK 

165 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

166 

167 return response 

168 

169 @instanced(lambda r: cast(str, r.instance_name)) 

170 @authorize 

171 @track_request_id_generator 

172 @generator_method_duration_metric(CAS_GET_TREE_TIME_METRIC_NAME) 

173 @handle_errors_unary_stream(GetTreeResponse) 

174 def GetTree(self, request: GetTreeRequest, context: grpc.ServicerContext) -> Iterator[GetTreeResponse]: 

175 LOGGER.info( 

176 f"GetTree request from [{context.peer()}] " 

177 f"([{printable_request_metadata(context.invocation_metadata())}])" 

178 ) 

179 

180 instance = self.get_instance(request.instance_name) 

181 yield from instance.get_tree(request) 

182 

183 

184class ResourceNameRegex: 

185 # CAS read name format: "{instance_name}/blobs/{hash}/{size}" 

186 READ = "^(.*?)/?(blobs/.*/[0-9]*)$" 

187 

188 # CAS write name format: "{instance_name}/uploads/{uuid}/blobs/{hash}/{size}[optional arbitrary extra content]" 

189 WRITE = "^(.*?)/?(uploads/.*/blobs/.*/[0-9]*)" 

190 

191 

192def _parse_resource_name(resource_name: str, regex: str) -> Tuple[str, str, "ByteStreamResourceType"]: 

193 cas_match = re.match(regex, resource_name) 

194 if cas_match: 

195 return cas_match[1], cas_match[2], ByteStreamResourceType.CAS 

196 else: 

197 raise InvalidArgumentError(f"Invalid resource name: [{resource_name}]") 

198 

199 

200def _printable_write_request(request: WriteRequest) -> Dict[str, Any]: 

201 # Log all the fields except `data`: 

202 return { 

203 "resource_name": request.resource_name, 

204 "write_offset": request.write_offset, 

205 "finish_write": request.finish_write, 

206 } 

207 

208 

209class ByteStreamService(bytestream_pb2_grpc.ByteStreamServicer, InstancedServicer[ByteStreamInstance]): 

210 REGISTER_METHOD = bytestream_pb2_grpc.add_ByteStreamServicer_to_server 

211 FULL_NAME = bytestream_pb2.DESCRIPTOR.services_by_name["ByteStream"].full_name 

212 

213 @instanced(lambda r: _parse_resource_name(r.resource_name, ResourceNameRegex.READ)[0]) 

214 @authorize 

215 @context_module.metadatacontext() 

216 @track_request_id_generator 

217 @generator_method_duration_metric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME) 

218 @handle_errors_unary_stream(ReadResponse) 

219 def Read(self, request: ReadRequest, context: grpc.ServicerContext) -> Iterator[ReadResponse]: 

220 LOGGER.info( 

221 f"Read request from [{context.peer()}] ([{printable_request_metadata(context.invocation_metadata())}])" 

222 ) 

223 instance_name, resource_name, resource_type = _parse_resource_name( 

224 request.resource_name, ResourceNameRegex.READ 

225 ) 

226 instance = self.get_instance(instance_name) 

227 if resource_type == ByteStreamResourceType.CAS: 

228 blob_details = resource_name.split("/") 

229 if len(blob_details[1]) != HASH_LENGTH: 

230 raise InvalidArgumentError(f"Invalid digest [{resource_name}]") 

231 try: 

232 digest = Digest(hash=blob_details[1], size_bytes=int(blob_details[2])) 

233 except ValueError: 

234 raise InvalidArgumentError(f"Invalid digest [{resource_name}]") 

235 

236 bytes_returned = 0 

237 expected_bytes = digest.size_bytes - request.read_offset 

238 if request.read_limit: 

239 expected_bytes = min(expected_bytes, request.read_limit) 

240 

241 try: 

242 if digest.size_bytes == 0: 

243 if digest.hash != EMPTY_BLOB_DIGEST.hash: 

244 raise InvalidArgumentError(f"Invalid digest [{digest.hash}/{digest.size_bytes}]") 

245 yield bytestream_pb2.ReadResponse(data=EMPTY_BLOB) 

246 return 

247 

248 for blob in instance.read_cas_blob(digest, request.read_offset, request.read_limit): 

249 bytes_returned += len(blob.data) 

250 yield blob 

251 finally: 

252 if bytes_returned != expected_bytes: 

253 LOGGER.warning( 

254 f"Read request {digest.hash}/{digest.size_bytes} exited early." 

255 f" bytes_returned={bytes_returned} expected_bytes={expected_bytes}" 

256 f" read_offset={request.read_offset} read_limit={request.read_limit}" 

257 ) 

258 else: 

259 LOGGER.info(f"Read request {digest.hash}/{digest.size_bytes} completed") 

260 

261 @instanced(lambda r: _parse_resource_name(r.resource_name, ResourceNameRegex.WRITE)[0]) 

262 @authorize 

263 @context_module.metadatacontext() 

264 @track_request_id 

265 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME) 

266 @handle_errors_stream_unary(WriteResponse, get_printable_request=_printable_write_request) 

267 def Write(self, request_iterator: Iterator[WriteRequest], context: grpc.ServicerContext) -> WriteResponse: 

268 LOGGER.info( 

269 f"Write request from [{context.peer()}] " 

270 f"([{printable_request_metadata(context.invocation_metadata())}])" 

271 ) 

272 

273 request = next(request_iterator) 

274 instance_name, resource_name, resource_type = _parse_resource_name( 

275 request.resource_name, 

276 ResourceNameRegex.WRITE, 

277 ) 

278 instance = self.get_instance(instance_name) 

279 if resource_type == ByteStreamResourceType.CAS: 

280 blob_details = resource_name.split("/") 

281 _, hash_, size_bytes = blob_details[1], blob_details[3], blob_details[4] 

282 return instance.write_cas_blob(hash_, size_bytes, itertools.chain([request], request_iterator)) 

283 return bytestream_pb2.WriteResponse() 

284 

285 @instanced(lambda r: _parse_resource_name(r.resource_name, ResourceNameRegex.WRITE)[0]) 

286 @authorize 

287 @track_request_id 

288 @handle_errors_unary_unary(QueryWriteStatusResponse) 

289 def QueryWriteStatus( 

290 self, request: QueryWriteStatusRequest, context: grpc.ServicerContext 

291 ) -> QueryWriteStatusResponse: 

292 LOGGER.info(f"QueryWriteStatus request from [{context.peer()}]") 

293 context.abort(grpc.StatusCode.UNIMPLEMENTED, "Method not implemented!")