Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/client/channel.py: 72.38%

105 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import namedtuple 

16from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union 

17from urllib.parse import urlparse 

18 

19import grpc 

20from grpc import aio 

21 

22from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

23from buildgrid.server.client.auth_token_loader import AuthTokenLoader 

24from buildgrid.server.client.authentication import AuthMetadataClientInterceptorBase, load_tls_channel_credentials 

25from buildgrid.server.client.interceptors import ( 

26 AsyncStreamStreamInterceptor, 

27 AsyncStreamUnaryInterceptor, 

28 AsyncUnaryStreamInterceptor, 

29 AsyncUnaryUnaryInterceptor, 

30 SyncStreamStreamInterceptor, 

31 SyncStreamUnaryInterceptor, 

32 SyncUnaryStreamInterceptor, 

33 SyncUnaryUnaryInterceptor, 

34) 

35from buildgrid.server.exceptions import InvalidArgumentError 

36from buildgrid.server.settings import ( 

37 INSECURE_URI_SCHEMES, 

38 REQUEST_METADATA_HEADER_NAME, 

39 REQUEST_METADATA_TOOL_NAME, 

40 REQUEST_METADATA_TOOL_VERSION, 

41 SECURE_URI_SCHEMES, 

42) 

43 

44 

45def setup_channel( 

46 remote_url: str, 

47 auth_token: Optional[str] = None, 

48 auth_token_refresh_seconds: Optional[int] = None, 

49 client_key: Optional[str] = None, 

50 client_cert: Optional[str] = None, 

51 server_cert: Optional[str] = None, 

52 action_id: Optional[str] = None, 

53 tool_invocation_id: Optional[str] = None, 

54 correlated_invocations_id: Optional[str] = None, 

55 asynchronous: bool = False, 

56 timeout: Optional[float] = None, 

57) -> Tuple[grpc.Channel, Tuple[Optional[str], ...]]: 

58 """Creates a new gRPC client communication chanel. 

59 

60 If `remote_url` does not point to a socket and does not specify a 

61 port number, defaults 50051. 

62 

63 Args: 

64 remote_url (str): URL for the remote, including protocol and, 

65 if not a Unix domain socket, a port. 

66 auth_token (str): Authorization token file path. 

67 auth_token_refresh_seconds(int): Time in seconds to read the authorization token again from file 

68 server_cert(str): TLS certificate chain file path. 

69 client_key (str): TLS root certificate file path. 

70 client_cert (str): TLS private key file path. 

71 action_id (str): Action identifier to which the request belongs to. 

72 tool_invocation_id (str): Identifier for a related group of Actions. 

73 correlated_invocations_id (str): Identifier that ties invocations together. 

74 timeout (float): Request timeout in seconds. 

75 

76 Returns: 

77 Channel: Client Channel to be used in order to access the server 

78 at `remote_url`. 

79 

80 Raises: 

81 InvalidArgumentError: On any input parsing error. 

82 """ 

83 url = urlparse(remote_url) 

84 

85 url_is_socket = url.scheme == "unix" 

86 if url_is_socket: 

87 remote = remote_url 

88 else: 

89 remote = f"{url.hostname}:{url.port or 50051}" 

90 

91 details: Tuple[Optional[str], Optional[str], Optional[str]] = None, None, None 

92 credentials_provided = any((server_cert, client_cert, client_key)) 

93 auth_token_loader: Optional[AuthTokenLoader] = None 

94 if auth_token: 

95 auth_token_loader = AuthTokenLoader(auth_token, auth_token_refresh_seconds) 

96 

97 if asynchronous: 

98 async_interceptors = _create_async_interceptors( 

99 auth_token_loader=auth_token_loader, 

100 action_id=action_id, 

101 tool_invocation_id=tool_invocation_id, 

102 correlated_invocations_id=correlated_invocations_id, 

103 timeout=timeout, 

104 ) 

105 

106 if url.scheme in INSECURE_URI_SCHEMES or (url_is_socket and not credentials_provided): 

107 async_channel = aio.insecure_channel(remote, interceptors=async_interceptors) 

108 elif url.scheme in SECURE_URI_SCHEMES or (url_is_socket and credentials_provided): 

109 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert) 

110 if not credentials: 

111 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded") 

112 async_channel = aio.secure_channel(remote, credentials, interceptors=async_interceptors) 

113 else: 

114 raise InvalidArgumentError("Given remote does not specify a protocol") 

115 

116 # TODO use overloads to make this return an async channel when asynchronous == True 

117 return async_channel, details # type: ignore[return-value] 

118 

119 else: 

120 sync_interceptors = _create_sync_interceptors( 

121 auth_token_loader=auth_token_loader, 

122 action_id=action_id, 

123 tool_invocation_id=tool_invocation_id, 

124 correlated_invocations_id=correlated_invocations_id, 

125 timeout=timeout, 

126 ) 

