Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/auth/manager.py: 90.13%

152 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2023 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from abc import ABC, abstractmethod 

17from contextvars import ContextVar 

18from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union, cast 

19 

20import grpc 

21import jwt 

22 

23from buildgrid._protos.buildgrid.v2.identity_pb2 import ClientIdentity 

24from buildgrid.server.auth.config import InstanceAuthorizationConfig 

25from buildgrid.server.auth.enums import AuthMetadataAlgorithm 

26from buildgrid.server.auth.exceptions import ( 

27 AuthError, 

28 ExpiredTokenError, 

29 InvalidAuthorizationHeaderError, 

30 InvalidTokenError, 

31 MissingTokenError, 

32 SigningKeyNotFoundError, 

33 UnboundedTokenError, 

34 UnexpectedTokenParsingError, 

35) 

36from buildgrid.server.exceptions import InvalidArgumentError 

37from buildgrid.server.logging import buildgrid_logger 

38from buildgrid.server.settings import AUTH_CACHE_SIZE 

39 

40LOGGER = buildgrid_logger(__name__) 

41 

42 

43AlgorithmType = Union[ 

44 Type[jwt.algorithms.RSAAlgorithm], Type[jwt.algorithms.ECAlgorithm], Type[jwt.algorithms.HMACAlgorithm] 

45] 

46 

47# Algorithm classes defined in: https://github.com/jpadilla/pyjwt/blob/master/jwt/algorithms.py 

48ALGORITHM_TO_PYJWT_CLASS: Dict[str, AlgorithmType] = { 

49 "RSA": jwt.algorithms.RSAAlgorithm, 

50 "EC": jwt.algorithms.ECAlgorithm, 

51 "oct": jwt.algorithms.HMACAlgorithm, 

52} 

53 

54 

55def _log_and_raise(request_name: str, exception: AuthError) -> str: 

56 LOGGER.info("Authorization error. Rejecting.", tags=dict(request_name=request_name, reason=str(exception))) 

57 raise exception 

58 

59 

60class JwtParser: 

61 def __init__( 

62 self, 

63 secret: Optional[str] = None, 

64 algorithm: AuthMetadataAlgorithm = AuthMetadataAlgorithm.UNSPECIFIED, 

65 jwks_urls: Optional[List[str]] = None, 

66 audiences: Optional[List[str]] = None, 

67 jwks_fetch_minutes: int = 60, 

68 ) -> None: 

69 self._check_jwt_support(algorithm) 

70 

71 self._algorithm = algorithm 

72 self._audiences = audiences 

73 

74 if (secret is None and jwks_urls is None) or (secret is not None and jwks_urls is not None): 

75 raise TypeError("Exactly one of `secret` or `jwks_url` must be set") 

76 self._secret = secret 

77 self._jwks_clients = [ 

78 jwt.PyJWKClient(url, lifespan=60 * jwks_fetch_minutes, max_cached_keys=AUTH_CACHE_SIZE) 

79 for url in (jwks_urls or []) 

80 ] 

81 

82 def _check_jwt_support(self, algorithm: AuthMetadataAlgorithm) -> None: 

83 """Ensures JWT and possible dependencies are available.""" 

84 if algorithm == AuthMetadataAlgorithm.UNSPECIFIED: 

85 raise InvalidArgumentError("JWT authorization method requires an algorithm to be specified") 

86 

87 def parse(self, token: str) -> Dict[str, Any]: 

88 payload: Optional[Dict[str, Any]] = None 

89 try: 

90 if self._secret is not None: 

91 payload = jwt.decode( 

92 token, 

93 self._secret, 

94 algorithms=[self._algorithm.value.upper()], 

95 audience=self._audiences, 

96 options={"require": ["exp"], "verify_exp": True}, 

97 ) 

98 

99 elif self._jwks_clients: 

100 # Find the signing_key in jkus 

101 signing_key: Optional[jwt.PyJWK] = None 

102 errors: List[Tuple[str, jwt.PyJWKClientError]] = [] 

103 for jwks_client in self._jwks_clients: 

104 try: 

105 signing_key = jwks_client.get_signing_key_from_jwt(token) 

106 break 

107 except jwt.PyJWKClientError as e: 

108 errors.append((jwks_client.uri, e)) 

109 

110 if signing_key is None: 

111 error_msg = ", ".join(f"{uri}:{str(err)}" for uri, err in errors) 

112 raise SigningKeyNotFoundError(error_msg) 

113 

114 payload = jwt.decode( 

115 token, 

116 signing_key.key, 

117 algorithms=[self._algorithm.value.upper()], 

118 audience=self._audiences, 

119 options={"require": ["exp"], "verify_exp": True}, 

120 ) 

121 

122 except AuthError: 

123 raise 

124 

125 except jwt.exceptions.ExpiredSignatureError as e: 

126 raise ExpiredTokenError() from e 

127 

128 except jwt.exceptions.MissingRequiredClaimError as e: 

129 raise UnboundedTokenError("Missing required JWT claim, likely 'exp' was not set") from e 

