Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/auth/manager.py: 90.13%
152 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2023 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from abc import ABC, abstractmethod
17from contextvars import ContextVar
18from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union, cast
20import grpc
21import jwt
23from buildgrid._protos.buildgrid.v2.identity_pb2 import ClientIdentity
24from buildgrid.server.auth.config import InstanceAuthorizationConfig
25from buildgrid.server.auth.enums import AuthMetadataAlgorithm
26from buildgrid.server.auth.exceptions import (
27 AuthError,
28 ExpiredTokenError,
29 InvalidAuthorizationHeaderError,
30 InvalidTokenError,
31 MissingTokenError,
32 SigningKeyNotFoundError,
33 UnboundedTokenError,
34 UnexpectedTokenParsingError,
35)
36from buildgrid.server.exceptions import InvalidArgumentError
37from buildgrid.server.logging import buildgrid_logger
38from buildgrid.server.settings import AUTH_CACHE_SIZE
40LOGGER = buildgrid_logger(__name__)
43AlgorithmType = Union[
44 Type[jwt.algorithms.RSAAlgorithm], Type[jwt.algorithms.ECAlgorithm], Type[jwt.algorithms.HMACAlgorithm]
45]
47# Algorithm classes defined in: https://github.com/jpadilla/pyjwt/blob/master/jwt/algorithms.py
48ALGORITHM_TO_PYJWT_CLASS: Dict[str, AlgorithmType] = {
49 "RSA": jwt.algorithms.RSAAlgorithm,
50 "EC": jwt.algorithms.ECAlgorithm,
51 "oct": jwt.algorithms.HMACAlgorithm,
52}
55def _log_and_raise(request_name: str, exception: AuthError) -> str:
56 LOGGER.info("Authorization error. Rejecting.", tags=dict(request_name=request_name, reason=str(exception)))
57 raise exception
60class JwtParser:
61 def __init__(
62 self,
63 secret: Optional[str] = None,
64 algorithm: AuthMetadataAlgorithm = AuthMetadataAlgorithm.UNSPECIFIED,
65 jwks_urls: Optional[List[str]] = None,
66 audiences: Optional[List[str]] = None,
67 jwks_fetch_minutes: int = 60,
68 ) -> None:
69 self._check_jwt_support(algorithm)
71 self._algorithm = algorithm
72 self._audiences = audiences
74 if (secret is None and jwks_urls is None) or (secret is not None and jwks_urls is not None):
75 raise TypeError("Exactly one of `secret` or `jwks_url` must be set")
76 self._secret = secret
77 self._jwks_clients = [
78 jwt.PyJWKClient(url, lifespan=60 * jwks_fetch_minutes, max_cached_keys=AUTH_CACHE_SIZE)
79 for url in (jwks_urls or [])
80 ]
82 def _check_jwt_support(self, algorithm: AuthMetadataAlgorithm) -> None:
83 """Ensures JWT and possible dependencies are available."""
84 if algorithm == AuthMetadataAlgorithm.UNSPECIFIED:
85 raise InvalidArgumentError("JWT authorization method requires an algorithm to be specified")
87 def parse(self, token: str) -> Dict[str, Any]:
88 payload: Optional[Dict[str, Any]] = None
89 try:
90 if self._secret is not None:
91 payload = jwt.decode(
92 token,
93 self._secret,
94 algorithms=[self._algorithm.value.upper()],
95 audience=self._audiences,
96 options={"require": ["exp"], "verify_exp": True},
97 )
99 elif self._jwks_clients:
100 # Find the signing_key in jkus
101 signing_key: Optional[jwt.PyJWK] = None
102 errors: List[Tuple[str, jwt.PyJWKClientError]] = []
103 for jwks_client in self._jwks_clients:
104 try:
105 signing_key = jwks_client.get_signing_key_from_jwt(token)
106 break
107 except jwt.PyJWKClientError as e:
108 errors.append((jwks_client.uri, e))
110 if signing_key is None:
111 error_msg = ", ".join(f"{uri}:{str(err)}" for uri, err in errors)
112 raise SigningKeyNotFoundError(error_msg)
114 payload = jwt.decode(
115 token,
116 signing_key.key,
117 algorithms=[self._algorithm.value.upper()],
118 audience=self._audiences,
119 options={"require": ["exp"], "verify_exp": True},
120 )
122 except AuthError:
123 raise
125 except jwt.exceptions.ExpiredSignatureError as e:
126 raise ExpiredTokenError() from e
128 except jwt.exceptions.MissingRequiredClaimError as e:
129 raise UnboundedTokenError("Missing required JWT claim, likely 'exp' was not set") from e
131 except jwt.exceptions.InvalidTokenError as e:
132 raise InvalidTokenError() from e
134 except Exception as e:
135 raise UnexpectedTokenParsingError() from e
137 if payload is None:
138 raise InvalidTokenError()
140 return payload
142 def identity_from_jwt_payload(self, payload: Dict[str, Any]) -> ClientIdentity:
143 """
144 Extract the relevant claims from the JWT
145 "aud" -> workflow
146 "sub" -> subject
147 "act" -> actor
148 If the "act" field is not set then the subject is considered the actor
149 The audience for the identity is taken from the config if set
150 If "aud" field is an array of strings then the first element is set as the audience'
151 Args:
152 payload: the decoded payload from the jwt
153 Returns:
154 A dictionary containing workflow, actor, subject
155 """
157 workflow = ""
158 if audience_from_payload := payload.get("aud"):
159 if isinstance(audience_from_payload, str):
160 workflow = audience_from_payload
161 elif isinstance(audience_from_payload, list):
162 workflow = audience_from_payload[0]
163 elif self._audiences is not None and len(self._audiences) > 0:
164 workflow = self._audiences[0]
166 actor = payload.get("act")
167 subject = payload.get("sub")
169 if not actor:
170 actor = subject
171 return ClientIdentity(
172 actor=actor if isinstance(actor, str) else "",
173 subject=subject if isinstance(subject, str) else "",
174 workflow=workflow if isinstance(workflow, str) else "",
175 )
177 def identity_from_token(self, token: str) -> ClientIdentity:
178 payload = self.parse(token)
179 return self.identity_from_jwt_payload(payload)
182class AuthManager(ABC):
183 @abstractmethod
184 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool:
185 """Determine whether or not a request is authorized.
187 This method takes a ``ServicerContext`` for an incoming gRPC request,
188 along with the name of the request, and the name of the instance that
189 the request is intended for. Information about the identity of the
190 requester is extracted from the context, for example a JWT token.
192 This identity information is compared to the ACL configuration given
193 to this class at construction time to determine authorization for the
194 request.
196 Args:
197 context (ServicerContext): The context for the gRPC request to check
198 the authz status of.
200 instance_name (str): The name of the instance that the gRPC request
201 will be interacting with. This is used for per-instance ACLs.
203 request_name (str): The name of the request being authorized, for
204 example `Execute`.
206 Returns:
207 bool: Whether the request is authorized.
209 """
212class JWTAuthManager(AuthManager):
213 def __init__(
214 self,
215 secret: Optional[str] = None,
216 algorithm: AuthMetadataAlgorithm = AuthMetadataAlgorithm.UNSPECIFIED,
217 jwks_urls: Optional[List[str]] = None,
218 audiences: Optional[List[str]] = None,
219 jwks_fetch_minutes: int = 60,
220 acls: Optional[Mapping[str, InstanceAuthorizationConfig]] = None,
221 allow_unauthorized_instances: Optional[Set[str]] = None,
222 ) -> None:
223 """Initializes a new :class:`JWTAuthManager`.
225 Args:
226 secret (str): The secret or key to be used for validating request,
227 depending on `method`. Defaults to ``None``.
229 algorithm (AuthMetadataAlgorithm): The crytographic algorithm used
230 to encode `secret`. Defaults to ``UNSPECIFIED``.
232 jwks_urls (list[str]): The urls to fetch the JWKs. Either secret or
233 this field must be specified if the authentication method is JWT.
234 Defaults to ``None``.
236 audiences (list[str]): The audience used to validate jwt tokens against.
237 The tokens must have an audience field.
239 jwks_fetch_minutes (int): The number of minutes to cache JWKs fetches
240 for. Defaults to 60.
242 acls (Mapping[str, InstanceAuthorizationConfig] | None): An optional
243 map of instance name -> ACL config to use for per-instance
244 authorization.
246 allow_unauthorized_instances(Set[str] | None): List of instances that should
247 be allowed to have unautheticated access
249 Raises:
250 InvalidArgumentError: If `algorithm` is not supported.
252 """
253 self._acls = acls
254 self._allow_unauthorized_instances = allow_unauthorized_instances
255 self._token_parser = JwtParser(secret, algorithm, jwks_urls, audiences, jwks_fetch_minutes)
257 def _token_from_request_context(self, context: grpc.ServicerContext, request_name: str) -> str:
258 try:
259 bearer = cast(str, dict(context.invocation_metadata())["authorization"])
261 except KeyError:
262 # Reject requests not carrying a token
263 _log_and_raise(request_name, MissingTokenError())
265 # Reject requests with malformatted bearer
266 if not bearer.startswith("Bearer "):
267 _log_and_raise(request_name, InvalidAuthorizationHeaderError())
269 return bearer[7:]
271 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool:
272 # No need to authorize if unauthorized access is allowed for the instance
273 if self._allow_unauthorized_instances and instance_name in self._allow_unauthorized_instances:
274 return True
275 try:
276 token = self._token_from_request_context(context, request_name)
277 identity_from_token = self._token_parser.identity_from_token(token)
278 workflow = identity_from_token.workflow
279 actor = identity_from_token.actor
280 subject = identity_from_token.subject
281 set_context_client_identity(identity_from_token)
282 except NameError:
283 LOGGER.error("JWT auth is enabled but PyJWT is not installed.")
284 return False
285 except AuthError as e:
286 LOGGER.info(f"Error authorizing JWT token: {str(e)}")
287 return False
289 # If no ACL config was provided at all, don't do any more validation
290 if self._acls is None:
291 return True
293 instance_acl_config = self._acls.get(instance_name)
294 if instance_acl_config is not None:
295 return instance_acl_config.is_authorized(request_name, actor=actor, subject=subject, workflow=workflow)
297 # If there is an ACL, but no config for this instance, deny all
298 return False
301class HeadersAuthManager(AuthManager):
302 def __init__(
303 self,
304 acls: Optional[Mapping[str, InstanceAuthorizationConfig]] = None,
305 allow_unauthorized_instances: Optional[Set[str]] = None,
306 ) -> None:
307 """Initializes a new :class:`HeadersAuthManager`.
309 Args:
310 acls (Mapping[str, InstanceAuthorizationConfig] | None): An optional
311 map of instance name -> ACL config to use for per-instance
312 authorization.
314 allow_unauthorized_instances(Set[str] | None): List of instances that should
315 be allowed to have unautheticated access
317 """
318 self._acls = acls
319 self._allow_unauthorized_instances = allow_unauthorized_instances
321 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool:
322 # No need to authorize if unauthorized access is allowed for the instance
323 if self._allow_unauthorized_instances and instance_name in self._allow_unauthorized_instances:
324 return True
325 metadata_dict = dict(context.invocation_metadata())
326 actor = str(metadata_dict.get("x-request-actor"))
327 subject = str(metadata_dict.get("x-request-subject"))
328 workflow = str(metadata_dict.get("x-request-workflow"))
329 set_context_client_identity(ClientIdentity(actor=actor, subject=subject, workflow=workflow))
330 # If no ACL config was provided at all, don't do any more validation
331 if self._acls is None:
332 return True
334 instance_acl_config = self._acls.get(instance_name)
335 if instance_acl_config is not None:
336 return instance_acl_config.is_authorized(request_name, actor=actor, subject=subject, workflow=workflow)
338 # If there is an ACL, but no config for this instance, deny all
339 return False
342# TODO: Once https://github.com/grpc/grpc/issues/33071 is resolved this AuthContext can be
343# replaced with a gRPC interceptor stores the AuthManager in a request-local ContextVar.
344AuthContext: "ContextVar[Optional[AuthManager]]" = ContextVar("AuthManager", default=None)
347def set_auth_manager(manager: Optional[AuthManager]) -> None:
348 AuthContext.set(manager)
351def get_auth_manager() -> Optional[AuthManager]:
352 return AuthContext.get()
355def authorize_request(request_context: grpc.ServicerContext, instance_name: str, request_name: str) -> None:
356 manager = get_auth_manager()
358 # If no auth is configured, don't do authz
359 if manager is None:
360 return
362 if manager.authorize(request_context, instance_name, request_name):
363 return
365 LOGGER.info(
366 "Authentication failed for request.", tags=dict(request_name=request_name, peer=request_context.peer())
367 )
368 # No need to yield here since calling `abort` raises an Exception
369 request_context.abort(grpc.StatusCode.UNAUTHENTICATED, "No valid authorization or authentication provided")
372ContextClientIdentity: "ContextVar[Optional[ClientIdentity]]" = ContextVar("ClientIdentity", default=None)
375def set_context_client_identity(clientIdentity: ClientIdentity) -> None:
376 ContextClientIdentity.set(clientIdentity)
379def get_context_client_identity() -> Optional[ClientIdentity]:
380 return ContextClientIdentity.get()