Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16""" 

17Storage Instances 

18================= 

19Instances of CAS and ByteStream 

20""" 

21 

22import logging 

23 

24from buildgrid._enums import MetricRecordDomain 

25from buildgrid._exceptions import InvalidArgumentError, NotFoundError, OutOfRangeError, PermissionDeniedError 

26from buildgrid._protos.google.bytestream import bytestream_pb2 

27from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 as re_pb2 

28from buildgrid._protos.google.rpc import code_pb2, status_pb2 

29from buildgrid.server.metrics_utils import ( 

30 Counter, 

31 DurationMetric, 

32 ExceptionCounter, 

33 Distribution) 

34from buildgrid.server.metrics_names import ( 

35 CAS_EXCEPTION_COUNT_METRIC_NAME, 

36 CAS_DOWNLOADED_BYTES_METRIC_NAME, 

37 CAS_UPLOADED_BYTES_METRIC_NAME, 

38 CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

39 CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

40 CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

41 CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, 

42 CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, 

43 CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, 

44 CAS_GET_TREE_TIME_METRIC_NAME, 

45 CAS_BYTESTREAM_READ_TIME_METRIC_NAME, 

46 CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME) 

47from buildgrid.settings import HASH, HASH_LENGTH, MAX_REQUEST_SIZE, MAX_REQUEST_COUNT 

48from buildgrid.utils import create_digest, get_hash_type, get_unique_objects_by_attribute 

49 

50 

51class ContentAddressableStorageInstance: 

52 

53 def __init__(self, storage, read_only=False): 

54 self.__logger = logging.getLogger(__name__) 

55 

56 self._instance_name = None 

57 

58 self.__storage = storage 

59 

60 self.__read_only = read_only 

61 

62 if self.__storage: 

63 # Some clients skip uploading a blob with size 0. 

64 # We pre-insert the empty blob to make sure that it is available. 

65 empty_digest = create_digest(b'') 

66 session = self.__storage.begin_write(empty_digest) 

67 self.__storage.commit_write(empty_digest, session) 

68 

69 # --- Public API --- 

70 

71 @property 

72 def instance_name(self): 

73 return self._instance_name 

74 

75 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

76 def register_instance_with_server(self, instance_name, server): 

77 """Names and registers the CAS instance with a given server.""" 

78 if self._instance_name is None: 

79 server.add_cas_instance(self, instance_name) 

80 

81 self._instance_name = instance_name 

82 

83 else: 

84 raise AssertionError("Instance already registered") 

85 

86 def hash_type(self): 

87 return get_hash_type() 

88 

89 def max_batch_total_size_bytes(self): 

90 return MAX_REQUEST_SIZE 

91 

92 def symlink_absolute_path_strategy(self): 

93 # Currently this strategy is hardcoded into BuildGrid 

94 # With no setting to reference 

95 return re_pb2.SymlinkAbsolutePathStrategy.DISALLOWED 

96 

97 @DurationMetric(CAS_FIND_MISSING_BLOBS_TIME_METRIC_NAME, MetricRecordDomain.CAS, instanced=True) 

98 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

99 def find_missing_blobs(self, blob_digests): 

100 storage = self.__storage 

101 blob_digests = list(get_unique_objects_by_attribute(blob_digests, "hash")) 

102 missing_blobs = storage.missing_blobs(blob_digests) 

103 

104 num_blobs_in_request = len(blob_digests) 

105 if num_blobs_in_request > 0: 

106 num_blobs_missing = len(missing_blobs) 

107 percent_missing = float((num_blobs_missing / num_blobs_in_request) * 100) 

108 

109 with Distribution(CAS_FIND_MISSING_BLOBS_NUM_REQUESTED_METRIC_NAME, 

110 instance_name=self._instance_name, 

111 metric_domain=MetricRecordDomain.CAS) as metric_num_requested: 

112 metric_num_requested.count = float(num_blobs_in_request) 

113 

114 with Distribution(CAS_FIND_MISSING_BLOBS_NUM_MISSING_METRIC_NAME, 

115 instance_name=self._instance_name, 

116 metric_domain=MetricRecordDomain.CAS) as metric_num_missing: 

117 metric_num_missing.count = float(num_blobs_missing) 

118 

119 with Distribution(CAS_FIND_MISSING_BLOBS_PERCENT_MISSING_METRIC_NAME, 

120 instance_name=self._instance_name, 

121 metric_domain=MetricRecordDomain.CAS) as metric_percent_missing: 

122 metric_percent_missing.count = percent_missing 

123 

124 return re_pb2.FindMissingBlobsResponse(missing_blob_digests=missing_blobs) 

125 