130 

131 except jwt.exceptions.InvalidTokenError as e: 

132 raise InvalidTokenError() from e 

133 

134 except Exception as e: 

135 raise UnexpectedTokenParsingError() from e 

136 

137 if payload is None: 

138 raise InvalidTokenError() 

139 

140 return payload 

141 

142 def identity_from_jwt_payload(self, payload: Dict[str, Any]) -> ClientIdentity: 

143 """ 

144 Extract the relevant claims from the JWT 

145 "aud" -> workflow 

146 "sub" -> subject 

147 "act" -> actor 

148 If the "act" field is not set then the subject is considered the actor 

149 The audience for the identity is taken from the config if set 

150 If "aud" field is an array of strings then the first element is set as the audience' 

151 Args: 

152 payload: the decoded payload from the jwt 

153 Returns: 

154 A dictionary containing workflow, actor, subject 

155 """ 

156 

157 workflow = "" 

158 if audience_from_payload := payload.get("aud"): 

159 if isinstance(audience_from_payload, str): 

160 workflow = audience_from_payload 

161 elif isinstance(audience_from_payload, list): 

162 workflow = audience_from_payload[0] 

163 elif self._audiences is not None and len(self._audiences) > 0: 

164 workflow = self._audiences[0] 

165 

166 actor = payload.get("act") 

167 subject = payload.get("sub") 

168 

169 if not actor: 

170 actor = subject 

171 return ClientIdentity( 

172 actor=actor if isinstance(actor, str) else "", 

173 subject=subject if isinstance(subject, str) else "", 

174 workflow=workflow if isinstance(workflow, str) else "", 

175 ) 

176 

177 def identity_from_token(self, token: str) -> ClientIdentity: 

178 payload = self.parse(token) 

179 return self.identity_from_jwt_payload(payload) 

180 

181 

182class AuthManager(ABC): 

183 @abstractmethod 

184 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool: 

185 """Determine whether or not a request is authorized. 

186 

187 This method takes a ``ServicerContext`` for an incoming gRPC request, 

188 along with the name of the request, and the name of the instance that 

189 the request is intended for. Information about the identity of the 

190 requester is extracted from the context, for example a JWT token. 

191 

192 This identity information is compared to the ACL configuration given 

193 to this class at construction time to determine authorization for the 

194 request. 

195 

196 Args: 

197 context (ServicerContext): The context for the gRPC request to check 

198 the authz status of. 

199 

200 instance_name (str): The name of the instance that the gRPC request 

201 will be interacting with. This is used for per-instance ACLs. 

202 

203 request_name (str): The name of the request being authorized, for 

204 example `Execute`. 

205 

206 Returns: 

207 bool: Whether the request is authorized. 

208 

209 """ 

210 

211 

212class JWTAuthManager(AuthManager): 

213 def __init__( 

214 self, 

215 secret: Optional[str] = None, 

216 algorithm: AuthMetadataAlgorithm = AuthMetadataAlgorithm.UNSPECIFIED, 

217 jwks_urls: Optional[List[str]] = None, 

218 audiences: Optional[List[str]] = None, 

219 jwks_fetch_minutes: int = 60, 

220 acls: Optional[Mapping[str, InstanceAuthorizationConfig]] = None, 

221 allow_unauthorized_instances: Optional[Set[str]] = None, 

222 ) -> None: 

223 """Initializes a new :class:`JWTAuthManager`. 

224 

225 Args: 

226 secret (str): The secret or key to be used for validating request, 

227 depending on `method`. Defaults to ``None``. 

228 

229 algorithm (AuthMetadataAlgorithm): The crytographic algorithm used 

230 to encode `secret`. Defaults to ``UNSPECIFIED``. 

231 

232 jwks_urls (list[str]): The urls to fetch the JWKs. Either secret or 

233 this field must be specified if the authentication method is JWT. 

234 Defaults to ``None``. 

235 

236 audiences (list[str]): The audience used to validate jwt tokens against. 

237 The tokens must have an audience field. 

238 

239 jwks_fetch_minutes (int): The number of minutes to cache JWKs fetches 

240 for. Defaults to 60. 

241 

242 acls (Mapping[str, InstanceAuthorizationConfig] | None): An optional 

243 map of instance name -> ACL config to use for per-instance 

244 authorization. 

245 

246 allow_unauthorized_instances(Set[str] | None): List of instances that should 

247 be allowed to have unautheticated access 

248 

249 Raises: 

250 InvalidArgumentError: If `algorithm` is not supported. 

251 

252 """ 

253 self._acls = acls 

254 self._allow_unauthorized_instances = allow_unauthorized_instances 

255 self._token_parser = JwtParser(secret, algorithm, jwks_urls, audiences, jwks_fetch_minutes) 

256 

257 def _token_from_request_context(self, context: grpc.ServicerContext, request_name: str) -> str: 

258 try: 

259 bearer = cast(str, dict(context.invocation_metadata())["authorization"]) 

260 

261 except KeyError: 

