Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/_authentication.py: 78.80%

217 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-22 21:04 +0000

1# Copyright (C) 2018 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15from collections import namedtuple, OrderedDict 

16from datetime import datetime, timedelta, timezone 

17from enum import Enum 

18import functools 

19import json 

20import logging 

21import sys 

22import threading 

23 

24import grpc 

25 

26from buildgrid._exceptions import InvalidArgumentError 

27from buildgrid.settings import AUTH_CACHE_SIZE 

28from buildgrid.server.metrics_utils import ( 

29 ExceptionCounter, 

30 DurationMetric, 

31) 

32from buildgrid.server.metrics_names import ( 

33 INVALID_JWT_COUNT_METRIC_NAME, 

34 JWK_FETCH_TIME_METRIC_NAME, 

35 JWT_DECODE_TIME_METRIC_NAME, 

36 JWT_VALIDATION_TIME_METRIC_NAME, 

37) 

38 

39# Since jwt authentication is not required, make it optional. 

40# If used, but module not imported/found, will raise an exception. 

41try: 

42 import jwt 

43 import requests 

44 

45 # Algorithm classes defined in: https://github.com/jpadilla/pyjwt/blob/master/jwt/algorithms.py 

46 ALGORITHM_TO_PYJWT_CLASS = { 

47 "RSA": jwt.algorithms.RSAAlgorithm, 

48 "EC": jwt.algorithms.ECAlgorithm, 

49 "oct": jwt.algorithms.HMACAlgorithm, 

50 } 

51 

52except ImportError: 

53 pass 

54 

55 

56class AuthMetadataMethod(Enum): 

57 # No authentication: 

58 NONE = 'none' 

59 # JWT based authentication: 

60 JWT = 'jwt' 

61 

62 

63class AuthMetadataAlgorithm(Enum): 

64 # No encryption involved: 

65 UNSPECIFIED = 'unspecified' 

66 # JWT related algorithms: 

67 JWT_ES256 = 'es256' # ECDSA signature algorithm using SHA-256 hash algorithm 

68 JWT_ES384 = 'es384' # ECDSA signature algorithm using SHA-384 hash algorithm 

69 JWT_ES512 = 'es512' # ECDSA signature algorithm using SHA-512 hash algorithm 

70 JWT_HS256 = 'hs256' # HMAC using SHA-256 hash algorithm 

71 JWT_HS384 = 'hs384' # HMAC using SHA-384 hash algorithm 

72 JWT_HS512 = 'hs512' # HMAC using SHA-512 hash algorithm 

73 JWT_PS256 = 'ps256' # RSASSA-PSS using SHA-256 and MGF1 padding with SHA-256 

74 JWT_PS384 = 'ps384' # RSASSA-PSS signature using SHA-384 and MGF1 padding with SHA-384 

75 JWT_PS512 = 'ps512' # RSASSA-PSS signature using SHA-512 and MGF1 padding with SHA-512 

76 JWT_RS256 = 'rs256' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-256 hash algorithm 

77 JWT_RS384 = 'rs384' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-384 hash algorithm 

78 JWT_RS512 = 'rs512' # RSASSA-PKCS1-v1_5 signature algorithm using SHA-512 hash algorithm 

79 

80 

81class AuthContext: 

82 

83 interceptor = None 

84 

85 

86class _InvalidTokenError(Exception): 

87 pass 

88 

89 

90class _ExpiredTokenError(Exception): 

91 pass 

92 

93 

94class _UnboundedTokenError(Exception): 

95 pass 

96 

97 

98def authorize(auth_context): 

99 """RPC method decorator for authorization validations. 

100 

101 This decorator is design to be used together with an :class:`AuthContext` 

102 authorization context holder:: 

103 

104 @authorize(AuthContext) 

105 def Execute(self, request, context): 

106 

107 By default, any request is accepted. Authorization validation can be 

108 activated by setting up a :class:`grpc.ServerInterceptor`:: 

109 

110 AuthContext.interceptor = AuthMetadataServerInterceptor() 

111 

112 Args: 

113 auth_context(AuthContext): Authorization context holder. 

114 """ 

115 def __authorize_decorator(behavior): 

116 """RPC authorization method decorator.""" 

117 _HandlerCallDetails = namedtuple('_HandlerCallDetails', ( 

118 'invocation_metadata', 

119 'method', 

120 )) 

121 

122 @functools.wraps(behavior) 

123 def __authorize_wrapper(self, request, context): 

124 """RPC authorization method wrapper.""" 

125 if auth_context.interceptor is None: 

126 return behavior(self, request, context) 

127 

128 authorized = False 

129 

130 def __continuator(handler_call_details): 

131 nonlocal authorized 

132 authorized = True 

133 

134 details = _HandlerCallDetails(context.invocation_metadata(), 

135 behavior.__name__) 

