Coverage for /builds/BuildGrid/buildgrid/buildgrid/client/channel.py: 72.64%

106 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import namedtuple 

16from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union 

17from urllib.parse import urlparse 

18 

19import grpc 

20from grpc import aio 

21 

22from buildgrid._exceptions import InvalidArgumentError 

23from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

24from buildgrid.client.auth_token_loader import AuthTokenLoader 

25from buildgrid.client.authentication import AuthMetadataClientInterceptorBase, load_tls_channel_credentials 

26from buildgrid.client.interceptors import ( 

27 AsyncStreamStreamInterceptor, 

28 AsyncStreamUnaryInterceptor, 

29 AsyncUnaryStreamInterceptor, 

30 AsyncUnaryUnaryInterceptor, 

31 SyncStreamStreamInterceptor, 

32 SyncStreamUnaryInterceptor, 

33 SyncUnaryStreamInterceptor, 

34 SyncUnaryUnaryInterceptor, 

35) 

36from buildgrid.settings import REQUEST_METADATA_HEADER_NAME, REQUEST_METADATA_TOOL_NAME, REQUEST_METADATA_TOOL_VERSION 

37from buildgrid.utils import insecure_uri_schemes, secure_uri_schemes 

38 

39 

40def setup_channel( 

41 remote_url: str, 

42 auth_token: Optional[str] = None, 

43 auth_token_refresh_seconds: Optional[int] = None, 

44 client_key: Optional[str] = None, 

45 client_cert: Optional[str] = None, 

46 server_cert: Optional[str] = None, 

47 action_id: Optional[str] = None, 

48 tool_invocation_id: Optional[str] = None, 

49 correlated_invocations_id: Optional[str] = None, 

50 asynchronous: bool = False, 

51 timeout: Optional[float] = None, 

52) -> Tuple[grpc.Channel, Tuple[Optional[str], ...]]: 

53 """Creates a new gRPC client communication chanel. 

54 

55 If `remote_url` does not point to a socket and does not specify a 

56 port number, defaults 50051. 

57 

58 Args: 

59 remote_url (str): URL for the remote, including protocol and, 

60 if not a Unix domain socket, a port. 

61 auth_token (str): Authorization token file path. 

62 auth_token_refresh_seconds(int): Time in seconds to read the authorization token again from file 

63 server_cert(str): TLS certificate chain file path. 

64 client_key (str): TLS root certificate file path. 

65 client_cert (str): TLS private key file path. 

66 action_id (str): Action identifier to which the request belongs to. 

67 tool_invocation_id (str): Identifier for a related group of Actions. 

68 correlated_invocations_id (str): Identifier that ties invocations together. 

69 timeout (float): Request timeout in seconds. 

70 

71 Returns: 

72 Channel: Client Channel to be used in order to access the server 

73 at `remote_url`. 

74 

75 Raises: 

76 InvalidArgumentError: On any input parsing error. 

77 """ 

78 url = urlparse(remote_url) 

79 

80 url_is_socket = url.scheme == "unix" 

81 if url_is_socket: 

82 remote = remote_url 

83 else: 

84 remote = f"{url.hostname}:{url.port or 50051}" 

85 

86 details: Tuple[Optional[str], Optional[str], Optional[str]] = None, None, None 

87 credentials_provided = any((server_cert, client_cert, client_key)) 

88 auth_token_loader: Optional[AuthTokenLoader] = None 

89 if auth_token: 

90 auth_token_loader = AuthTokenLoader(auth_token, auth_token_refresh_seconds) 

91 

92 if asynchronous: 

93 async_interceptors = _create_async_interceptors( 

94 auth_token_loader=auth_token_loader, 

95 action_id=action_id, 

96 tool_invocation_id=tool_invocation_id, 

97 correlated_invocations_id=correlated_invocations_id, 

98 timeout=timeout, 

99 ) 

100 

101 if url.scheme in insecure_uri_schemes or (url_is_socket and not credentials_provided): 

102 async_channel = aio.insecure_channel(remote, interceptors=async_interceptors) 

103 elif url.scheme in secure_uri_schemes or (url_is_socket and credentials_provided): 

104 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert) 

105 if not credentials: 

106 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded") 

107 async_channel = aio.secure_channel(remote, credentials, interceptors=async_interceptors) 

108 else: 

109 raise InvalidArgumentError("Given remote does not specify a protocol") 

110 

111 # TODO use overloads to make this return an async channel when asynchronous == True 

112 return async_channel, details # type: ignore[return-value] 

113 

114 else: 

115 sync_interceptors = _create_sync_interceptors( 

116 auth_token_loader=auth_token_loader, 

117 action_id=action_id, 

118 tool_invocation_id=tool_invocation_id, 

119 correlated_invocations_id=correlated_invocations_id, 

120 timeout=timeout, 

121 ) 