262 # Reject requests not carrying a token 

263 _log_and_raise(request_name, MissingTokenError()) 

264 

265 # Reject requests with malformatted bearer 

266 if not bearer.startswith("Bearer "): 

267 _log_and_raise(request_name, InvalidAuthorizationHeaderError()) 

268 

269 return bearer[7:] 

270 

271 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool: 

272 # No need to authorize if unauthorized access is allowed for the instance 

273 if self._allow_unauthorized_instances and instance_name in self._allow_unauthorized_instances: 

274 return True 

275 try: 

276 token = self._token_from_request_context(context, request_name) 

277 identity_from_token = self._token_parser.identity_from_token(token) 

278 workflow = identity_from_token.workflow 

279 actor = identity_from_token.actor 

280 subject = identity_from_token.subject 

281 set_context_client_identity(identity_from_token) 

282 except NameError: 

283 LOGGER.error("JWT auth is enabled but PyJWT is not installed.") 

284 return False 

285 except AuthError as e: 

286 LOGGER.info(f"Error authorizing JWT token: {str(e)}") 

287 return False 

288 

289 # If no ACL config was provided at all, don't do any more validation 

290 if self._acls is None: 

291 return True 

292 

293 instance_acl_config = self._acls.get(instance_name) 

294 if instance_acl_config is not None: 

295 return instance_acl_config.is_authorized(request_name, actor=actor, subject=subject, workflow=workflow) 

296 

297 # If there is an ACL, but no config for this instance, deny all 

298 return False 

299 

300 

301class HeadersAuthManager(AuthManager): 

302 def __init__( 

303 self, 

304 acls: Optional[Mapping[str, InstanceAuthorizationConfig]] = None, 

305 allow_unauthorized_instances: Optional[Set[str]] = None, 

306 ) -> None: 

307 """Initializes a new :class:`HeadersAuthManager`. 

308 

309 Args: 

310 acls (Mapping[str, InstanceAuthorizationConfig] | None): An optional 

311 map of instance name -> ACL config to use for per-instance 

312 authorization. 

313 

314 allow_unauthorized_instances(Set[str] | None): List of instances that should 

315 be allowed to have unautheticated access 

316 

317 """ 

318 self._acls = acls 

319 self._allow_unauthorized_instances = allow_unauthorized_instances 

320 

321 def authorize(self, context: grpc.ServicerContext, instance_name: str, request_name: str) -> bool: 

322 # No need to authorize if unauthorized access is allowed for the instance 

323 if self._allow_unauthorized_instances and instance_name in self._allow_unauthorized_instances: 

324 return True 

325 metadata_dict = dict(context.invocation_metadata()) 

326 actor = str(metadata_dict.get("x-request-actor")) 

327 subject = str(metadata_dict.get("x-request-subject")) 

328 workflow = str(metadata_dict.get("x-request-workflow")) 

329 set_context_client_identity(ClientIdentity(actor=actor, subject=subject, workflow=workflow)) 

330 # If no ACL config was provided at all, don't do any more validation 

331 if self._acls is None: 

332 return True 

333 

334 instance_acl_config = self._acls.get(instance_name) 

335 if instance_acl_config is not None: 

336 return instance_acl_config.is_authorized(request_name, actor=actor, subject=subject, workflow=workflow) 

337 

338 # If there is an ACL, but no config for this instance, deny all 

339 return False 

340 

341 

342# TODO: Once https://github.com/grpc/grpc/issues/33071 is resolved this AuthContext can be 

343# replaced with a gRPC interceptor stores the AuthManager in a request-local ContextVar. 

344AuthContext: "ContextVar[Optional[AuthManager]]" = ContextVar("AuthManager", default=None) 

345 

346 

347def set_auth_manager(manager: Optional[AuthManager]) -> None: 

348 AuthContext.set(manager) 

349 

350 

351def get_auth_manager() -> Optional[AuthManager]: 

352 return AuthContext.get() 

353 

354 

355def authorize_request(request_context: grpc.ServicerContext, instance_name: str, request_name: str) -> None: 

356 manager = get_auth_manager() 

357 

358 # If no auth is configured, don't do authz 

359 if manager is None: 

360 return 

361 

362 if manager.authorize(request_context, instance_name, request_name): 

363 return 

364 

365 LOGGER.info( 

366 "Authentication failed for request.", tags=dict(request_name=request_name, peer=request_context.peer()) 

367 ) 

368 # No need to yield here since calling `abort` raises an Exception 

369 request_context.abort(grpc.StatusCode.UNAUTHENTICATED, "No valid authorization or authentication provided") 

370 

371 

372ContextClientIdentity: "ContextVar[Optional[ClientIdentity]]" = ContextVar("ClientIdentity", default=None) 

373 

374 

375def set_context_client_identity(clientIdentity: ClientIdentity) -> None: 

376 ContextClientIdentity.set(clientIdentity) 

377 

378 

379def get_context_client_identity() -> Optional[ClientIdentity]: 

380 return ContextClientIdentity.get()