Coverage for /builds/BuildGrid/buildgrid/buildgrid/client/channel.py: 80.17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

116 statements  

1# Copyright (C) 2019 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import namedtuple 

16from urllib.parse import urlparse 

17from typing import Any, List, Optional, TYPE_CHECKING, Union 

18 

19import grpc 

20from grpc import aio # type: ignore 

21 

22from buildgrid.client.authentication import AsyncAuthMetadataClientInterceptor, AuthMetadataClientInterceptor 

23from buildgrid.client.authentication import load_channel_authorization_token 

24from buildgrid.client.authentication import load_tls_channel_credentials 

25from buildgrid._exceptions import InvalidArgumentError 

26from buildgrid._protos.build.bazel.remote.execution.v2 import remote_execution_pb2 

27from buildgrid.settings import REQUEST_METADATA_HEADER_NAME 

28from buildgrid.settings import REQUEST_METADATA_TOOL_NAME, REQUEST_METADATA_TOOL_VERSION 

29from buildgrid.utils import insecure_uri_schemes, secure_uri_schemes 

30 

31 

32def setup_channel(remote_url: str, 

33 auth_token: Optional[str]=None, client_key: Optional[str]=None, 

34 client_cert: Optional[str]=None, server_cert: Optional[str]=None, 

35 action_id: Optional[str]=None, tool_invocation_id: Optional[str]=None, 

36 correlated_invocations_id: Optional[str]=None, asynchronous: bool=False): 

37 """Creates a new gRPC client communication chanel. 

38 

39 If `remote_url` does not point to a socket and does not specify a 

40 port number, defaults 50051. 

41 

42 Args: 

43 remote_url (str): URL for the remote, including protocol and, 

44 if not a Unix domain socket, a port. 

45 auth_token (str): Authorization token file path. 

46 server_cert(str): TLS certificate chain file path. 

47 client_key (str): TLS root certificate file path. 

48 client_cert (str): TLS private key file path. 

49 action_id (str): Action identifier to which the request belongs to. 

50 tool_invocation_id (str): Identifier for a related group of Actions. 

51 correlated_invocations_id (str): Identifier that ties invocations together. 

52 

53 Returns: 

54 Channel: Client Channel to be used in order to access the server 

55 at `remote_url`. 

56 

57 Raises: 

58 InvalidArgumentError: On any input parsing error. 

59 """ 

60 url = urlparse(remote_url) 

61 

62 url_is_socket = (url.scheme == 'unix') 

63 if url_is_socket: 

64 remote = remote_url 

65 else: 

66 remote = f'{url.hostname}:{url.port or 50051}' 

67 

68 details = None, None, None 

69 credentials_provided = any((server_cert, client_cert, client_key)) 

70 

71 if asynchronous: 

72 async_interceptors = _create_async_interceptors( 

73 auth_token=auth_token, 

74 action_id=action_id, 

75 tool_invocation_id=tool_invocation_id, 

76 correlated_invocations_id=correlated_invocations_id) 

77 else: 

78 sync_interceptors = _create_sync_interceptors( 

79 auth_token=auth_token, 

80 action_id=action_id, 

81 tool_invocation_id=tool_invocation_id, 

82 correlated_invocations_id=correlated_invocations_id) 

83 

84 if url.scheme in insecure_uri_schemes or (url_is_socket and not credentials_provided): 

85 if asynchronous: 

86 channel = aio.insecure_channel(remote, interceptors=async_interceptors) 

87 else: 

88 channel = grpc.insecure_channel(remote) 

89 elif url.scheme in secure_uri_schemes or (url_is_socket and credentials_provided): 

90 credentials, details = load_tls_channel_credentials(client_key, client_cert, server_cert) 

91 if not credentials: 

92 raise InvalidArgumentError("Given TLS details (or defaults) could be loaded") 

93 

94 if asynchronous: 

