Coverage for /builds/BuildGrid/buildgrid/buildgrid/client/channel.py: 30.00%

160 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-22 21:04 +0000

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import namedtuple 

16from urllib.parse import urlparse 

17from typing import Any, List, Optional, TYPE_CHECKING, Union 

18 

19import grpc 

20from grpc import aio # type: ignore 

21 

22from buildgrid.client.authentication import AsyncAuthMetadataClientInterceptor, AuthMetadataClientInterceptor 

23from buildgrid.client.authentication import load_channel_authorization_token 

24from buildgrid.client.authentication import load_tls_channel_credentials 

25from buildgrid._exceptions import InvalidArgumentError 

26from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

27from buildgrid.settings import REQUEST_METADATA_HEADER_NAME 

28from buildgrid.settings import REQUEST_METADATA_TOOL_NAME, REQUEST_METADATA_TOOL_VERSION 

29from buildgrid.utils import insecure_uri_schemes, secure_uri_schemes 

30 

31 

32def setup_channel(remote_url: str, 

33 auth_token: Optional[str]=None, client_key: Optional[str]=None, 

34 client_cert: Optional[str]=None, server_cert: Optional[str]=None, 

35 action_id: Optional[str]=None, tool_invocation_id: Optional[str]=None, 

36 correlated_invocations_id: Optional[str]=None, asynchronous: bool=False, 

37 timeout: Optional[float]=None): 

38 """Creates a new gRPC client communication chanel. 

39 

40 If `remote_url` does not point to a socket and does not specify a 

41 port number, defaults 50051. 

42 

43 Args: 

44 remote_url (str): URL for the remote, including protocol and, 

45 if not a Unix domain socket, a port. 

46 auth_token (str): Authorization token file path. 

47 server_cert(str): TLS certificate chain file path. 

48 client_key (str): TLS root certificate file path. 

49 client_cert (str): TLS private key file path. 

50 action_id (str): Action identifier to which the request belongs to. 

51 tool_invocation_id (str): Identifier for a related group of Actions. 

52 correlated_invocations_id (str): Identifier that ties invocations together. 

53 timeout (float): Request timeout in seconds. 

54 

55 Returns: 

56 Channel: Client Channel to be used in order to access the server 

57 at `remote_url`. 

58 

59 Raises: 

60 InvalidArgumentError: On any input parsing error. 

61 """ 

62 url = urlparse(remote_url) 

63 

64 url_is_socket = (url.scheme == 'unix') 

65 if url_is_socket: 

66 remote = remote_url 

67 else: 

68 remote = f'{url.hostname}:{url.port or 50051}' 

69 

70 details = None, None, None 

71 credentials_provided = any((server_cert, client_cert, client_key)) 

72 

73 if asynchronous: 

74 async_interceptors = _create_async_interceptors( 

75 auth_token=auth_token, 

76 action_id=action_id, 

77 tool_invocation_id=tool_invocation_id, 

78 correlated_invocations_id=correlated_invocations_id, 

79 timeout=timeout) 

80 else: 

81 sync_interceptors = _create_sync_interceptors( 

82 auth_token=auth_token, 

83 action_id=action_id, 

84 tool_invocation_id=tool_invocation_id, 

85 correlated_invocations_id=correlated_invocations_id, 

86 timeout=timeout) 

87 

88 if url.scheme in insecure_uri_schemes or (url_is_socket and not credentials_provided): 

89 if asynchronous: 

90 channel = aio.insecure_channel(remote, interceptors=async_interceptors) 

91 else: 

92 channel = grpc.insecure_channel(remote) 

93 elif url.scheme in secure_uri_schemes or (url_is_socket and credentials_provided): 

94 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert) 

95 if not credentials: 

96 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded") 

97 

98 if asynchronous: 

99 channel = aio.secure_channel(remote, credentials, interceptors=async_interceptors) 

100 else: 

101 channel = grpc.secure_channel(remote, credentials) 

102 

103 else: 

104 raise InvalidArgumentError("Given remote does not specify a protocol") 

105 

106 if not asynchronous: 

107 for interceptor in sync_interceptors: 

108 channel = grpc.intercept_channel(channel, interceptor) 

109 

110 return channel, details 

111 

112 

113class RequestMetadataInterceptorBase: 

114 

115 def __init__(self, action_id: Optional[str]=None, 

116 tool_invocation_id: Optional[str]=None, 

117 correlated_invocations_id: Optional[str]=None): 

118 """Appends optional `RequestMetadata` header values to each call. 

119 

120 Args: 

121 action_id (str): Action identifier to which the request belongs to. 

122 tool_invocation_id (str): Identifier for a related group of Actions. 

123 correlated_invocations_id (str): Identifier that ties invocations together. 

124 """ 

125 self._action_id = action_id 

126 self._tool_invocation_id = tool_invocation_id 

127 self._correlated_invocations_id = correlated_invocations_id 

