Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/client/channel.py: 72.38%
105 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2019 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections import namedtuple
16from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
17from urllib.parse import urlparse
19import grpc
20from grpc import aio
22from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2
23from buildgrid.server.client.auth_token_loader import AuthTokenLoader
24from buildgrid.server.client.authentication import AuthMetadataClientInterceptorBase, load_tls_channel_credentials
25from buildgrid.server.client.interceptors import (
26 AsyncStreamStreamInterceptor,
27 AsyncStreamUnaryInterceptor,
28 AsyncUnaryStreamInterceptor,
29 AsyncUnaryUnaryInterceptor,
30 SyncStreamStreamInterceptor,
31 SyncStreamUnaryInterceptor,
32 SyncUnaryStreamInterceptor,
33 SyncUnaryUnaryInterceptor,
34)
35from buildgrid.server.exceptions import InvalidArgumentError
36from buildgrid.server.settings import (
37 INSECURE_URI_SCHEMES,
38 REQUEST_METADATA_HEADER_NAME,
39 REQUEST_METADATA_TOOL_NAME,
40 REQUEST_METADATA_TOOL_VERSION,
41 SECURE_URI_SCHEMES,
42)
45def setup_channel(
46 remote_url: str,
47 auth_token: Optional[str] = None,
48 auth_token_refresh_seconds: Optional[int] = None,
49 client_key: Optional[str] = None,
50 client_cert: Optional[str] = None,
51 server_cert: Optional[str] = None,
52 action_id: Optional[str] = None,
53 tool_invocation_id: Optional[str] = None,
54 correlated_invocations_id: Optional[str] = None,
55 asynchronous: bool = False,
56 timeout: Optional[float] = None,
57) -> Tuple[grpc.Channel, Tuple[Optional[str], ...]]:
58 """Creates a new gRPC client communication chanel.
60 If `remote_url` does not point to a socket and does not specify a
61 port number, defaults 50051.
63 Args:
64 remote_url (str): URL for the remote, including protocol and,
65 if not a Unix domain socket, a port.
66 auth_token (str): Authorization token file path.
67 auth_token_refresh_seconds(int): Time in seconds to read the authorization token again from file
68 server_cert(str): TLS certificate chain file path.
69 client_key (str): TLS root certificate file path.
70 client_cert (str): TLS private key file path.
71 action_id (str): Action identifier to which the request belongs to.
72 tool_invocation_id (str): Identifier for a related group of Actions.
73 correlated_invocations_id (str): Identifier that ties invocations together.
74 timeout (float): Request timeout in seconds.
76 Returns:
77 Channel: Client Channel to be used in order to access the server
78 at `remote_url`.
80 Raises:
81 InvalidArgumentError: On any input parsing error.
82 """
83 url = urlparse(remote_url)
85 url_is_socket = url.scheme == "unix"
86 if url_is_socket:
87 remote = remote_url
88 else:
89 remote = f"{url.hostname}:{url.port or 50051}"
91 details: Tuple[Optional[str], Optional[str], Optional[str]] = None, None, None
92 credentials_provided = any((server_cert, client_cert, client_key))
93 auth_token_loader: Optional[AuthTokenLoader] = None
94 if auth_token:
95 auth_token_loader = AuthTokenLoader(auth_token, auth_token_refresh_seconds)
97 if asynchronous:
98 async_interceptors = _create_async_interceptors(
99 auth_token_loader=auth_token_loader,
100 action_id=action_id,
101 tool_invocation_id=tool_invocation_id,
102 correlated_invocations_id=correlated_invocations_id,
103 timeout=timeout,
104 )
106 if url.scheme in INSECURE_URI_SCHEMES or (url_is_socket and not credentials_provided):
107 async_channel = aio.insecure_channel(remote, interceptors=async_interceptors)
108 elif url.scheme in SECURE_URI_SCHEMES or (url_is_socket and credentials_provided):
109 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert)
110 if not credentials:
111 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded")
112 async_channel = aio.secure_channel(remote, credentials, interceptors=async_interceptors)
113 else:
114 raise InvalidArgumentError("Given remote does not specify a protocol")
116 # TODO use overloads to make this return an async channel when asynchronous == True
117 return async_channel, details # type: ignore[return-value]
119 else:
120 sync_interceptors = _create_sync_interceptors(
121 auth_token_loader=auth_token_loader,
122 action_id=action_id,
123 tool_invocation_id=tool_invocation_id,
124 correlated_invocations_id=correlated_invocations_id,
125 timeout=timeout,
126 )
128 if url.scheme in INSECURE_URI_SCHEMES or (url_is_socket and not credentials_provided):
129 sync_channel = grpc.insecure_channel(remote)
130 elif url.scheme in SECURE_URI_SCHEMES or (url_is_socket and credentials_provided):
131 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert)
132 if not credentials:
133 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded")
134 sync_channel = grpc.secure_channel(remote, credentials)
135 else:
136 raise InvalidArgumentError("Given remote does not specify a protocol")
138 for interceptor in sync_interceptors:
139 sync_channel = grpc.intercept_channel(sync_channel, interceptor)
141 return sync_channel, details
144class RequestMetadataInterceptorBase:
145 def __init__(
146 self,
147 action_id: Optional[str] = None,
148 tool_invocation_id: Optional[str] = None,
149 correlated_invocations_id: Optional[str] = None,
150 ) -> None:
151 """Appends optional `RequestMetadata` header values to each call.
153 Args:
154 action_id (str): Action identifier to which the request belongs to.
155 tool_invocation_id (str): Identifier for a related group of Actions.
156 correlated_invocations_id (str): Identifier that ties invocations together.
157 """
158 self._action_id = action_id
159 self._tool_invocation_id = tool_invocation_id
160 self._correlated_invocations_id = correlated_invocations_id
162 self.__header_field_name = REQUEST_METADATA_HEADER_NAME
163 self.__header_field_value = self._request_metadata()
165 def _request_metadata(self) -> bytes:
166 """Creates a serialized RequestMetadata entry to attach to a gRPC
167 call header. Arguments should be of type str or None.
168 """
169 request_metadata = remote_execution_pb2.RequestMetadata()
170 request_metadata.tool_details.tool_name = REQUEST_METADATA_TOOL_NAME
171 request_metadata.tool_details.tool_version = REQUEST_METADATA_TOOL_VERSION
173 if self._action_id:
174 request_metadata.action_id = self._action_id
175 if self._tool_invocation_id:
176 request_metadata.tool_invocation_id = self._tool_invocation_id
177 if self._correlated_invocations_id:
178 request_metadata.correlated_invocations_id = self._correlated_invocations_id
180 return request_metadata.SerializeToString()
182 def amend_call_details( # type: ignore[no-untyped-def] # wait for client lib updates here
183 self, client_call_details, grpc_call_details_class: Any
184 ):
185 if client_call_details.metadata is not None:
186 new_metadata = list(client_call_details.metadata)
187 else:
188 new_metadata = []
190 new_metadata.append((self.__header_field_name, self.__header_field_value))
192 class _ClientCallDetails(
193 namedtuple(
194 "_ClientCallDetails",
195 (
196 "method",
197 "timeout",
198 "credentials",
199 "metadata",
200 "wait_for_ready",
201 ),
202 ),
203 grpc_call_details_class, # type: ignore
204 ):
205 pass
207 return _ClientCallDetails(
208 client_call_details.method,
209 client_call_details.timeout,
210 client_call_details.credentials,
211 new_metadata,
212 client_call_details.wait_for_ready,
213 )
216class TimeoutInterceptorBase:
217 def __init__(self, timeout: float) -> None:
218 """Applies a request timeout to each call.
220 Args:
221 timeout (float): Request timeout in seconds.
222 """
223 self._timeout = timeout
225 def amend_call_details( # type: ignore[no-untyped-def] # wait for client lib updates here
226 self, client_call_details, grpc_call_details_class: Any
227 ):
228 # If there are multiple timeouts, apply the shorter timeout (earliest deadline wins)
229 if client_call_details.timeout is not None:
230 new_timeout = min(self._timeout, client_call_details.timeout)
231 else:
232 new_timeout = self._timeout
234 class _ClientCallDetails(
235 namedtuple(
236 "_ClientCallDetails",
237 (
238 "method",
239 "timeout",
240 "credentials",
241 "metadata",
242 "wait_for_ready",
243 ),
244 ),
245 grpc_call_details_class, # type: ignore
246 ):
247 pass
249 return _ClientCallDetails(
250 client_call_details.method,
251 new_timeout,
252 client_call_details.credentials,
253 client_call_details.metadata,
254 client_call_details.wait_for_ready,
255 )
258if TYPE_CHECKING:
259 SyncInterceptorsList = List[
260 Union[
261 grpc.UnaryUnaryClientInterceptor[Any, Any],
262 grpc.UnaryStreamClientInterceptor[Any, Any],
263 grpc.StreamUnaryClientInterceptor[Any, Any],
264 grpc.StreamStreamClientInterceptor[Any, Any],
265 ]
266 ]
269def _create_sync_interceptors(
270 auth_token_loader: Optional[AuthTokenLoader] = None,
271 action_id: Optional[str] = None,
272 tool_invocation_id: Optional[str] = None,
273 correlated_invocations_id: Optional[str] = None,
274 timeout: Optional[float] = None,
275) -> "SyncInterceptorsList":
276 interceptors: "SyncInterceptorsList" = []
278 request_metadata_interceptor = RequestMetadataInterceptorBase(
279 action_id=action_id,
280 tool_invocation_id=tool_invocation_id,
281 correlated_invocations_id=correlated_invocations_id,
282 )
284 interceptors += [
285 SyncUnaryUnaryInterceptor(request_metadata_interceptor),
286 SyncUnaryStreamInterceptor(request_metadata_interceptor),
287 SyncStreamUnaryInterceptor(request_metadata_interceptor),
288 SyncStreamStreamInterceptor(request_metadata_interceptor),
289 ]
291 if auth_token_loader is not None:
292 auth_metadata_client_interceptor = AuthMetadataClientInterceptorBase(auth_token_loader=auth_token_loader)
293 interceptors += [
294 SyncUnaryUnaryInterceptor(auth_metadata_client_interceptor),
295 SyncUnaryStreamInterceptor(auth_metadata_client_interceptor),
296 SyncStreamUnaryInterceptor(auth_metadata_client_interceptor),
297 SyncStreamStreamInterceptor(auth_metadata_client_interceptor),
298 ]
300 if timeout is not None:
301 timeout_interceptor_base = TimeoutInterceptorBase(timeout)
302 interceptors += [
303 SyncUnaryUnaryInterceptor(timeout_interceptor_base),
304 SyncUnaryStreamInterceptor(timeout_interceptor_base),
305 SyncStreamUnaryInterceptor(timeout_interceptor_base),
306 SyncStreamStreamInterceptor(timeout_interceptor_base),
307 ]
309 return interceptors
312def _create_async_interceptors(
313 auth_token_loader: Optional[AuthTokenLoader] = None,
314 action_id: Optional[str] = None,
315 tool_invocation_id: Optional[str] = None,
316 correlated_invocations_id: Optional[str] = None,
317 timeout: Optional[float] = None,
318) -> List[aio.ClientInterceptor]:
319 # FIXME Types not happy... "list" has incompatible type "..."; expected "_PartialStubMustCastOrIgnore"
320 interceptors: List[Any] = []
322 request_metadata_interceptor = RequestMetadataInterceptorBase(
323 action_id=action_id,
324 tool_invocation_id=tool_invocation_id,
325 correlated_invocations_id=correlated_invocations_id,
326 )
328 interceptors += [
329 AsyncUnaryUnaryInterceptor(request_metadata_interceptor),
330 AsyncUnaryStreamInterceptor(request_metadata_interceptor),
331 AsyncStreamUnaryInterceptor(request_metadata_interceptor),
332 AsyncStreamStreamInterceptor(request_metadata_interceptor),
333 ]
335 if auth_token_loader is not None:
336 auth_metadata_client_interceptor = AuthMetadataClientInterceptorBase(auth_token_loader=auth_token_loader)
337 interceptors += [
338 AsyncUnaryUnaryInterceptor(auth_metadata_client_interceptor),
339 AsyncUnaryStreamInterceptor(auth_metadata_client_interceptor),
340 AsyncStreamUnaryInterceptor(auth_metadata_client_interceptor),
341 AsyncStreamStreamInterceptor(auth_metadata_client_interceptor),
342 ]
344 if timeout is not None:
345 timeout_interceptor_base = TimeoutInterceptorBase(timeout)
346 interceptors += [
347 AsyncUnaryUnaryInterceptor(timeout_interceptor_base),
348 AsyncUnaryStreamInterceptor(timeout_interceptor_base),
349 AsyncStreamUnaryInterceptor(timeout_interceptor_base),
350 AsyncStreamStreamInterceptor(timeout_interceptor_base),
351 ]
353 return interceptors