126 @DurationMetric(CAS_BATCH_UPDATE_BLOBS_TIME_METRIC_NAME, MetricRecordDomain.CAS, instanced=True) 

127 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

128 def batch_update_blobs(self, requests): 

129 if self.__read_only: 

130 raise PermissionDeniedError(f"CAS instance {self._instance_name} is read-only") 

131 

132 storage = self.__storage 

133 store = [] 

134 for request_proto in get_unique_objects_by_attribute(requests, "digest.hash"): 

135 store.append((request_proto.digest, request_proto.data)) 

136 

137 response = re_pb2.BatchUpdateBlobsResponse() 

138 statuses = storage.bulk_update_blobs(store) 

139 

140 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, 

141 instance_name=self._instance_name, 

142 metric_domain=MetricRecordDomain.CAS) as bytes_counter: 

143 for (digest, _), status in zip(store, statuses): 

144 response_proto = response.responses.add() 

145 response_proto.digest.CopyFrom(digest) 

146 response_proto.status.CopyFrom(status) 

147 if response_proto.status.code == 0: 

148 bytes_counter.increment(response_proto.digest.size_bytes) 

149 

150 return response 

151 

152 @DurationMetric(CAS_BATCH_READ_BLOBS_TIME_METRIC_NAME, MetricRecordDomain.CAS, instanced=True) 

153 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

154 def batch_read_blobs(self, digests): 

155 storage = self.__storage 

156 

157 response = re_pb2.BatchReadBlobsResponse() 

158 

159 max_batch_size = self.max_batch_total_size_bytes() 

160 

161 # Only process unique digests 

162 good_digests = [] 

163 bad_digests = [] 

164 requested_bytes = 0 

165 for digest in get_unique_objects_by_attribute(digests, "hash"): 

166 if len(digest.hash) != HASH_LENGTH: 

167 bad_digests.append(digest) 

168 else: 

169 good_digests.append(digest) 

170 requested_bytes += digest.size_bytes 

171 

172 if requested_bytes > max_batch_size: 

173 raise InvalidArgumentError('Combined total size of blobs exceeds ' 

174 'server limit. ' 

175 f'({requested_bytes} > {max_batch_size} [byte])') 

176 

177 blobs_read = storage.bulk_read_blobs(good_digests) 

178 

179 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

180 instance_name=self._instance_name, 

181 metric_domain=MetricRecordDomain.CAS) as bytes_counter: 

182 for digest in good_digests: 

183 response_proto = response.responses.add() 

184 response_proto.digest.CopyFrom(digest) 

185 

186 if digest.hash in blobs_read and blobs_read[digest.hash] is not None: 

187 response_proto.data = blobs_read[digest.hash].read() 

188 status_code = code_pb2.OK 

189 bytes_counter.increment(digest.size_bytes) 

190 else: 

191 status_code = code_pb2.NOT_FOUND 

192 

193 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

194 

195 for digest in bad_digests: 

196 response_proto = response.responses.add() 

197 response_proto.digest.CopyFrom(digest) 

198 status_code = code_pb2.INVALID_ARGUMENT 

199 response_proto.status.CopyFrom(status_pb2.Status(code=status_code)) 

200 

201 return response 

202 

203 @DurationMetric(CAS_GET_TREE_TIME_METRIC_NAME, MetricRecordDomain.CAS, instanced=True) 

204 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

205 def get_tree(self, request): 

206 storage = self.__storage 

207 

208 response = re_pb2.GetTreeResponse() 

209 page_size = request.page_size 

210 

211 if not request.page_size: 

212 request.page_size = MAX_REQUEST_COUNT 

213 

214 root_digest = request.root_digest 

215 page_size = request.page_size 

216 

217 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

218 instance_name=self._instance_name, 

219 metric_domain=MetricRecordDomain.CAS) as bytes_counter: 

220 def __get_tree(node_digest): 

221 nonlocal response, page_size, request 

222 

223 if not page_size: 

224 page_size = request.page_size 

225 yield response 

226 response = re_pb2.GetTreeResponse() 

227 

228 if response.ByteSize() >= (MAX_REQUEST_SIZE): 

229 yield response 

230 response = re_pb2.GetTreeResponse() 

231 

232 directory_from_digest = storage.get_message( 

233 node_digest, re_pb2.Directory) 

234 

235 bytes_counter.increment(node_digest.size_bytes) 

236 

237 page_size -= 1 

238 response.directories.extend([directory_from_digest]) 

239 

240 for directory in directory_from_digest.directories: 

241 yield from __get_tree(directory.digest) 

242 

243 yield response 

244 response = re_pb2.GetTreeResponse() 

245 

246 yield from __get_tree(root_digest) 

247 

248 

249class ByteStreamInstance: 

250 