95 channel = aio.secure_channel(remote, credentials, interceptors=async_interceptors) 

96 else: 

97 channel = grpc.secure_channel(remote, credentials) 

98 

99 else: 

100 raise InvalidArgumentError("Given remote does not specify a protocol") 

101 

102 if not asynchronous: 

103 for interceptor in sync_interceptors: 

104 channel = grpc.intercept_channel(channel, interceptor) 

105 

106 return channel, details 

107 

108 

109class RequestMetadataInterceptorBase: 

110 

111 def __init__(self, action_id: Optional[str]=None, 

112 tool_invocation_id: Optional[str]=None, 

113 correlated_invocations_id: Optional[str]=None): 

114 """Appends optional `RequestMetadata` header values to each call. 

115 

116 Args: 

117 action_id (str): Action identifier to which the request belongs to. 

118 tool_invocation_id (str): Identifier for a related group of Actions. 

119 correlated_invocations_id (str): Identifier that ties invocations together. 

120 """ 

121 self._action_id = action_id 

122 self._tool_invocation_id = tool_invocation_id 

123 self._correlated_invocations_id = correlated_invocations_id 

124 

125 self.__header_field_name = REQUEST_METADATA_HEADER_NAME 

126 self.__header_field_value = self._request_metadata() 

127 

128 def _request_metadata(self): 

129 """Creates a serialized RequestMetadata entry to attach to a gRPC 

130 call header. Arguments should be of type str or None. 

131 """ 

132 request_metadata = remote_execution_pb2.RequestMetadata() 

133 request_metadata.tool_details.tool_name = REQUEST_METADATA_TOOL_NAME 

134 request_metadata.tool_details.tool_version = REQUEST_METADATA_TOOL_VERSION 

135 

136 if self._action_id: 

137 request_metadata.action_id = self._action_id 

138 if self._tool_invocation_id: 

139 request_metadata.tool_invocation_id = self._tool_invocation_id 

140 if self._correlated_invocations_id: 

141 request_metadata.correlated_invocations_id = self._correlated_invocations_id 

142 

143 return request_metadata.SerializeToString() 

144 

145 def _amend_call_details(self, client_call_details, grpc_call_details_class): 

146 if client_call_details.metadata is not None: 

147 new_metadata = list(client_call_details.metadata) 

148 else: 

149 new_metadata = [] 

150 

151 new_metadata.append((self.__header_field_name, 

152 self.__header_field_value)) 

153 

154 class _ClientCallDetails( 

155 namedtuple('_ClientCallDetails', 

156 ('method', 'timeout', 'credentials', 'metadata', 'wait_for_ready',)), 

157 grpc_call_details_class): 

158 pass 

159 

160 return _ClientCallDetails(client_call_details.method, 

161 client_call_details.timeout, 

162 client_call_details.credentials, 

163 new_metadata, 

164 client_call_details.wait_for_ready) 

165 

166 

167class RequestMetadataInterceptor(RequestMetadataInterceptorBase, 

168 grpc.UnaryUnaryClientInterceptor, 

169 grpc.UnaryStreamClientInterceptor, 

170 grpc.StreamUnaryClientInterceptor, 

171 grpc.StreamStreamClientInterceptor): 

172 

173 def __init__(self, action_id: Optional[str]=None, 

174 tool_invocation_id: Optional[str]=None, 

175 correlated_invocations_id: Optional[str]=None): 

176 RequestMetadataInterceptorBase.__init__( 

177 self, 

178 action_id=action_id, 

179 tool_invocation_id=tool_invocation_id, 

180 correlated_invocations_id=correlated_invocations_id 

181 ) 

182 

183 def intercept_unary_unary(self, continuation, client_call_details, request): 

184 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

185 

186 return continuation(new_details, request) 

187 

188 def intercept_unary_stream(self, continuation, client_call_details, request): 

189 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

190 

191 return continuation(new_details, request) 

192 

