Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/auth/manager.py: 87.26%

157 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-06-11 15:37 +0000

1# Copyright (C) 2023 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import functools 

16import inspect 

17import logging 

18import sys 

19from abc import ABC, abstractmethod 

20from contextlib import contextmanager 

21from contextvars import ContextVar 

22from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast 

23 

24import grpc 

25 

26from buildgrid._exceptions import InvalidArgumentError 

27from buildgrid.server.auth.config import InstanceAuthorizationConfig 

28from buildgrid.server.auth.enums import AuthMetadataAlgorithm 

29from buildgrid.server.auth.exceptions import ( 

30 AuthError, 

31 ExpiredTokenError, 

32 InvalidAuthorizationHeaderError, 

33 InvalidTokenError, 

34 MissingTokenError, 

35 UnboundedTokenError, 

36 UnexpectedTokenParsingError, 

37) 

38from buildgrid.server.instance import current_instance 

39from buildgrid.server.metrics_names import INVALID_JWT_COUNT_METRIC_NAME, JWT_VALIDATION_TIME_METRIC_NAME 

40from buildgrid.server.metrics_utils import DurationMetric, ExceptionCounter 

41from buildgrid.settings import AUTH_CACHE_SIZE 

42 

43LOGGER = logging.getLogger(__name__) 

44 

45 

46# Since jwt authorization is not required, make it optional. 

47# If used, but module not imported/found, will raise an exception. 

48try: 

49 import jwt 

50 

51 AlgorithmType = Union[ 

52 Type[jwt.algorithms.RSAAlgorithm], Type[jwt.algorithms.ECAlgorithm], Type[jwt.algorithms.HMACAlgorithm] 

53 ] 

54 

55 # Algorithm classes defined in: https://github.com/jpadilla/pyjwt/blob/master/jwt/algorithms.py 

56 ALGORITHM_TO_PYJWT_CLASS: Dict[str, AlgorithmType] = { 

57 "RSA": jwt.algorithms.RSAAlgorithm, 

58 "EC": jwt.algorithms.ECAlgorithm, 

59 "oct": jwt.algorithms.HMACAlgorithm, 

60 } 

61 

62except ImportError: 

63 pass 

64 

65 

66def _log_and_raise(request_name: str, exception: AuthError) -> str: 

67 LOGGER.info(f"Authorization error. Rejecting '{request_name}' request: " f"Reason=[{str(exception)}]") 

68 raise exception 

69 

70 

71class JwtParser: 

72 def __init__( 

73 self, 

74 secret: Optional[str] = None, 

75 algorithm: AuthMetadataAlgorithm = AuthMetadataAlgorithm.UNSPECIFIED, 

76 jwks_url: Optional[str] = None, 

77 audience: Optional[str] = None, 

78 jwks_fetch_minutes: int = 60, 

79 ) -> None: 

80 self._check_jwt_support(algorithm) 

81 

82 self._algorithm = algorithm 

83 self._audience = audience 

84 self._jwks_client = None 

85 

86 if (secret is None and jwks_url is None) or (secret is not None and jwks_url is not None): 

87 raise TypeError("Exactly one of `secret` or `jwks_url` must be set") 

88 

89 self._secret = secret 

90 

91 if jwks_url is not None: 

92 try: 

93 jwks_lifespan = jwks_fetch_minutes * 60 

94 self._jwks_client = jwt.PyJWKClient(jwks_url, lifespan=jwks_lifespan, max_cached_keys=AUTH_CACHE_SIZE) 

95 except NameError: 

96 LOGGER.error("JWT auth is enabled but PyJWT could not be imported.") 

97 raise 

98 

99 def _check_jwt_support(self, algorithm: AuthMetadataAlgorithm) -> None: 

100 """Ensures JWT and possible dependencies are available.""" 

101 if algorithm == AuthMetadataAlgorithm.UNSPECIFIED: 

102 raise InvalidArgumentError("JWT authorization method requires an algorithm to be specified") 

103 

104 if "jwt" not in sys.modules: 

105 raise InvalidArgumentError("JWT authorization method requires PyJWT") 

106 

107 jwt_invalid_exceptions = (ExpiredTokenError, InvalidTokenError, UnboundedTokenError) 

108 

109 @ExceptionCounter(INVALID_JWT_COUNT_METRIC_NAME, exceptions=jwt_invalid_exceptions) 

110 @DurationMetric(JWT_VALIDATION_TIME_METRIC_NAME) 

111 def parse(self, token: str) -> Dict[str, Any]: 

112 payload: Optional[Dict[str, Any]] = None 

113 try: 

114 if self._secret is not None: 

115 payload = jwt.decode( 

116 token, 

117 self._secret, 

118 algorithms=[self._algorithm.value.upper()], 

119 audience=self._audience, 

120 options={"require": ["exp"], "verify_exp": True}, 

121 ) 