122 

123 if url.scheme in insecure_uri_schemes or (url_is_socket and not credentials_provided): 

124 sync_channel = grpc.insecure_channel(remote) 

125 elif url.scheme in secure_uri_schemes or (url_is_socket and credentials_provided): 

126 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert) 

127 if not credentials: 

128 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded") 

129 sync_channel = grpc.secure_channel(remote, credentials) 

130 else: 

131 raise InvalidArgumentError("Given remote does not specify a protocol") 

132 

133 for interceptor in sync_interceptors: 

134 sync_channel = grpc.intercept_channel(sync_channel, interceptor) 

135 

136 return sync_channel, details 

137 

138 

139class RequestMetadataInterceptorBase: 

140 def __init__( 

141 self, 

142 action_id: Optional[str] = None, 

143 tool_invocation_id: Optional[str] = None, 

144 correlated_invocations_id: Optional[str] = None, 

145 ) -> None: 

146 """Appends optional `RequestMetadata` header values to each call. 

147 

148 Args: 

149 action_id (str): Action identifier to which the request belongs to. 

150 tool_invocation_id (str): Identifier for a related group of Actions. 

151 correlated_invocations_id (str): Identifier that ties invocations together. 

152 """ 

153 self._action_id = action_id 

154 self._tool_invocation_id = tool_invocation_id 

155 self._correlated_invocations_id = correlated_invocations_id 

156 

157 self.__header_field_name = REQUEST_METADATA_HEADER_NAME 

158 self.__header_field_value = self._request_metadata() 

159 

160 def _request_metadata(self) -> bytes: 

161 """Creates a serialized RequestMetadata entry to attach to a gRPC 

162 call header. Arguments should be of type str or None. 

163 """ 

164 request_metadata = remote_execution_pb2.RequestMetadata() 

165 request_metadata.tool_details.tool_name = REQUEST_METADATA_TOOL_NAME 

166 request_metadata.tool_details.tool_version = REQUEST_METADATA_TOOL_VERSION 

167 

168 if self._action_id: 

169 request_metadata.action_id = self._action_id 

170 if self._tool_invocation_id: 

171 request_metadata.tool_invocation_id = self._tool_invocation_id 

172 if self._correlated_invocations_id: 

173 request_metadata.correlated_invocations_id = self._correlated_invocations_id 

174 

175 return request_metadata.SerializeToString() 

176 

177 def amend_call_details( # type: ignore[no-untyped-def] # wait for client lib updates here 

178 self, client_call_details, grpc_call_details_class: Any 

179 ): 

180 if client_call_details.metadata is not None: 

181 new_metadata = list(client_call_details.metadata) 

182 else: 

183 new_metadata = [] 

184 

185 new_metadata.append((self.__header_field_name, self.__header_field_value)) 

186 

187 class _ClientCallDetails( 

188 namedtuple( 

189 "_ClientCallDetails", 

190 ( 

191 "method", 

192 "timeout", 

193 "credentials", 

194 "metadata", 

195 "wait_for_ready", 

196 ), 

197 ), 

198 grpc_call_details_class, # type: ignore 

199 ): 

200 pass 

201 

202 return _ClientCallDetails( 

203 client_call_details.method, 

204 client_call_details.timeout, 

205 client_call_details.credentials, 

206 new_metadata, 

207 client_call_details.wait_for_ready, 

208 ) 

209 

210 

211class TimeoutInterceptorBase: 

212 def __init__(self, timeout: float) -> None: 

213 """Applies a request timeout to each call. 

214 

215 Args: 

216 timeout (float): Request timeout in seconds. 

217 """ 

218 self._timeout = timeout 

219 

220 def amend_call_details( # type: ignore[no-untyped-def] # wait for client lib updates here 

221 self, client_call_details, grpc_call_details_class: Any 

222 ): 

223 # If there are multiple timeouts, apply the shorter timeout (earliest deadline wins) 

224 if client_call_details.timeout is not None: 

225 new_timeout = min(self._timeout, client_call_details.timeout) 

226 else: 

227 new_timeout = self._timeout 

228 

229 class _ClientCallDetails( 

230 namedtuple( 

231 "_ClientCallDetails", 

232 ( 

233 "method", 

234 "timeout", 

235 "credentials", 

236 "metadata", 

237 "wait_for_ready", 

238 ), 

239 ), 

240 grpc_call_details_class, # type: ignore 

241 ): 

242 pass 

243 

244 return _ClientCallDetails( 

245 client_call_details.method, 

246 new_timeout, 

247 client_call_details.credentials, 

248 client_call_details.metadata, 

249 client_call_details.wait_for_ready, 

250 ) 

251 

252 