128 

129 self.__header_field_name = REQUEST_METADATA_HEADER_NAME 

130 self.__header_field_value = self._request_metadata() 

131 

132 def _request_metadata(self): 

133 """Creates a serialized RequestMetadata entry to attach to a gRPC 

134 call header. Arguments should be of type str or None. 

135 """ 

136 request_metadata = remote_execution_pb2.RequestMetadata() 

137 request_metadata.tool_details.tool_name = REQUEST_METADATA_TOOL_NAME 

138 request_metadata.tool_details.tool_version = REQUEST_METADATA_TOOL_VERSION 

139 

140 if self._action_id: 

141 request_metadata.action_id = self._action_id 

142 if self._tool_invocation_id: 

143 request_metadata.tool_invocation_id = self._tool_invocation_id 

144 if self._correlated_invocations_id: 

145 request_metadata.correlated_invocations_id = self._correlated_invocations_id 

146 

147 return request_metadata.SerializeToString() 

148 

149 def _amend_call_details(self, client_call_details, grpc_call_details_class): 

150 if client_call_details.metadata is not None: 

151 new_metadata = list(client_call_details.metadata) 

152 else: 

153 new_metadata = [] 

154 

155 new_metadata.append((self.__header_field_name, 

156 self.__header_field_value)) 

157 

158 class _ClientCallDetails( 

159 namedtuple('_ClientCallDetails', 

160 ('method', 'timeout', 'credentials', 'metadata', 'wait_for_ready',)), 

161 grpc_call_details_class): 

162 pass 

163 

164 return _ClientCallDetails(client_call_details.method, 

165 client_call_details.timeout, 

166 client_call_details.credentials, 

167 new_metadata, 

168 client_call_details.wait_for_ready) 

169 

170 

171class RequestMetadataInterceptor(RequestMetadataInterceptorBase, 

172 grpc.UnaryUnaryClientInterceptor, 

173 grpc.UnaryStreamClientInterceptor, 

174 grpc.StreamUnaryClientInterceptor, 

175 grpc.StreamStreamClientInterceptor): 

176 

177 def __init__(self, action_id: Optional[str]=None, 

178 tool_invocation_id: Optional[str]=None, 

179 correlated_invocations_id: Optional[str]=None): 

180 RequestMetadataInterceptorBase.__init__( 

181 self, 

182 action_id=action_id, 

183 tool_invocation_id=tool_invocation_id, 

184 correlated_invocations_id=correlated_invocations_id 

185 ) 

186 

187 def intercept_unary_unary(self, continuation, client_call_details, request): 

188 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

189 

190 return continuation(new_details, request) 

191 

192 def intercept_unary_stream(self, continuation, client_call_details, request): 

193 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

194 

195 return continuation(new_details, request) 

196 

197 def intercept_stream_unary(self, continuation, client_call_details, request_iterator): 

198 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

199 

200 return continuation(new_details, request_iterator) 

201 

202 def intercept_stream_stream(self, continuation, client_call_details, request_iterator): 

203 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

204 

205 return continuation(new_details, request_iterator) 

206 

207 

208class AsyncRequestMetadataInterceptor(RequestMetadataInterceptorBase, 

209 aio.UnaryUnaryClientInterceptor, 

210 aio.UnaryStreamClientInterceptor, 

211 aio.StreamUnaryClientInterceptor, 

212 aio.StreamStreamClientInterceptor): 

213 

214 def __init__(self, action_id: Optional[str]=None, 

215 tool_invocation_id: Optional[str]=None, 

216 correlated_invocations_id: Optional[str]=None): 

217 RequestMetadataInterceptorBase.__init__( 

218 self, 

219 action_id=action_id, 

220 tool_invocation_id=tool_invocation_id, 

221 correlated_invocations_id=correlated_invocations_id 

222 ) 

223 

224 async def intercept_unary_unary(self, continuation, client_call_details, request): 

225 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

226 

227 return await continuation(new_details, request) 

228 

229 async def intercept_unary_stream(self, continuation, client_call_details, request): 

230 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

231 

232 return await continuation(new_details, request) 

233 

234 async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): 

235 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

236 

237 return await continuation(new_details, request_iterator) 

238 

239 async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): 

240 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

241 

242 return await continuation(new_details, request_iterator) 

243 

244 

245class TimeoutInterceptorBase: 

246 

247 def __init__(self, timeout: float): 

248 """Applies a request timeout to each call. 

249 

250 Args: 

251 timeout (float): Request timeout in seconds. 

252 """ 

253 self._timeout = timeout 

254 

255 def _amend_call_details(self, client_call_details, grpc_call_details_class): 

256 # If there are multiple timeouts, apply the shorter timeout (earliest deadline wins) 

257 if client_call_details.timeout is not None: 

258 new_timeout = min(self._timeout, client_call_details.timeout) 

259 else: 

260 new_timeout = self._timeout 

261 