136 

137 auth_context.interceptor.intercept_service(__continuator, details) 

138 

139 if authorized: 

140 return behavior(self, request, context) 

141 else: 

142 request_args = str(request).replace("\n", "") 

143 logging.getLogger(__name__).info( 

144 "Authentication failed for request=[" 

145 f"{behavior.__name__}({request_args})], " 

146 f"peer=[{context.peer()}]") 

147 

148 context.abort(grpc.StatusCode.UNAUTHENTICATED, 

149 "No valid authorization or authentication provided") 

150 

151 return None 

152 

153 return __authorize_wrapper 

154 

155 return __authorize_decorator 

156 

157 

158class AuthMetadataServerInterceptor(grpc.ServerInterceptor): 

159 

160 __auth_errors = { 

161 'missing-bearer': "Missing authentication header field", 

162 'invalid-bearer': "Invalid authentication header field", 

163 'invalid-token': "Invalid authentication token", 

164 'expired-token': "Expired authentication token", 

165 'unbounded-token': "Unbounded authentication token", 

166 } 

167 

168 def __init__(self, 

169 method, 

170 secret=None, 

171 algorithm=AuthMetadataAlgorithm.UNSPECIFIED, 

172 jwks_url=None, 

173 audience=None, 

174 jwks_fetch_minutes=60): 

175 """Initializes a new :class:`AuthMetadataServerInterceptor`. 

176 

177 Args: 

178 method (AuthMetadataMethod): Type of authorization method. 

179 secret (str): The secret or key to be used for validating request, 

180 depending on `method`. Defaults to ``None``. 

181 algorithm (AuthMetadataAlgorithm): The crytographic algorithm used 

182 to encode `secret`. Defaults to ``UNSPECIFIED``. 

183 jwks_url (str): The url to fetch the JWKs. Either secret or 

184 this field must be specified if the authentication method is JWT. 

185 Defaults to ``None``. 

186 audience (str): The audience used to validate jwt tokens against. 

187 The tokens must have an audience field. 

188 jwks_fetch_minutes (int): The number of minutes to wait before 

189 refreshing the jwks set. Default: 60 minutes. 

190 

191 Raises: 

192 InvalidArgumentError: If `method` is not supported or if `algorithm` 

193 is not supported for the given `method`. 

194 """ 

195 self.__logger = logging.getLogger(__name__) 

196 self.__bearer_cache = OrderedDict() 

197 self.__terminators = {} 

198 self.__validator = None 

199 self.__secret = secret 

200 self.__jwk_update_lock = threading.Lock() 

201 

202 self._audience = audience 

203 self._jwks_url = jwks_url 

204 self._public_keys = {} 

205 self._jwks_fetch_minutes = jwks_fetch_minutes 

206 self._last_fetch_time = 0 

207 self._method = method 

208 self._algorithm = algorithm 

209 

210 if self._method == AuthMetadataMethod.JWT: 

211 if self.__secret and self._jwks_url: 

212 raise RuntimeError( 

213 "Only allowed to set secret or jwks-url. Not both.") 

214 

215 if self._jwks_url: 

216 # Fetch jwk and store 

217 self._get_and_parse_jwks_from_url() 

218 

219 self._check_jwt_support(self._algorithm) 

220 self.__validator = self._validate_jwt_token 

221 

222 for code, message in self.__auth_errors.items(): 

223 self.__terminators[code] = _unary_unary_rpc_terminator(message) 

224 

225 def _error_message_for_call(self, call_details, auth_error_type, exception_details=""): 

226 return ( 

227 f"Authentication error. Rejecting '{str(call_details.method)}' request: " 

228 f"Reason=[{self.__auth_errors[auth_error_type]}], " 

229 f"{exception_details}") 

230 

231 # --- Public API --- 

232 

233 @property 

234 def method(self): 

235 return self._method 

236 

237 @property 

238 def algorithm(self): 

239 return self._algorithm 

240 

241 def intercept_service(self, continuation, handler_call_details): 

242 try: 

243 # Reject requests not carrying a token: 

244 bearer = dict( 

245 handler_call_details.invocation_metadata)['authorization'] 

246 

247 except KeyError: 

248 self.__logger.info( 

249 self._error_message_for_call(handler_call_details, 

250 'missing-bearer')) 

251 return self.__terminators['missing-bearer'] 

252 

253 # Reject requests with malformated bearer: 

254 if not bearer.startswith('Bearer '): 

255 self.__logger.info( 

256 self._error_message_for_call(handler_call_details, 

257 'invalid-bearer')) 

258 return self.__terminators['invalid-bearer'] 

259 

260 try: 

261 # Hit the cache for already validated token: 

262 expiration_time = self.__bearer_cache[bearer] 

263 