127 

128 if url.scheme in INSECURE_URI_SCHEMES or (url_is_socket and not credentials_provided): 

129 sync_channel = grpc.insecure_channel(remote) 

130 elif url.scheme in SECURE_URI_SCHEMES or (url_is_socket and credentials_provided): 

131 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert) 

132 if not credentials: 

133 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded") 

134 sync_channel = grpc.secure_channel(remote, credentials) 

135 else: 

136 raise InvalidArgumentError("Given remote does not specify a protocol") 

137 

138 for interceptor in sync_interceptors: 

139 sync_channel = grpc.intercept_channel(sync_channel, interceptor) 

140 

141 return sync_channel, details 

142 

143 

144class RequestMetadataInterceptorBase: 

145 def __init__( 

146 self, 

147 action_id: Optional[str] = None, 

148 tool_invocation_id: Optional[str] = None, 

149 correlated_invocations_id: Optional[str] = None, 

150 ) -> None: 

151 """Appends optional `RequestMetadata` header values to each call. 

152 

153 Args: 

154 action_id (str): Action identifier to which the request belongs to. 

155 tool_invocation_id (str): Identifier for a related group of Actions. 

156 correlated_invocations_id (str): Identifier that ties invocations together. 

157 """ 

158 self._action_id = action_id 

159 self._tool_invocation_id = tool_invocation_id 

160 self._correlated_invocations_id = correlated_invocations_id 

161 

162 self.__header_field_name = REQUEST_METADATA_HEADER_NAME 

163 self.__header_field_value = self._request_metadata() 

164 

165 def _request_metadata(self) -> bytes: 

166 """Creates a serialized RequestMetadata entry to attach to a gRPC 

167 call header. Arguments should be of type str or None. 

168 """ 

169 request_metadata = remote_execution_pb2.RequestMetadata() 

170 request_metadata.tool_details.tool_name = REQUEST_METADATA_TOOL_NAME 

171 request_metadata.tool_details.tool_version = REQUEST_METADATA_TOOL_VERSION 

172 

173 if self._action_id: 

174 request_metadata.action_id = self._action_id 

175 if self._tool_invocation_id: 

176 request_metadata.tool_invocation_id = self._tool_invocation_id 

177 if self._correlated_invocations_id: 

178 request_metadata.correlated_invocations_id = self._correlated_invocations_id 

179 

180 return request_metadata.SerializeToString() 

181 

182 def amend_call_details( # type: ignore[no-untyped-def] # wait for client lib updates here 

183 self, client_call_details, grpc_call_details_class: Any 

184 ): 

185 if client_call_details.metadata is not None: 

186 new_metadata = list(client_call_details.metadata) 

187 else: 

188 new_metadata = [] 

189 

190 new_metadata.append((self.__header_field_name, self.__header_field_value)) 

191 

192 class _ClientCallDetails( 

193 namedtuple( 

194 "_ClientCallDetails", 

195 ( 

196 "method", 

197 "timeout", 

198 "credentials", 

199 "metadata", 

200 "wait_for_ready", 

201 ), 

202 ), 

203 grpc_call_details_class, # type: ignore 

204 ): 

205 pass 

206 

207 return _ClientCallDetails( 

208 client_call_details.method, 

209 client_call_details.timeout, 

210 client_call_details.credentials, 

211 new_metadata, 

212 client_call_details.wait_for_ready, 

213 ) 

214 

215 

216class TimeoutInterceptorBase: 

217 def __init__(self, timeout: float) -> None: 

218 """Applies a request timeout to each call. 

219 

220 Args: 

221 timeout (float): Request timeout in seconds. 

222 """ 

223 self._timeout = timeout 

224 

225 def amend_call_details( # type: ignore[no-untyped-def] # wait for client lib updates here 

226 self, client_call_details, grpc_call_details_class: Any 

227 ): 

228 # If there are multiple timeouts, apply the shorter timeout (earliest deadline wins) 

229 if client_call_details.timeout is not None: 

230 new_timeout = min(self._timeout, client_call_details.timeout) 

231 else: 

232 new_timeout = self._timeout 

233 

234 class _ClientCallDetails( 

235 namedtuple( 

236 "_ClientCallDetails", 

237 ( 

238 "method", 

239 "timeout", 

240 "credentials", 

241 "metadata", 

242 "wait_for_ready", 

243 ), 

244 ), 

245 grpc_call_details_class, # type: ignore 

246 ): 

247 pass 

248 

249 return _ClientCallDetails( 

250 client_call_details.method, 

251 new_timeout, 

252 client_call_details.credentials, 

253 client_call_details.metadata, 

254 client_call_details.wait_for_ready, 

255 ) 