262 class _ClientCallDetails( 

263 namedtuple('_ClientCallDetails', 

264 ('method', 'timeout', 'credentials', 'metadata', 'wait_for_ready',)), 

265 grpc_call_details_class): 

266 pass 

267 

268 return _ClientCallDetails(client_call_details.method, 

269 new_timeout, 

270 client_call_details.credentials, 

271 client_call_details.metadata, 

272 client_call_details.wait_for_ready) 

273 

274 

275class TimeoutInterceptor(TimeoutInterceptorBase, 

276 grpc.UnaryUnaryClientInterceptor, 

277 grpc.UnaryStreamClientInterceptor, 

278 grpc.StreamUnaryClientInterceptor, 

279 grpc.StreamStreamClientInterceptor): 

280 

281 def __init__(self, timeout: float): 

282 TimeoutInterceptorBase.__init__(self, timeout=timeout) 

283 

284 def intercept_unary_unary(self, continuation, client_call_details, request): 

285 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

286 

287 return continuation(new_details, request) 

288 

289 def intercept_unary_stream(self, continuation, client_call_details, request): 

290 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

291 

292 return continuation(new_details, request) 

293 

294 def intercept_stream_unary(self, continuation, client_call_details, request_iterator): 

295 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

296 

297 return continuation(new_details, request_iterator) 

298 

299 def intercept_stream_stream(self, continuation, client_call_details, request_iterator): 

300 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

301 

302 return continuation(new_details, request_iterator) 

303 

304 

305class AsyncTimeoutInterceptor(TimeoutInterceptorBase, 

306 aio.UnaryUnaryClientInterceptor, 

307 aio.UnaryStreamClientInterceptor, 

308 aio.StreamUnaryClientInterceptor, 

309 aio.StreamStreamClientInterceptor): 

310 

311 def __init__(self, timeout: float): 

312 TimeoutInterceptorBase.__init__(self, timeout=timeout) 

313 

314 async def intercept_unary_unary(self, continuation, client_call_details, request): 

315 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

316 

317 return await continuation(new_details, request) 

318 

319 async def intercept_unary_stream(self, continuation, client_call_details, request): 

320 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

321 

322 return await continuation(new_details, request) 

323 

324 async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): 

325 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

326 

327 return await continuation(new_details, request_iterator) 

328 

329 async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): 

330 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

331 

332 return await continuation(new_details, request_iterator) 

333 

334 

335if TYPE_CHECKING: 

336 # pylint: disable=unsubscriptable-object 

337 SyncInterceptorsList = List[ 

338 Union[ 

339 grpc.UnaryUnaryClientInterceptor[Any, Any], 

340 grpc.UnaryStreamClientInterceptor[Any, Any], 

341 grpc.StreamUnaryClientInterceptor[Any, Any], 

342 grpc.StreamStreamClientInterceptor[Any, Any] 

343 ] 

344 ] 

345 

346 

347def _create_sync_interceptors(auth_token: Optional[str]=None, action_id: Optional[str]=None, 

348 tool_invocation_id: Optional[str]=None, 

349 correlated_invocations_id: Optional[str]=None, 

350 timeout: Optional[float]=None) -> 'SyncInterceptorsList': 

351 interceptors: 'SyncInterceptorsList' = [] 

352 interceptors.append(RequestMetadataInterceptor( 

353 action_id=action_id, 

354 tool_invocation_id=tool_invocation_id, 

355 correlated_invocations_id=correlated_invocations_id)) 

356 

357 if auth_token is not None: 

358 token = load_channel_authorization_token(auth_token) 

359 if not token: 

360 raise InvalidArgumentError("Given authorization token could be loaded") 

361 interceptors.append(AuthMetadataClientInterceptor(auth_token=token)) 

362 

363 if timeout is not None: 

364 interceptors.append(TimeoutInterceptor(timeout)) 

365 

366 return interceptors 

367 

368 

369def _create_async_interceptors(auth_token: Optional[str]=None, action_id: Optional[str]=None, 

370 tool_invocation_id: Optional[str]=None, 

371 correlated_invocations_id: Optional[str]=None, 

372 timeout: Optional[float]=None) -> List[aio.ClientInterceptor]: 

373 interceptors: List[aio.ClientInterceptor] = [] 

374 interceptors.append(AsyncRequestMetadataInterceptor( 

375 action_id=action_id, 

376 tool_invocation_id=tool_invocation_id, 

377 correlated_invocations_id=correlated_invocations_id)) 

378 

379 if auth_token is not None: 

380 token = load_channel_authorization_token(auth_token) 

381 if not token: 

382 raise InvalidArgumentError("Given authorization token could be loaded") 

383 interceptors.append(AsyncAuthMetadataClientInterceptor(auth_token=token)) 

384 

385 if timeout is not None: 

386 interceptors.append(AsyncTimeoutInterceptor(timeout)) 

387 

388 return interceptors