Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/decorators/io.py: 90.00%

50 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-10-04 17:48 +0000

1# Copyright (C) 2024 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import functools 

16import inspect 

17from typing import Any, Callable, Iterator, TypeVar, Union, cast 

18 

19import grpc 

20from google.protobuf.message import Message 

21 

22from buildgrid.server.metrics_names import METRIC 

23from buildgrid.server.metrics_utils import publish_counter_metric 

24 

25from .time import _service_metadata 

26 

27Func = TypeVar("Func", bound=Callable) # type: ignore[type-arg] 

28_Message = Union[Iterator[Message], Message] 

29 

30 

31def network_io(f: Func) -> Func: 

32 @functools.wraps(f) 

33 def server_stream_wrapper(self: Any, message: _Message, context: grpc.ServicerContext) -> Iterator[Any]: 

34 input_bytes = 0 

35 if isinstance(message, Iterator): 

36 

37 def stream(messages: Iterator[Message]) -> Iterator[Message]: 

38 nonlocal input_bytes 

39 for input_message in messages: 

40 input_bytes += input_message.ByteSize() 

41 yield input_message 

42 

43 message = stream(message) 

44 else: 

45 input_bytes = message.ByteSize() 

46 

47 output_bytes = 0 

48 try: 

49 for output_message in f(self, message, context): 

50 output_bytes += output_message.ByteSize() 

51 yield output_message 

52 finally: 

53 metadata = _service_metadata(self, message, context) 

54 publish_counter_metric(METRIC.RPC.INPUT_BYTES, input_bytes, **metadata) 

55 publish_counter_metric(METRIC.RPC.OUTPUT_BYTES, output_bytes, **metadata) 

56 

57 @functools.wraps(f) 

58 def server_unary_wrapper(self: Any, message: _Message, context: grpc.ServicerContext) -> Any: 

59 input_bytes = 0 

60 if isinstance(message, Iterator): 

61 

62 def stream(messages: Iterator[Message]) -> Iterator[Message]: 

63 nonlocal input_bytes 

64 for input_message in messages: 

65 input_bytes += input_message.ByteSize() 

66 yield input_message 

67 

68 message = stream(message) 

69 else: 

70 input_bytes = message.ByteSize() 

71 

72 output_bytes = 0 

73 try: 

74 output_message = f(self, message, context) 

75 output_bytes += output_message.ByteSize() 

76 return output_message 

77 finally: 

78 metadata = _service_metadata(self, message, context) 

79 publish_counter_metric(METRIC.RPC.INPUT_BYTES, input_bytes, **metadata) 

80 publish_counter_metric(METRIC.RPC.OUTPUT_BYTES, output_bytes, **metadata) 

81 

82 if inspect.isgeneratorfunction(f): 

83 return cast(Func, server_stream_wrapper) 

84 return cast(Func, server_unary_wrapper)