Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/decorators/instance.py: 100.00%

31 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import functools 

16import inspect 

17import itertools 

18from typing import Any, Callable, Iterator, TypeVar, cast 

19 

20import grpc 

21 

22from buildgrid.server.context import instance_context 

23 

24Func = TypeVar("Func", bound=Callable) # type: ignore[type-arg] 

25 

26 

27def instanced(get_instance_name: Callable[[Any], str]) -> Callable[[Func], Func]: 

28 def decorator(f: Func) -> Func: 

29 @functools.wraps(f) 

30 def server_stream_wrapper(self: Any, message: Any, context: grpc.ServicerContext) -> Iterator[Any]: 

31 if isinstance(message, Iterator): 

32 # Pop the message out to get the instance from it, then and recreate the iterator. 

33 first_message = next(message) 

34 message = itertools.chain([first_message], message) 

35 instance_name = get_instance_name(first_message) 

36 else: 

37 instance_name = get_instance_name(message) 

38 

39 with instance_context(instance_name): 

40 yield from f(self, message, context) 

41 

42 @functools.wraps(f) 

43 def server_unary_wrapper(self: Any, message: Any, context: grpc.ServicerContext) -> Any: 

44 if isinstance(message, Iterator): 

45 # Pop the message out to get the instance from it, then and recreate the iterator. 

46 first_message = next(message) 

47 message = itertools.chain([first_message], message) 

48 instance_name = get_instance_name(first_message) 

49 else: 

50 instance_name = get_instance_name(message) 

51 

52 with instance_context(instance_name): 

53 return f(self, message, context) 

54 

55 if inspect.isgeneratorfunction(f): 

56 return cast(Func, server_stream_wrapper) 

57 return cast(Func, server_unary_wrapper) 

58 

59 return decorator