264 # Accept request if cached token hasn't expired yet: 

265 if expiration_time >= datetime.utcnow(): 

266 return continuation(handler_call_details) # Accepted 

267 

268 else: 

269 del self.__bearer_cache[bearer] 

270 

271 # Cached token has expired, reject the request: 

272 self.__logger.info( 

273 self._error_message_for_call(handler_call_details, 

274 'expired-token')) 

275 # TODO: Use grpc.Status.details to inform the client of the expiry? 

276 return self.__terminators['expired-token'] 

277 

278 except KeyError: 

279 pass 

280 

281 assert self.__validator is not None 

282 

283 try: 

284 # Decode and validate the new token: 

285 expiration_time = self.__validator(bearer[7:]) 

286 

287 except _InvalidTokenError as e: 

288 self.__logger.info( 

289 self._error_message_for_call(handler_call_details, 

290 'invalid-token', str(e))) 

291 return self.__terminators['invalid-token'] 

292 

293 except _ExpiredTokenError as e: 

294 self.__logger.info( 

295 self._error_message_for_call(handler_call_details, 

296 'expired-token', str(e))) 

297 return self.__terminators['expired-token'] 

298 

299 except _UnboundedTokenError as e: 

300 self.__logger.info( 

301 self._error_message_for_call(handler_call_details, 

302 'unbounded-token', str(e))) 

303 return self.__terminators['unbounded-token'] 

304 

305 # Cache the validated token and store expiration time: 

306 self.__bearer_cache[bearer] = expiration_time 

307 if len(self.__bearer_cache) > AUTH_CACHE_SIZE: 

308 self.__bearer_cache.popitem(last=False) 

309 

310 return continuation(handler_call_details) # Accepted 

311 

312 # --- Private API: JWT --- 

313 

314 def _check_jwt_support(self, algorithm=AuthMetadataAlgorithm.UNSPECIFIED): 

315 """Ensures JWT and possible dependencies are available.""" 

316 if 'jwt' not in sys.modules: 

317 raise InvalidArgumentError( 

318 "JWT authorization method requires PyJWT") 

319 

320 try: 

321 if algorithm != AuthMetadataAlgorithm.UNSPECIFIED: 

322 jwt.register_algorithm(algorithm.value.upper(), None) 

323 

324 except TypeError: 

325 raise InvalidArgumentError( 

326 f"Algorithm not supported for JWT decoding: [{self._algorithm}]" 

327 ) 

328 

329 except ValueError: 

330 pass 

331 

332 jwt_invalid_exceptions = (_ExpiredTokenError, _InvalidTokenError, 

333 _UnboundedTokenError) 

334 

335 @ExceptionCounter(INVALID_JWT_COUNT_METRIC_NAME, 

336 exceptions=jwt_invalid_exceptions) 

337 @DurationMetric(JWT_VALIDATION_TIME_METRIC_NAME) 

338 def _validate_jwt_token(self, token): 

339 """Validates a JWT token and returns its expiry date.""" 

340 if self._algorithm != AuthMetadataAlgorithm.UNSPECIFIED: 

341 algorithms = [self._algorithm.value.upper()] 

342 else: 

343 algorithms = None 

344 

345 try: 

346 if self.__secret: 

347 with DurationMetric(JWT_DECODE_TIME_METRIC_NAME): 

348 payload = jwt.decode(token, 

349 self.__secret, 

350 algorithms=algorithms) 

351 if self._jwks_url: 

352 self.__logger.debug( 

353 f"Validating token with JWKS fetched from url: [{self._jwks_url}]" 

354 ) 

355 # Refetch the jwks if the current time 

356 # is greater than the last fetch time plus the specified delta. 

357 # The first thread that is able to acquire the lock will be the one that updates the set. 

358 # pylint: disable=consider-using-with 

359 if (self._last_fetch_time + 

360 timedelta(minutes=self._jwks_fetch_minutes) <= 

361 datetime.now(tz=timezone.utc) 

362 ) and self.__jwk_update_lock.acquire(False): 

363 try: 

364 self._get_and_parse_jwks_from_url() 

365 except Exception: 

366 self.__logger.exception( 

367 "Exception thrown while fetching jwk. \ 

368 Continuing with request using previously cached keys." 

369 ) 

370 # Continue if an exception occurred. 

371 finally: 

372 self.__jwk_update_lock.release() 

373 

374 kid = jwt.get_unverified_header(token).get('kid') 

375 if kid is None: 

376 raise RuntimeError("JWT token is missing kid.") 

377 key = self._public_keys.get(kid) 

378 if key is None: 

379 # Try to update JWKs, if unable to grab lock (currently ongoing refresh process) 

380 # then block until we can obtain and try again (see "else" block). 

381 

382 # pylint: disable=consider-using-with 