253if TYPE_CHECKING: 

254 SyncInterceptorsList = List[ 

255 Union[ 

256 grpc.UnaryUnaryClientInterceptor[Any, Any], 

257 grpc.UnaryStreamClientInterceptor[Any, Any], 

258 grpc.StreamUnaryClientInterceptor[Any, Any], 

259 grpc.StreamStreamClientInterceptor[Any, Any], 

260 ] 

261 ] 

262 

263 

264def _create_sync_interceptors( 

265 auth_token_loader: Optional[AuthTokenLoader] = None, 

266 action_id: Optional[str] = None, 

267 tool_invocation_id: Optional[str] = None, 

268 correlated_invocations_id: Optional[str] = None, 

269 timeout: Optional[float] = None, 

270) -> "SyncInterceptorsList": 

271 interceptors: "SyncInterceptorsList" = [] 

272 

273 request_metadata_interceptor = RequestMetadataInterceptorBase( 

274 action_id=action_id, 

275 tool_invocation_id=tool_invocation_id, 

276 correlated_invocations_id=correlated_invocations_id, 

277 ) 

278 

279 interceptors += [ 

280 SyncUnaryUnaryInterceptor(request_metadata_interceptor), 

281 SyncUnaryStreamInterceptor(request_metadata_interceptor), 

282 SyncStreamUnaryInterceptor(request_metadata_interceptor), 

283 SyncStreamStreamInterceptor(request_metadata_interceptor), 

284 ] 

285 

286 if auth_token_loader is not None: 

287 auth_metadata_client_interceptor = AuthMetadataClientInterceptorBase(auth_token_loader=auth_token_loader) 

288 interceptors += [ 

289 SyncUnaryUnaryInterceptor(auth_metadata_client_interceptor), 

290 SyncUnaryStreamInterceptor(auth_metadata_client_interceptor), 

291 SyncStreamUnaryInterceptor(auth_metadata_client_interceptor), 

292 SyncStreamStreamInterceptor(auth_metadata_client_interceptor), 

293 ] 

294 

295 if timeout is not None: 

296 timeout_interceptor_base = TimeoutInterceptorBase(timeout) 

297 interceptors += [ 

298 SyncUnaryUnaryInterceptor(timeout_interceptor_base), 

299 SyncUnaryStreamInterceptor(timeout_interceptor_base), 

300 SyncStreamUnaryInterceptor(timeout_interceptor_base), 

301 SyncStreamStreamInterceptor(timeout_interceptor_base), 

302 ] 

303 

304 return interceptors 

305 

306 

307def _create_async_interceptors( 

308 auth_token_loader: Optional[AuthTokenLoader] = None, 

309 action_id: Optional[str] = None, 

310 tool_invocation_id: Optional[str] = None, 

311 correlated_invocations_id: Optional[str] = None, 

312 timeout: Optional[float] = None, 

313) -> List[aio.ClientInterceptor]: 

314 # FIXME Types not happy... "list" has incompatible type "..."; expected "_PartialStubMustCastOrIgnore" 

315 interceptors: List[Any] = [] 

316 

317 request_metadata_interceptor = RequestMetadataInterceptorBase( 

318 action_id=action_id, 

319 tool_invocation_id=tool_invocation_id, 

320 correlated_invocations_id=correlated_invocations_id, 

321 ) 

322 

323 interceptors += [ 

324 AsyncUnaryUnaryInterceptor(request_metadata_interceptor), 

325 AsyncUnaryStreamInterceptor(request_metadata_interceptor), 

326 AsyncStreamUnaryInterceptor(request_metadata_interceptor), 

327 AsyncStreamStreamInterceptor(request_metadata_interceptor), 

328 ] 

329 

330 if auth_token_loader is not None: 

331 auth_metadata_client_interceptor = AuthMetadataClientInterceptorBase(auth_token_loader=auth_token_loader) 

332 interceptors += [ 

333 AsyncUnaryUnaryInterceptor(auth_metadata_client_interceptor), 

334 AsyncUnaryStreamInterceptor(auth_metadata_client_interceptor), 

335 AsyncStreamUnaryInterceptor(auth_metadata_client_interceptor), 

336 AsyncStreamStreamInterceptor(auth_metadata_client_interceptor), 

337 ] 

338 

339 if timeout is not None: 

340 timeout_interceptor_base = TimeoutInterceptorBase(timeout) 

341 interceptors += [ 

342 AsyncUnaryUnaryInterceptor(timeout_interceptor_base), 

343 AsyncUnaryStreamInterceptor(timeout_interceptor_base), 

344 AsyncStreamUnaryInterceptor(timeout_interceptor_base), 

345 AsyncStreamStreamInterceptor(timeout_interceptor_base), 

346 ] 

347 

348 return interceptors