122 

123 elif self._jwks_client is not None: 

124 signing_key = self._jwks_client.get_signing_key_from_jwt(token) 

125 payload = jwt.decode( 

126 token, 

127 signing_key.key, 

128 algorithms=[self._algorithm.value.upper()], 

129 audience=self._audience, 

130 options={"require": ["exp"], "verify_exp": True}, 

131 ) 

132 

133 except jwt.exceptions.ExpiredSignatureError as e: 

134 raise ExpiredTokenError() from e 

135 

136 except jwt.exceptions.MissingRequiredClaimError as e: 

137 raise UnboundedTokenError("Missing required JWT claim, likely 'exp' was not set") from e 

138 

139 except jwt.exceptions.InvalidTokenError as e: 

140 raise InvalidTokenError() from e 

141 

142 except Exception as e: 

143 raise UnexpectedTokenParsingError() from e 

144 

145 if payload is None: 

146 raise InvalidTokenError() 

147 

148 return payload 

149 

150 def identity_from_token(self, token: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: 

151 payload = self.parse(token) 

152 return payload.get("act"), payload.get("sub"), payload.get("aud") 

153 

154 

155class AuthManager(ABC): 

156 @abstractmethod 

157 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool: 

158 """Determine whether or not a request is authorized. 

159 

160 This method takes a ``ServicerContext`` for an incoming gRPC request, 

161 along with the name of the request, and the name of the instance that 

162 the request is intended for. Information about the identity of the 

163 requester is extracted from the context, for example a JWT token. 

164 

165 This identity information is compared to the ACL configuration given 

166 to this class at construction time to determine authorization for the 

167 request. 

168 

169 Args: 

170 context (ServicerContext): The context for the gRPC request to check 

171 the authz status of. 

172 

173 instance_name (str): The name of the instance that the gRPC request 

174 will be interacting with. This is used for per-instance ACLs. 

175 

176 request_name (str): The name of the request being authorized, for 

177 example `Execute`. 

178 

179 Returns: 

180 bool: Whether the request is authorized. 

181 

182 """ 

183 

184 

185class JWTAuthManager(AuthManager): 

186 def __init__( 

187 self, 

188 secret: Optional[str] = None, 

189 algorithm: AuthMetadataAlgorithm = AuthMetadataAlgorithm.UNSPECIFIED, 

190 jwks_url: Optional[str] = None, 

191 audience: Optional[str] = None, 

192 jwks_fetch_minutes: int = 60, 

193 acls: Optional[Mapping[str, InstanceAuthorizationConfig]] = None, 

194 allow_unauthorized_instances: Optional[Set[str]] = None, 

195 ) -> None: 

196 """Initializes a new :class:`JWTAuthManager`. 

197 

198 Args: 

199 secret (str): The secret or key to be used for validating request, 

200 depending on `method`. Defaults to ``None``. 

201 

202 algorithm (AuthMetadataAlgorithm): The crytographic algorithm used 

203 to encode `secret`. Defaults to ``UNSPECIFIED``. 

204 

205 jwks_url (str): The url to fetch the JWKs. Either secret or 

206 this field must be specified if the authentication method is JWT. 

207 Defaults to ``None``. 

208 

209 audience (str): The audience used to validate jwt tokens against. 

210 The tokens must have an audience field. 

211 

212 jwks_fetch_minutes (int): The number of minutes to cache JWKs fetches 

213 for. Defaults to 60. 

214 

215 acls (Mapping[str, InstanceAuthorizationConfig] | None): An optional 

216 map of instance name -> ACL config to use for per-instance 

217 authorization. 

218 

219 allow_unauthorized_instances(Set[str] | None): List of instances that should 

220 be allowed to have unautheticated access 

221 

222 Raises: 

223 InvalidArgumentError: If `algorithm` is not supported. 

224 

225 """ 

226 self._acls = acls 

227 self._allow_unauthorized_instances = allow_unauthorized_instances 

228 self._token_parser = JwtParser(secret, algorithm, jwks_url, audience, jwks_fetch_minutes) 

229 

230 def _token_from_request_context(self, context: grpc.ServicerContext, request_name: str) -> str: 

231 try: 

232 bearer = cast(str, dict(context.invocation_metadata())["authorization"]) 

233 

234 except KeyError: 

235 # Reject requests not carrying a token 

236 _log_and_raise(request_name, MissingTokenError()) 

237 

238 # Reject requests with malformatted bearer 

239 if not bearer.startswith("Bearer "): 

240 _log_and_raise(request_name, InvalidAuthorizationHeaderError()) 

241 

242 return bearer[7:] 

243 

244 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool: 

245 # No need to authorize if unauthorized access is allowed for the instance 

246 if self._allow_unauthorized_instances and instance_name in self._allow_unauthorized_instances: 

247 return True 

248 try: 

249 token = self._token_from_request_context(context, request_name) 

250 actor, subject, workflow = self._token_parser.identity_from_token(token) 

251 except NameError: 

252 LOGGER.error("JWT auth is enabled but PyJWT is not installed.") 

253 return False 

254 except AuthError as e: 

255 LOGGER.info(f"Error authorizing JWT token: {str(e)}") 

256 return False 

257 

258 # If no ACL config was provided at all, don't do any more validation 

259 if self._acls is None: 

260 return True 

261 

262 instance_acl_config = self._acls.get(instance_name) 

263 if instance_acl_config is not None: 

264 return instance_acl_config.is_authorized(request_name, actor=actor, subject=subject, workflow=workflow) 

265 

266 # If there is an ACL, but no config for this instance, deny all 

267 return False 

268 

269 

270class HeadersAuthManager(AuthManager): 

271 def __init__( 

272 self, 

273 acls: Optional[Mapping[str, InstanceAuthorizationConfig]] = None, 

274 allow_unauthorized_instances: Optional[Set[str]] = None, 

275 ) -> None: 

276 """Initializes a new :class:`HeadersAuthManager`. 

277 

278 Args: 

279 acls (Mapping[str, InstanceAuthorizationConfig] | None): An optional 

280 map of instance name -> ACL config to use for per-instance 

281 authorization. 

282 

283 allow_unauthorized_instances(Set[str] | None): List of instances that should 

284 be allowed to have unautheticated access 

285 

286 """ 

287 self._acls = acls 

288 self._allow_unauthorized_instances = allow_unauthorized_instances 

289 

290 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool: 

291 # No need to authorize if unauthorized access is allowed for the instance 

292 if self._allow_unauthorized_instances and instance_name in self._allow_unauthorized_instances: 

293 return True 

294 metadata_dict = dict(context.invocation_metadata()) 

295 actor = str(metadata_dict.get("x-request-actor")) 

296 subject = str(metadata_dict.get("x-request-subject")) 

297 workflow = str(metadata_dict.get("x-request-workflow")) 

298 

299 # If no ACL config was provided at all, don't do any more validation 

300 if self._acls is None: 

301 return True 

302 

303 instance_acl_config = self._acls.get(instance_name) 

304 if instance_acl_config is not None: 

305 return instance_acl_config.is_authorized(request_name, actor=actor, subject=subject, workflow=workflow) 

306 

307 # If there is an ACL, but no config for this instance, deny all 

308 return False 

309 

310 

311# TODO: Once https://github.com/grpc/grpc/issues/33071 is resolved this AuthContext can be 

312# replaced with a gRPC interceptor stores the AuthManager in a request-local ContextVar. 

313AuthContext: "ContextVar[Optional[AuthManager]]" = ContextVar("AuthManager", default=None) 

314 

315 

316def set_auth_manager(manager: Optional[AuthManager]) -> None: 

317 AuthContext.set(manager) 

318 

319 

320def get_auth_manager() -> Optional[AuthManager]: 

321 return AuthContext.get() 

322 

323 

324@contextmanager 

325def _authorize( 

326 request_context: grpc.ServicerContext, instance_name: str, request_name: str 

327) -> Generator[None, None, None]: 

328 manager = get_auth_manager() 

329 

330 # If no auth is configured, don't do authz 

331 if manager is None: 

332 yield 

333 

334 elif manager.authorize(request_context, instance_name, request_name): 

335 yield 

336 

337 else: 

338 LOGGER.info("Authentication failed for request=[" f"{request_name}], peer=[{request_context.peer()}]") 

339 # No need to yield here since calling `abort` raises an Exception 

340 request_context.abort(grpc.StatusCode.UNAUTHENTICATED, "No valid authorization or authentication provided") 

341 

342 

343Func = TypeVar("Func", bound=Callable) # type: ignore[type-arg] 

344 

345 

346def authorize(f: Func) -> Func: 

347 @functools.wraps(f) 

348 def server_stream_wrapper(self: Any, message: Any, context: grpc.ServicerContext) -> Iterator[Any]: 

349 try: 

350 with _authorize(context, current_instance(), f.__name__): 

351 yield from f(self, message, context) 

352 except InvalidArgumentError as e: 

353 context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) 

354 

355 @functools.wraps(f) 

356 def server_unary_wrapper(self: Any, message: Any, context: grpc.ServicerContext) -> Any: 

357 try: 

358 with _authorize(context, current_instance(), f.__name__): 

359 return f(self, message, context) 

360 except InvalidArgumentError as e: 

361 context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) 

362 

363 if inspect.isgeneratorfunction(f): 

364 return cast(Func, server_stream_wrapper) 

365 return cast(Func, server_unary_wrapper)