383 if self.__jwk_update_lock.acquire(False): 

384 try: 

385 self._get_and_parse_jwks_from_url() 

386 except Exception: 

387 self.__logger.exception( 

388 "Exception thrown while fetching jwk. \ 

389 Continuing with request using previously cached keys." 

390 ) 

391 # Continue if an exception occurred. 

392 finally: 

393 self.__jwk_update_lock.release() 

394 else: 

395 # Wait until lock can be acquired (update has completed). 

396 with self.__jwk_update_lock: 

397 pass 

398 key = self._public_keys.get(kid) 

399 if key is None: 

400 raise _InvalidTokenError( 

401 f"No public key found for token with kid: {kid}") 

402 

403 with DurationMetric(JWT_DECODE_TIME_METRIC_NAME): 

404 payload = jwt.decode(token, 

405 key, 

406 algorithms=algorithms, 

407 audience=self._audience) 

408 self.__logger.debug( 

409 f"JWT validated from JWK set fetched from: [{self._jwks_url}]" 

410 ) 

411 

412 except jwt.exceptions.ExpiredSignatureError as e: 

413 raise _ExpiredTokenError(e) 

414 

415 except jwt.exceptions.InvalidTokenError as e: 

416 raise _InvalidTokenError(e) 

417 

418 if 'exp' not in payload or not isinstance(payload['exp'], int): 

419 raise _UnboundedTokenError("Missing 'exp' in payload") 

420 

421 return datetime.utcfromtimestamp(payload['exp']) 

422 

423 @DurationMetric(JWK_FETCH_TIME_METRIC_NAME) 

424 def _get_and_parse_jwks_from_url(self): 

425 """ Get JWKs from url, and parse JSON web key set. """ 

426 # pyJWT 2.0 will support these operations, once merged: 

427 # https://github.com/jpadilla/pyjwt/pull/470/files 

428 # 

429 # jwks_client = PyJWKClient(self._jwks_url) 

430 # signing_key = jwks_client.get_signing_key_from_jwt(token) 

431 # payload = jwt.decode(token, signing_key.key, algorithms=algorithms) 

432 try: 

433 self.__logger.info( 

434 f"Sending request to fetch JWKs from provided url: [{self._jwks_url}]" 

435 ) 

436 data = requests.get(self._jwks_url) 

437 except requests.exceptions.RequestException as e_thrown: 

438 self.__logger.exception( 

439 f"Error sending request to: [{self._jwks_url}]") 

440 raise e_thrown 

441 

442 try: 

443 jwks = data.json() 

444 temp_keys = {} 

445 for jwk in jwks.get('keys'): 

446 kid = jwk.get('kid') 

447 kty = jwk.get('kty') 

448 if kid is None or kty is None: 

449 raise RuntimeError( 

450 f"A key in the JWKs fetched from [{self._jwks_url}], \ 

451 doesn't include one of the required properties: kid, or kty." 

452 ) 

453 alg_class = ALGORITHM_TO_PYJWT_CLASS.get(kty) 

454 if alg_class is None: 

455 raise RuntimeError( 

456 f"Unsupported algorithm type provided by \ 

457 JWKs: [{kty}], fetched from [{self._jwks_url}]" 

458 ) 

459 temp_keys[kid] = alg_class.from_jwk(json.dumps(jwk)) 

460 except (AttributeError, ValueError) as e_thrown: 

461 self.__logger.exception(f"Error parsing input: [{jwks}], \ 

462 fetched from [{self._jwks_url}]") 

463 raise e_thrown 

464 

465 if not temp_keys: 

466 self.__logger.error( 

467 f"No public keys returned from url: [{self._jwks_url}]") 

468 # If there are no public keys, raise an exception. 

469 if not self._public_keys: 

470 raise RuntimeError( 

471 "Error fetching public keys, non-existing public keys.") 

472 

473 self.__logger.info( 

474 f"Unable to fetch proper JWKs from [{self._jwks_url}], \ 

475 leaving existing set last fetched at [{self._last_fetch_time}] unchanged." 

476 ) 

477 return 

478 

479 # Set _last_fetch_time, this will be used to check 

480 # whether to refetch the token after a certain amount of time. 

481 self._last_fetch_time = datetime.now(tz=timezone.utc) 

482 

483 self.__logger.info( 

484 f"Replacing existing JWKs set, with one fetched at time: \ 

485 [{self._last_fetch_time}] from url: [{self._jwks_url}" 

486 ) 

487 

488 # Set the class member variable to the new keys. 

489 self._public_keys = temp_keys 

490 

491 

492def _unary_unary_rpc_terminator(details): 

493 def terminate(ignored_request, context): 

494 context.abort(grpc.StatusCode.UNAUTHENTICATED, details) 

495 

496 return grpc.unary_unary_rpc_method_handler(terminate)