193 def intercept_stream_unary(self, continuation, client_call_details, request_iterator): 

194 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

195 

196 return continuation(new_details, request_iterator) 

197 

198 def intercept_stream_stream(self, continuation, client_call_details, request_iterator): 

199 new_details = self._amend_call_details(client_call_details, grpc.ClientCallDetails) 

200 

201 return continuation(new_details, request_iterator) 

202 

203 

204class AsyncRequestMetadataInterceptor(RequestMetadataInterceptorBase, 

205 aio.UnaryUnaryClientInterceptor, 

206 aio.UnaryStreamClientInterceptor, 

207 aio.StreamUnaryClientInterceptor, 

208 aio.StreamStreamClientInterceptor): 

209 

210 def __init__(self, action_id: Optional[str]=None, 

211 tool_invocation_id: Optional[str]=None, 

212 correlated_invocations_id: Optional[str]=None): 

213 RequestMetadataInterceptorBase.__init__( 

214 self, 

215 action_id=action_id, 

216 tool_invocation_id=tool_invocation_id, 

217 correlated_invocations_id=correlated_invocations_id 

218 ) 

219 

220 async def intercept_unary_unary(self, continuation, client_call_details, request): 

221 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

222 

223 return await continuation(new_details, request) 

224 

225 async def intercept_unary_stream(self, continuation, client_call_details, request): 

226 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

227 

228 return await continuation(new_details, request) 

229 

230 async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): 

231 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

232 

233 return await continuation(new_details, request_iterator) 

234 

235 async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): 

236 new_details = self._amend_call_details(client_call_details, aio.ClientCallDetails) 

237 

238 return await continuation(new_details, request_iterator) 

239 

240 

241if TYPE_CHECKING: 

242 # pylint: disable=unsubscriptable-object 

243 SyncInterceptorsList = List[ 

244 Union[ 

245 grpc.UnaryUnaryClientInterceptor[Any, Any], 

246 grpc.UnaryStreamClientInterceptor[Any, Any], 

247 grpc.StreamUnaryClientInterceptor[Any, Any], 

248 grpc.StreamStreamClientInterceptor[Any, Any] 

249 ] 

250 ] 

251 

252 

253def _create_sync_interceptors(auth_token: Optional[str]=None, action_id: Optional[str]=None, 

254 tool_invocation_id: Optional[str]=None, 

255 correlated_invocations_id: Optional[str]=None) -> 'SyncInterceptorsList': 

256 interceptors: 'SyncInterceptorsList' = [] 

257 interceptors.append(RequestMetadataInterceptor( 

258 action_id=action_id, 

259 tool_invocation_id=tool_invocation_id, 

260 correlated_invocations_id=correlated_invocations_id)) 

261 

262 if auth_token is not None: 

263 token = load_channel_authorization_token(auth_token) 

264 if not token: 

265 raise InvalidArgumentError("Given authorization token could be loaded") 

266 interceptors.append(AuthMetadataClientInterceptor(auth_token=token)) 

267 

268 return interceptors 

269 

270 

271def _create_async_interceptors(auth_token: Optional[str]=None, action_id: Optional[str]=None, 

272 tool_invocation_id: Optional[str]=None, 

273 correlated_invocations_id: Optional[str]=None) -> List[aio.ClientInterceptor]: 

274 interceptors: List[aio.ClientInterceptor] = [] 

275 interceptors.append(AsyncRequestMetadataInterceptor( 

276 action_id=action_id, 

277 tool_invocation_id=tool_invocation_id, 

278 correlated_invocations_id=correlated_invocations_id)) 

279 

280 if auth_token is not None: 

281 token = load_channel_authorization_token(auth_token) 

282 if not token: 

283 raise InvalidArgumentError("Given authorization token could be loaded") 

284 interceptors.append(AsyncAuthMetadataClientInterceptor(auth_token=token)) 

285 

286 return interceptors