256 

257 

258if TYPE_CHECKING: 

259 SyncInterceptorsList = List[ 

260 Union[ 

261 grpc.UnaryUnaryClientInterceptor[Any, Any], 

262 grpc.UnaryStreamClientInterceptor[Any, Any], 

263 grpc.StreamUnaryClientInterceptor[Any, Any], 

264 grpc.StreamStreamClientInterceptor[Any, Any], 

265 ] 

266 ] 

267 

268 

269def _create_sync_interceptors( 

270 auth_token_loader: Optional[AuthTokenLoader] = None, 

271 action_id: Optional[str] = None, 

272 tool_invocation_id: Optional[str] = None, 

273 correlated_invocations_id: Optional[str] = None, 

274 timeout: Optional[float] = None, 

275) -> "SyncInterceptorsList": 

276 interceptors: "SyncInterceptorsList" = [] 

277 

278 request_metadata_interceptor = RequestMetadataInterceptorBase( 

279 action_id=action_id, 

280 tool_invocation_id=tool_invocation_id, 

281 correlated_invocations_id=correlated_invocations_id, 

282 ) 

283 

284 interceptors += [ 

285 SyncUnaryUnaryInterceptor(request_metadata_interceptor), 

286 SyncUnaryStreamInterceptor(request_metadata_interceptor), 

287 SyncStreamUnaryInterceptor(request_metadata_interceptor), 

288 SyncStreamStreamInterceptor(request_metadata_interceptor), 

289 ] 

290 

291 if auth_token_loader is not None: 

292 auth_metadata_client_interceptor = AuthMetadataClientInterceptorBase(auth_token_loader=auth_token_loader) 

293 interceptors += [ 

294 SyncUnaryUnaryInterceptor(auth_metadata_client_interceptor), 

295 SyncUnaryStreamInterceptor(auth_metadata_client_interceptor), 

296 SyncStreamUnaryInterceptor(auth_metadata_client_interceptor), 

297 SyncStreamStreamInterceptor(auth_metadata_client_interceptor), 

298 ] 

299 

300 if timeout is not None: 

301 timeout_interceptor_base = TimeoutInterceptorBase(timeout) 

302 interceptors += [ 

303 SyncUnaryUnaryInterceptor(timeout_interceptor_base), 

304 SyncUnaryStreamInterceptor(timeout_interceptor_base), 

305 SyncStreamUnaryInterceptor(timeout_interceptor_base), 

306 SyncStreamStreamInterceptor(timeout_interceptor_base), 

307 ] 

308 

309 return interceptors 

310 

311 

312def _create_async_interceptors( 

313 auth_token_loader: Optional[AuthTokenLoader] = None, 

314 action_id: Optional[str] = None, 

315 tool_invocation_id: Optional[str] = None, 

316 correlated_invocations_id: Optional[str] = None, 

317 timeout: Optional[float] = None, 

318) -> List[aio.ClientInterceptor]: 

319 # FIXME Types not happy... "list" has incompatible type "..."; expected "_PartialStubMustCastOrIgnore" 

320 interceptors: List[Any] = [] 

321 

322 request_metadata_interceptor = RequestMetadataInterceptorBase( 

323 action_id=action_id, 

324 tool_invocation_id=tool_invocation_id, 

325 correlated_invocations_id=correlated_invocations_id, 

326 ) 

327 

328 interceptors += [ 

329 AsyncUnaryUnaryInterceptor(request_metadata_interceptor), 

330 AsyncUnaryStreamInterceptor(request_metadata_interceptor), 

331 AsyncStreamUnaryInterceptor(request_metadata_interceptor), 

332 AsyncStreamStreamInterceptor(request_metadata_interceptor), 

333 ] 

334 

335 if auth_token_loader is not None: 

336 auth_metadata_client_interceptor = AuthMetadataClientInterceptorBase(auth_token_loader=auth_token_loader) 

337 interceptors += [ 

338 AsyncUnaryUnaryInterceptor(auth_metadata_client_interceptor), 

339 AsyncUnaryStreamInterceptor(auth_metadata_client_interceptor), 

340 AsyncStreamUnaryInterceptor(auth_metadata_client_interceptor), 

341 AsyncStreamStreamInterceptor(auth_metadata_client_interceptor), 

342 ] 

343 

344 if timeout is not None: 

345 timeout_interceptor_base = TimeoutInterceptorBase(timeout) 

346 interceptors += [ 

347 AsyncUnaryUnaryInterceptor(timeout_interceptor_base), 

348 AsyncUnaryStreamInterceptor(timeout_interceptor_base), 

349 AsyncStreamUnaryInterceptor(timeout_interceptor_base), 

350 AsyncStreamStreamInterceptor(timeout_interceptor_base), 

351 ] 

352 

353 return interceptors