Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/_authentication.py: 78.80%
217 statements
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-22 21:04 +0000
« prev ^ index » next coverage.py v6.4.1, created at 2022-06-22 21:04 +0000
1# Copyright (C) 2018 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15from collections import namedtuple, OrderedDict
16from datetime import datetime, timedelta, timezone
17from enum import Enum
18import functools
19import json
20import logging
21import sys
22import threading
24import grpc
26from buildgrid._exceptions import InvalidArgumentError
27from buildgrid.settings import AUTH_CACHE_SIZE
28from buildgrid.server.metrics_utils import (
29 ExceptionCounter,
30 DurationMetric,
31)
32from buildgrid.server.metrics_names import (
33 INVALID_JWT_COUNT_METRIC_NAME,
34 JWK_FETCH_TIME_METRIC_NAME,
35 JWT_DECODE_TIME_METRIC_NAME,
36 JWT_VALIDATION_TIME_METRIC_NAME,
37)
39# Since jwt authentication is not required, make it optional.
40# If used, but module not imported/found, will raise an exception.
41try:
42 import jwt
43 import requests
45 # Algorithm classes defined in: https://github.com/jpadilla/pyjwt/blob/master/jwt/algorithms.py
46 ALGORITHM_TO_PYJWT_CLASS = {
47 "RSA": jwt.algorithms.RSAAlgorithm,
48 "EC": jwt.algorithms.ECAlgorithm,
49 "oct": jwt.algorithms.HMACAlgorithm,
50 }
52except ImportError:
53 pass
56class AuthMetadataMethod(Enum):
57 # No authentication:
58 NONE = 'none'
59 # JWT based authentication:
60 JWT = 'jwt'
63class AuthMetadataAlgorithm(Enum):
64 # No encryption involved:
65 UNSPECIFIED = 'unspecified'
66 # JWT related algorithms:
67 JWT_ES256 = 'es256' # ECDSA signature algorithm using SHA-256 hash algorithm
68 JWT_ES384 = 'es384' # ECDSA signature algorithm using SHA-384 hash algorithm
69 JWT_ES512 = 'es512' # ECDSA signature algorithm using SHA-512 hash algorithm
70 JWT_HS256 = 'hs256' # HMAC using SHA-256 hash algorithm
71 JWT_HS384 = 'hs384' # HMAC using SHA-384 hash algorithm
72 JWT_HS512 = 'hs512' # HMAC using SHA-512 hash algorithm
73 JWT_PS256 = 'ps256' # RSASSA-PSS using SHA-256 and MGF1 padding with SHA-256
74 JWT_PS384 = 'ps384' # RSASSA-PSS signature using SHA-384 and MGF1 padding with SHA-384
75 JWT_PS512 = 'ps512' # RSASSA-PSS signature using SHA-512 and MGF1 padding with SHA-512
76 JWT_RS256 = 'rs256' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-256 hash algorithm
77 JWT_RS384 = 'rs384' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-384 hash algorithm
78 JWT_RS512 = 'rs512' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-512 hash algorithm
81class AuthContext:
83 interceptor = None
86class _InvalidTokenError(Exception):
87 pass
90class _ExpiredTokenError(Exception):
91 pass
94class _UnboundedTokenError(Exception):
95 pass
98def authorize(auth_context):
99 """RPC method decorator for authorization validations.
101 This decorator is design to be used together with an :class:`AuthContext`
102 authorization context holder::
104 @authorize(AuthContext)
105 def Execute(self, request, context):
107 By default, any request is accepted. Authorization validation can be
108 activated by setting up a :class:`grpc.ServerInterceptor`::
110 AuthContext.interceptor = AuthMetadataServerInterceptor()
112 Args:
113 auth_context(AuthContext): Authorization context holder.
114 """
115 def __authorize_decorator(behavior):
116 """RPC authorization method decorator."""
117 _HandlerCallDetails = namedtuple('_HandlerCallDetails', (
118 'invocation_metadata',
119 'method',
120 ))
122 @functools.wraps(behavior)
123 def __authorize_wrapper(self, request, context):
124 """RPC authorization method wrapper."""
125 if auth_context.interceptor is None:
126 return behavior(self, request, context)
128 authorized = False
130 def __continuator(handler_call_details):
131 nonlocal authorized
132 authorized = True
134 details = _HandlerCallDetails(context.invocation_metadata(),
135 behavior.__name__)
137 auth_context.interceptor.intercept_service(__continuator, details)
139 if authorized:
140 return behavior(self, request, context)
141 else:
142 request_args = str(request).replace("\n", "")
143 logging.getLogger(__name__).info(
144 "Authentication failed for request=["
145 f"{behavior.__name__}({request_args})], "
146 f"peer=[{context.peer()}]")
148 context.abort(grpc.StatusCode.UNAUTHENTICATED,
149 "No valid authorization or authentication provided")
151 return None
153 return __authorize_wrapper
155 return __authorize_decorator
158class AuthMetadataServerInterceptor(grpc.ServerInterceptor):
160 __auth_errors = {
161 'missing-bearer': "Missing authentication header field",
162 'invalid-bearer': "Invalid authentication header field",
163 'invalid-token': "Invalid authentication token",
164 'expired-token': "Expired authentication token",
165 'unbounded-token': "Unbounded authentication token",
166 }
168 def __init__(self,
169 method,
170 secret=None,
171 algorithm=AuthMetadataAlgorithm.UNSPECIFIED,
172 jwks_url=None,
173 audience=None,
174 jwks_fetch_minutes=60):
175 """Initializes a new :class:`AuthMetadataServerInterceptor`.
177 Args:
178 method (AuthMetadataMethod): Type of authorization method.
179 secret (str): The secret or key to be used for validating request,
180 depending on `method`. Defaults to ``None``.
181 algorithm (AuthMetadataAlgorithm): The crytographic algorithm used
182 to encode `secret`. Defaults to ``UNSPECIFIED``.
183 jwks_url (str): The url to fetch the JWKs. Either secret or
184 this field must be specified if the authentication method is JWT.
185 Defaults to ``None``.
186 audience (str): The audience used to validate jwt tokens against.
187 The tokens must have an audience field.
188 jwks_fetch_minutes (int): The number of minutes to wait before
189 refreshing the jwks set. Default: 60 minutes.
191 Raises:
192 InvalidArgumentError: If `method` is not supported or if `algorithm`
193 is not supported for the given `method`.
194 """
195 self.__logger = logging.getLogger(__name__)
196 self.__bearer_cache = OrderedDict()
197 self.__terminators = {}
198 self.__validator = None
199 self.__secret = secret
200 self.__jwk_update_lock = threading.Lock()
202 self._audience = audience
203 self._jwks_url = jwks_url
204 self._public_keys = {}
205 self._jwks_fetch_minutes = jwks_fetch_minutes
206 self._last_fetch_time = 0
207 self._method = method
208 self._algorithm = algorithm
210 if self._method == AuthMetadataMethod.JWT:
211 if self.__secret and self._jwks_url:
212 raise RuntimeError(
213 "Only allowed to set secret or jwks-url. Not both.")
215 if self._jwks_url:
216 # Fetch jwk and store
217 self._get_and_parse_jwks_from_url()
219 self._check_jwt_support(self._algorithm)
220 self.__validator = self._validate_jwt_token
222 for code, message in self.__auth_errors.items():
223 self.__terminators[code] = _unary_unary_rpc_terminator(message)
225 def _error_message_for_call(self, call_details, auth_error_type, exception_details=""):
226 return (
227 f"Authentication error. Rejecting '{str(call_details.method)}' request: "
228 f"Reason=[{self.__auth_errors[auth_error_type]}], "
229 f"{exception_details}")
231 # --- Public API ---
233 @property
234 def method(self):
235 return self._method
237 @property
238 def algorithm(self):
239 return self._algorithm
241 def intercept_service(self, continuation, handler_call_details):
242 try:
243 # Reject requests not carrying a token:
244 bearer = dict(
245 handler_call_details.invocation_metadata)['authorization']
247 except KeyError:
248 self.__logger.info(
249 self._error_message_for_call(handler_call_details,
250 'missing-bearer'))
251 return self.__terminators['missing-bearer']
253 # Reject requests with malformated bearer:
254 if not bearer.startswith('Bearer '):
255 self.__logger.info(
256 self._error_message_for_call(handler_call_details,
257 'invalid-bearer'))
258 return self.__terminators['invalid-bearer']
260 try:
261 # Hit the cache for already validated token:
262 expiration_time = self.__bearer_cache[bearer]
264 # Accept request if cached token hasn't expired yet:
265 if expiration_time >= datetime.utcnow():
266 return continuation(handler_call_details) # Accepted
268 else:
269 del self.__bearer_cache[bearer]
271 # Cached token has expired, reject the request:
272 self.__logger.info(
273 self._error_message_for_call(handler_call_details,
274 'expired-token'))
275 # TODO: Use grpc.Status.details to inform the client of the expiry?
276 return self.__terminators['expired-token']
278 except KeyError:
279 pass
281 assert self.__validator is not None
283 try:
284 # Decode and validate the new token:
285 expiration_time = self.__validator(bearer[7:])
287 except _InvalidTokenError as e:
288 self.__logger.info(
289 self._error_message_for_call(handler_call_details,
290 'invalid-token', str(e)))
291 return self.__terminators['invalid-token']
293 except _ExpiredTokenError as e:
294 self.__logger.info(
295 self._error_message_for_call(handler_call_details,
296 'expired-token', str(e)))
297 return self.__terminators['expired-token']
299 except _UnboundedTokenError as e:
300 self.__logger.info(
301 self._error_message_for_call(handler_call_details,
302 'unbounded-token', str(e)))
303 return self.__terminators['unbounded-token']
305 # Cache the validated token and store expiration time:
306 self.__bearer_cache[bearer] = expiration_time
307 if len(self.__bearer_cache) > AUTH_CACHE_SIZE:
308 self.__bearer_cache.popitem(last=False)
310 return continuation(handler_call_details) # Accepted
312 # --- Private API: JWT ---
314 def _check_jwt_support(self, algorithm=AuthMetadataAlgorithm.UNSPECIFIED):
315 """Ensures JWT and possible dependencies are available."""
316 if 'jwt' not in sys.modules:
317 raise InvalidArgumentError(
318 "JWT authorization method requires PyJWT")
320 try:
321 if algorithm != AuthMetadataAlgorithm.UNSPECIFIED:
322 jwt.register_algorithm(algorithm.value.upper(), None)
324 except TypeError:
325 raise InvalidArgumentError(
326 f"Algorithm not supported for JWT decoding: [{self._algorithm}]"
327 )
329 except ValueError:
330 pass
332 jwt_invalid_exceptions = (_ExpiredTokenError, _InvalidTokenError,
333 _UnboundedTokenError)
335 @ExceptionCounter(INVALID_JWT_COUNT_METRIC_NAME,
336 exceptions=jwt_invalid_exceptions)
337 @DurationMetric(JWT_VALIDATION_TIME_METRIC_NAME)
338 def _validate_jwt_token(self, token):
339 """Validates a JWT token and returns its expiry date."""
340 if self._algorithm != AuthMetadataAlgorithm.UNSPECIFIED:
341 algorithms = [self._algorithm.value.upper()]
342 else:
343 algorithms = None
345 try:
346 if self.__secret:
347 with DurationMetric(JWT_DECODE_TIME_METRIC_NAME):
348 payload = jwt.decode(token,
349 self.__secret,
350 algorithms=algorithms)
351 if self._jwks_url:
352 self.__logger.debug(
353 f"Validating token with JWKS fetched from url: [{self._jwks_url}]"
354 )
355 # Refetch the jwks if the current time
356 # is greater than the last fetch time plus the specified delta.
357 # The first thread that is able to acquire the lock will be the one that updates the set.
358 # pylint: disable=consider-using-with
359 if (self._last_fetch_time +
360 timedelta(minutes=self._jwks_fetch_minutes) <=
361 datetime.now(tz=timezone.utc)
362 ) and self.__jwk_update_lock.acquire(False):
363 try:
364 self._get_and_parse_jwks_from_url()
365 except Exception:
366 self.__logger.exception(
367 "Exception thrown while fetching jwk. \
368 Continuing with request using previously cached keys."
369 )
370 # Continue if an exception occurred.
371 finally:
372 self.__jwk_update_lock.release()
374 kid = jwt.get_unverified_header(token).get('kid')
375 if kid is None:
376 raise RuntimeError("JWT token is missing kid.")
377 key = self._public_keys.get(kid)
378 if key is None:
379 # Try to update JWKs, if unable to grab lock (currently ongoing refresh process)
380 # then block until we can obtain and try again (see "else" block).
382 # pylint: disable=consider-using-with
383 if self.__jwk_update_lock.acquire(False):
384 try:
385 self._get_and_parse_jwks_from_url()
386 except Exception:
387 self.__logger.exception(
388 "Exception thrown while fetching jwk. \
389 Continuing with request using previously cached keys."
390 )
391 # Continue if an exception occurred.
392 finally:
393 self.__jwk_update_lock.release()
394 else:
395 # Wait until lock can be acquired (update has completed).
396 with self.__jwk_update_lock:
397 pass
398 key = self._public_keys.get(kid)
399 if key is None:
400 raise _InvalidTokenError(
401 f"No public key found for token with kid: {kid}")
403 with DurationMetric(JWT_DECODE_TIME_METRIC_NAME):
404 payload = jwt.decode(token,
405 key,
406 algorithms=algorithms,
407 audience=self._audience)
408 self.__logger.debug(
409 f"JWT validated from JWK set fetched from: [{self._jwks_url}]"
410 )
412 except jwt.exceptions.ExpiredSignatureError as e:
413 raise _ExpiredTokenError(e)
415 except jwt.exceptions.InvalidTokenError as e:
416 raise _InvalidTokenError(e)
418 if 'exp' not in payload or not isinstance(payload['exp'], int):
419 raise _UnboundedTokenError("Missing 'exp' in payload")
421 return datetime.utcfromtimestamp(payload['exp'])
423 @DurationMetric(JWK_FETCH_TIME_METRIC_NAME)
424 def _get_and_parse_jwks_from_url(self):
425 """ Get JWKs from url, and parse JSON web key set. """
426 # pyJWT 2.0 will support these operations, once merged:
427 # https://github.com/jpadilla/pyjwt/pull/470/files
428 #
429 # jwks_client = PyJWKClient(self._jwks_url)
430 # signing_key = jwks_client.get_signing_key_from_jwt(token)
431 # payload = jwt.decode(token, signing_key.key, algorithms=algorithms)
432 try:
433 self.__logger.info(
434 f"Sending request to fetch JWKs from provided url: [{self._jwks_url}]"
435 )
436 data = requests.get(self._jwks_url)
437 except requests.exceptions.RequestException as e_thrown:
438 self.__logger.exception(
439 f"Error sending request to: [{self._jwks_url}]")
440 raise e_thrown
442 try:
443 jwks = data.json()
444 temp_keys = {}
445 for jwk in jwks.get('keys'):
446 kid = jwk.get('kid')
447 kty = jwk.get('kty')
448 if kid is None or kty is None:
449 raise RuntimeError(
450 f"A key in the JWKs fetched from [{self._jwks_url}], \
451 doesn't include one of the required properties: kid, or kty."
452 )
453 alg_class = ALGORITHM_TO_PYJWT_CLASS.get(kty)
454 if alg_class is None:
455 raise RuntimeError(
456 f"Unsupported algorithm type provided by \
457 JWKs: [{kty}], fetched from [{self._jwks_url}]"
458 )
459 temp_keys[kid] = alg_class.from_jwk(json.dumps(jwk))
460 except (AttributeError, ValueError) as e_thrown:
461 self.__logger.exception(f"Error parsing input: [{jwks}], \
462 fetched from [{self._jwks_url}]")
463 raise e_thrown
465 if not temp_keys:
466 self.__logger.error(
467 f"No public keys returned from url: [{self._jwks_url}]")
468 # If there are no public keys, raise an exception.
469 if not self._public_keys:
470 raise RuntimeError(
471 "Error fetching public keys, non-existing public keys.")
473 self.__logger.info(
474 f"Unable to fetch proper JWKs from [{self._jwks_url}], \
475 leaving existing set last fetched at [{self._last_fetch_time}] unchanged."
476 )
477 return
479 # Set _last_fetch_time, this will be used to check
480 # whether to refetch the token after a certain amount of time.
481 self._last_fetch_time = datetime.now(tz=timezone.utc)
483 self.__logger.info(
484 f"Replacing existing JWKs set, with one fetched at time: \
485 [{self._last_fetch_time}] from url: [{self._jwks_url}"
486 )
488 # Set the class member variable to the new keys.
489 self._public_keys = temp_keys
492def _unary_unary_rpc_terminator(details):
493 def terminate(ignored_request, context):
494 context.abort(grpc.StatusCode.UNAUTHENTICATED, details)
496 return grpc.unary_unary_rpc_method_handler(terminate)