251 BLOCK_SIZE = 1 * 1024 * 1024 # 1 MB block size 

252 

253 def __init__(self, storage, read_only=False): 

254 self.__logger = logging.getLogger(__name__) 

255 

256 self._instance_name = None 

257 

258 self.__storage = storage 

259 

260 self.__read_only = read_only 

261 

262 if self.__storage: 

263 # Some clients skip uploading a blob with size 0. 

264 # We pre-insert the empty blob to make sure that it is available. 

265 empty_digest = create_digest(b'') 

266 session = self.__storage.begin_write(empty_digest) 

267 self.__storage.commit_write(empty_digest, session) 

268 

269 # --- Public API --- 

270 

271 @property 

272 def instance_name(self): 

273 return self._instance_name 

274 

275 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

276 def register_instance_with_server(self, instance_name, server): 

277 """Names and registers the byte-stream instance with a given server.""" 

278 if self._instance_name is None: 

279 server.add_bytestream_instance(self, instance_name) 

280 

281 self._instance_name = instance_name 

282 

283 else: 

284 raise AssertionError("Instance already registered") 

285 

286 @DurationMetric(CAS_BYTESTREAM_READ_TIME_METRIC_NAME, MetricRecordDomain.CAS, instanced=True) 

287 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

288 def read(self, digest_hash, digest_size, read_offset, read_limit): 

289 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

290 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

291 

292 digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size)) 

293 

294 # Check the given read offset and limit. 

295 if read_offset < 0 or read_offset > digest.size_bytes: 

296 raise OutOfRangeError("Read offset out of range") 

297 

298 elif read_limit == 0: 

299 bytes_remaining = digest.size_bytes - read_offset 

300 

301 elif read_limit > 0: 

302 bytes_remaining = read_limit 

303 

304 else: 

305 raise InvalidArgumentError("Negative read_limit is invalid") 

306 

307 # Read the blob from storage and send its contents to the client. 

308 result = self.__storage.get_blob(digest) 

309 if result is None: 

310 raise NotFoundError("Blob not found") 

311 

312 elif result.seekable(): 

313 result.seek(read_offset) 

314 

315 else: 

316 result.read(read_offset) 

317 

318 with Counter(metric_name=CAS_DOWNLOADED_BYTES_METRIC_NAME, 

319 instance_name=self._instance_name, 

320 metric_domain=MetricRecordDomain.CAS) as bytes_counter: 

321 while bytes_remaining > 0: 

322 block_data = result.read(min(self.BLOCK_SIZE, bytes_remaining)) 

323 yield bytestream_pb2.ReadResponse(data=block_data) 

324 bytes_counter.increment(len(block_data)) 

325 bytes_remaining -= self.BLOCK_SIZE 

326 

327 @DurationMetric(CAS_BYTESTREAM_WRITE_TIME_METRIC_NAME, MetricRecordDomain.CAS, instanced=True) 

328 @ExceptionCounter(CAS_EXCEPTION_COUNT_METRIC_NAME, metric_domain=MetricRecordDomain.CAS) 

329 def write(self, digest_hash, digest_size, first_block, other_blocks): 

330 if self.__read_only: 

331 raise PermissionDeniedError( 

332 f"ByteStream instance {self._instance_name} is read-only") 

333 

334 if len(digest_hash) != HASH_LENGTH or not digest_size.isdigit(): 

335 raise InvalidArgumentError(f"Invalid digest [{digest_hash}/{digest_size}]") 

336 

337 digest = re_pb2.Digest(hash=digest_hash, size_bytes=int(digest_size)) 

338 

339 write_session = self.__storage.begin_write(digest) 

340 

341 # Start the write session and write the first request's data. 

342 with Counter(metric_name=CAS_UPLOADED_BYTES_METRIC_NAME, 

343 instance_name=self._instance_name, 

344 metric_domain=MetricRecordDomain.CAS) as bytes_counter: 

345 write_session.write(first_block) 

346 computed_hash = HASH(first_block) 

347 bytes_counter.increment(len(first_block)) 

348 

349 # Handle subsequent write requests. 

350 for next_block in other_blocks: 

351 write_session.write(next_block) 

352 

353 computed_hash.update(next_block) 

354 bytes_counter.increment(len(next_block)) 

355 

356 # Check that the data matches the provided digest. 

357 if bytes_counter.count != digest.size_bytes: 

358 raise NotImplementedError( 

359 "Cannot close stream before finishing write") 

360 

361 elif computed_hash.hexdigest() != digest.hash: 

362 raise InvalidArgumentError("Data does not match hash") 

363 

364 self.__storage.commit_write(digest, write_session) 

365 

366 return bytestream_pb2.WriteResponse(committed_size=int(bytes_counter.count))