Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/server.py: 79.65%
285 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2024-10-04 17:48 +0000
1# Copyright (C) 2018 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import faulthandler
17import logging
18import logging.handlers
19import os
20import signal
21import sys
22import threading
23import time
24import traceback
25from collections import defaultdict
26from contextlib import ExitStack
27from datetime import datetime
28from queue import Empty, Queue
29from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple
31import grpc
32from grpc_reflection.v1alpha import reflection
34from buildgrid._protos.buildgrid.v2.monitoring_pb2 import LogRecord
35from buildgrid.server.actioncache.service import ActionCacheService
36from buildgrid.server.bots.service import BotsService
37from buildgrid.server.build_events.service import PublishBuildEventService, QueryBuildEventsService
38from buildgrid.server.capabilities.instance import CapabilitiesInstance
39from buildgrid.server.capabilities.service import CapabilitiesService
40from buildgrid.server.cas.service import ByteStreamService, ContentAddressableStorageService
41from buildgrid.server.context import instance_context
42from buildgrid.server.controller import ExecutionController
43from buildgrid.server.enums import LeaseState, LogRecordLevel, MetricCategories, OperationStage
44from buildgrid.server.exceptions import PermissionDeniedError
45from buildgrid.server.execution.service import ExecutionService
46from buildgrid.server.introspection.service import IntrospectionService
47from buildgrid.server.logging import buildgrid_logger
48from buildgrid.server.metrics_names import METRIC
49from buildgrid.server.metrics_utils import publish_gauge_metric
50from buildgrid.server.monitoring import get_monitoring_bus
51from buildgrid.server.operations.service import OperationsService
52from buildgrid.server.scheduler import Scheduler, SchedulerMetrics
53from buildgrid.server.servicer import Instance, InstancedServicer
54from buildgrid.server.settings import LOG_RECORD_FORMAT, MIN_THREAD_POOL_SIZE, MONITORING_PERIOD, SHUTDOWN_ALARM_DELAY
55from buildgrid.server.threading import ContextThreadPoolExecutor, ContextWorker
56from buildgrid.server.types import OnServerStartCallback, PortAssignedCallback
58LOGGER = buildgrid_logger(__name__)
61def load_tls_server_credentials(
62 server_key: Optional[str] = None, server_cert: Optional[str] = None, client_certs: Optional[str] = None
63) -> Optional[grpc.ServerCredentials]:
64 """Looks-up and loads TLS server gRPC credentials.
66 Every private and public keys are expected to be PEM-encoded.
68 Args:
69 server_key(str): private server key file path.
70 server_cert(str): public server certificate file path.
71 client_certs(str): public client certificates file path.
73 Returns:
74 :obj:`ServerCredentials`: The credentials for use for a
75 TLS-encrypted gRPC server channel.
76 """
77 if not server_key or not os.path.exists(server_key):
78 return None
80 if not server_cert or not os.path.exists(server_cert):
81 return None
83 with open(server_key, "rb") as f:
84 server_key_pem = f.read()
85 with open(server_cert, "rb") as f:
86 server_cert_pem = f.read()
88 if client_certs and os.path.exists(client_certs):
89 with open(client_certs, "rb") as f:
90 client_certs_pem = f.read()
91 else:
92 client_certs_pem = None
93 client_certs = None
95 credentials = grpc.ssl_server_credentials(
96 [(server_key_pem, server_cert_pem)], root_certificates=client_certs_pem, require_client_auth=bool(client_certs)
97 )
99 # TODO: Fix this (missing stubs?) "ServerCredentials" has no attribute
100 credentials.server_key = server_key # type: ignore[attr-defined]
101 credentials.server_cert = server_cert # type: ignore[attr-defined]
102 credentials.client_certs = client_certs # type: ignore[attr-defined]
104 return credentials
107class Server:
108 """Creates a BuildGrid server instance.
110 The :class:`Server` class binds together all the gRPC services.
111 """
113 def __init__(
114 self,
115 server_reflection: bool,
116 grpc_compression: grpc.Compression,
117 is_instrumented: bool,
118 grpc_server_options: Optional[Sequence[Tuple[str, Any]]],
119 max_workers: Optional[int],
120 monitoring_period: float = MONITORING_PERIOD,
121 ):
122 self._stack = ExitStack()
124 self._server_reflection = server_reflection
125 self._grpc_compression = grpc_compression
126 self._is_instrumented = is_instrumented
127 self._grpc_server_options = grpc_server_options
129 self._action_cache_service = ActionCacheService()
130 self._bots_service = BotsService()
131 self._bytestream_service = ByteStreamService()
132 self._capabilities_service = CapabilitiesService()
133 self._cas_service = ContentAddressableStorageService()
134 self._execution_service = ExecutionService()
135 self._operations_service = OperationsService()
136 self._introspection_service = IntrospectionService()
138 # Special cases
139 self._build_events_service = PublishBuildEventService()
140 self._query_build_events_service = QueryBuildEventsService()
142 self._schedulers: Dict[str, Set[Scheduler]] = defaultdict(set)
144 self._ports: List[Tuple[str, Optional[Dict[str, str]]]] = []
145 self._port_map: Dict[str, int] = {}
147 self._logging_queue: Queue[Any] = Queue()
148 self._monitoring_period = monitoring_period
150 if max_workers is None:
151 # Use max_workers default from Python 3.4+
152 max_workers = max(MIN_THREAD_POOL_SIZE, (os.cpu_count() or 1) * 5)
154 elif max_workers < MIN_THREAD_POOL_SIZE:
155 LOGGER.warning(
156 "Specified thread-limit is too small, bumping it.",
157 tags=dict(requested_thread_limit=max_workers, new_thread_limit=MIN_THREAD_POOL_SIZE),
158 )
159 # Enforce a minumun for max_workers
160 max_workers = MIN_THREAD_POOL_SIZE
162 self._max_grpc_workers = max_workers
164 def register_instance(self, instance_name: str, instance: Instance) -> None:
165 """
166 Register an instance with the server. Handled the logic of mapping instances to the
167 correct servicer container.
169 Args:
170 instance_name (str): The name of the instance.
172 instance (Instance): The instance implementation.
173 """
175 # Special case to handle the ExecutionController which combines the service interfaces.
176 if isinstance(instance, ExecutionController):
177 if bots_interface := instance.bots_interface:
178 self.register_instance(instance_name, bots_interface)
179 if execution_instance := instance.execution_instance:
180 self.register_instance(instance_name, execution_instance)
181 if operations_instance := instance.operations_instance:
182 self.register_instance(instance_name, operations_instance)
184 elif action_instance := self._action_cache_service.cast(instance):
185 self._action_cache_service.add_instance(instance_name, action_instance)
186 capabilities = self._capabilities_service.instances.setdefault(instance_name, CapabilitiesInstance())
187 capabilities.add_action_cache_instance(action_instance)
189 elif bots_instance := self._bots_service.cast(instance):
190 self._bots_service.add_instance(instance_name, bots_instance)
191 self._schedulers[instance_name].add(bots_instance.scheduler)
193 elif bytestream_instance := self._bytestream_service.cast(instance):
194 self._bytestream_service.add_instance(instance_name, bytestream_instance)
196 elif cas_instance := self._cas_service.cast(instance):
197 self._cas_service.add_instance(instance_name, cas_instance)
198 capabilities = self._capabilities_service.instances.setdefault(instance_name, CapabilitiesInstance())
199 capabilities.add_cas_instance(cas_instance)
201 elif execution_instance := self._execution_service.cast(instance):
202 self._execution_service.add_instance(instance_name, execution_instance)
203 self._schedulers[instance_name].add(execution_instance.scheduler)
204 capabilities = self._capabilities_service.instances.setdefault(instance_name, CapabilitiesInstance())
205 capabilities.add_execution_instance(execution_instance)
207 elif operations_instance := self._operations_service.cast(instance):
208 self._operations_service.add_instance(instance_name, operations_instance)
210 elif introspection_instance := self._introspection_service.cast(instance):
211 self._introspection_service.add_instance(instance_name, introspection_instance)
213 # The Build Events Services have no support for instance names, so this
214 # is a bit of a special case where the storage backend itself is the
215 # trigger for creating the gRPC services.
216 elif instance.SERVICE_NAME == "BuildEvents":
217 self._build_events_service.add_instance("", instance) # type: ignore[arg-type]
218 self._query_build_events_service.add_instance("", instance) # type: ignore[arg-type]
220 else:
221 raise ValueError(f"Instance of type {type(instance)} not supported by {type(self)}")
223 @property
224 def _services(self) -> Iterable[InstancedServicer[Any]]:
225 return (
226 self._action_cache_service,
227 self._bots_service,
228 self._bytestream_service,
229 self._capabilities_service,
230 self._cas_service,
231 self._execution_service,
232 self._operations_service,
233 self._introspection_service,
234 # Special cases
235 self._build_events_service,
236 self._query_build_events_service,
237 )
239 def add_port(self, address: str, credentials: Optional[Dict[str, str]]) -> None:
240 """Adds a port to the server.
242 Must be called before the server starts. If a credentials object exists,
243 it will make a secure port.
245 Args:
246 address (str): The address with port number.
247 credentials (:obj:`grpc.ChannelCredentials`): Credentials object.
248 """
249 self._ports.append((address, credentials))
251 def start(
252 self,
253 *,
254 on_server_start_cb: Optional[OnServerStartCallback] = None,
255 port_assigned_callback: Optional[PortAssignedCallback] = None,
256 run_forever: bool = True,
257 ) -> None:
258 """Starts the BuildGrid server.
260 BuildGrid server startup consists of 3 stages,
262 1. Starting logging and monitoring
264 This step starts up the logging coroutine, the periodic status metrics
265 coroutine, and the monitoring bus' publishing subprocess. Since this
266 step involves forking, anything not fork-safe needs to be done *after*
267 this step.
269 2. Instantiate gRPC
271 This step instantiates the gRPC server, and tells all the instances
272 which have been attached to the server to instantiate their gRPC
273 objects. It is also responsible for creating the various service
274 objects and connecting them to the server and the instances.
276 After this step, gRPC core is running and its no longer safe to fork
277 the process.
279 3. Start instances
281 Several of BuildGrid's services use background threads that need to
282 be explicitly started when BuildGrid starts up. Rather than doing
283 this at configuration parsing time, this step provides a hook for
284 services to start up in a more organised fashion.
286 4. Start the gRPC server
288 The final step is starting up the gRPC server. The callback passed in
289 via ``on_server_start_cb`` is executed in this step once the server
290 has started. After this point BuildGrid is ready to serve requests.
292 The final thing done by this method is adding a ``SIGTERM`` handler
293 which calls the ``Server.stop`` method to the event loop, and then
294 that loop is started up using ``run_forever()``.
296 Args:
297 on_server_start_cb (Callable): Callback function to execute once
298 the gRPC server has started up.
299 port_assigned_callback (Callable): Callback function to execute
300 once the gRPC server has started up. The mapping of addresses
301 to ports is passed to this callback.
303 """
304 # 1. Start logging and monitoring
305 self._stack.enter_context(
306 ContextWorker(
307 self._logging_worker,
308 "ServerLogger",
309 # Add a dummy value to the queue to unblock the get call.
310 on_shutdown_requested=lambda: self._logging_queue.put(None),
311 )
312 )
313 if self._is_instrumented:
314 self._stack.enter_context(get_monitoring_bus())
315 self._stack.enter_context(ContextWorker(self._state_monitoring_worker, "ServerMonitor"))
317 # 2. Instantiate gRPC objects
318 grpc_server = self.setup_grpc()
320 # 3. Start background threads
321 for service in self._services:
322 self._stack.enter_context(service)
324 # 4. Start the gRPC server.
325 grpc_server.start()
326 self._stack.callback(grpc_server.stop, None)
328 if on_server_start_cb:
329 on_server_start_cb()
330 if port_assigned_callback:
331 port_assigned_callback(port_map=self._port_map)
333 # Add the stop handler and run the event loop
334 if run_forever:
335 grpc_server.wait_for_termination()
337 def setup_grpc(self) -> grpc.Server:
338 """Instantiate the gRPC objects.
340 This creates the gRPC server, and causes the instances attached to
341 this server to instantiate any gRPC channels they need. This also
342 sets up the services which route to those instances, and sets up
343 gRPC reflection.
345 """
346 LOGGER.info(
347 "Setting up gRPC server.",
348 tags=dict(
349 maximum_concurrent_rpcs=self._max_grpc_workers,
350 compression=self._grpc_compression,
351 options=self._grpc_server_options,
352 ),
353 )
355 grpc_server = grpc.server(
356 ContextThreadPoolExecutor(self._max_grpc_workers, "gRPC_Executor", immediate_copy=True),
357 maximum_concurrent_rpcs=self._max_grpc_workers,
358 compression=self._grpc_compression,
359 options=self._grpc_server_options,
360 )
362 # Add the requested ports to the gRPC server
363 for address, credentials in self._ports:
364 port_number = 0
365 if credentials is not None:
366 LOGGER.info("Adding secure connection.", tags=dict(address=address))
367 server_key = credentials.get("tls-server-key")
368 server_cert = credentials.get("tls-server-cert")
369 client_certs = credentials.get("tls-client-certs")
370 server_credentials = load_tls_server_credentials(
371 server_cert=server_cert, server_key=server_key, client_certs=client_certs
372 )
373 # TODO should this error out??
374 if server_credentials:
375 port_number = grpc_server.add_secure_port(address, server_credentials)
377 else:
378 LOGGER.info("Adding insecure connection.", tags=dict(address=address))
379 port_number = grpc_server.add_insecure_port(address)
381 if not port_number:
382 raise PermissionDeniedError("Unable to configure socket")
384 self._port_map[address] = port_number
386 for service in self._services:
387 service.setup_grpc(grpc_server)
389 if self._server_reflection:
390 reflection_services = [service.FULL_NAME for service in self._services if service.enabled]
391 LOGGER.info("Server reflection is enabled.", tags=dict(reflection_services=reflection_services))
392 reflection.enable_server_reflection([reflection.SERVICE_NAME] + reflection_services, grpc_server)
393 else:
394 LOGGER.info("Server reflection is not enabled.")
396 return grpc_server
398 def stop(self, *args: Any, **kwargs: Any) -> None:
399 LOGGER.info("Stopping BuildGrid server.")
401 def alarm_handler(*args: Any, **kwargs: Any) -> None:
402 LOGGER.warning(
403 "Shutdown still ongoing after shutdown delay.",
404 tags=dict(
405 shutdown_alarm_delay_seconds=SHUTDOWN_ALARM_DELAY, active_thread_count=threading.active_count()
406 ),
407 )
408 for thread in threading.enumerate():
409 LOGGER.warning(f" - {thread.name}")
410 LOGGER.warning("Traceback for all threads:")
411 faulthandler.dump_traceback()
413 LOGGER.debug("Setting alarm for delayed shutdown.")
414 signal.signal(signal.SIGALRM, alarm_handler)
415 signal.alarm(SHUTDOWN_ALARM_DELAY)
417 self._stack.close()
419 def _logging_worker(self, shutdown_requested: threading.Event) -> None:
420 """Publishes log records to the monitoring bus."""
422 logging_formatter = logging.Formatter(fmt=LOG_RECORD_FORMAT)
423 logging_handler = logging.handlers.QueueHandler(self._logging_queue)
425 # Setup the main logging handler:
426 root_logger = logging.getLogger()
428 for log_filter in root_logger.filters[:]:
429 logging_handler.addFilter(log_filter)
430 root_logger.removeFilter(log_filter)
432 for log_handler in root_logger.handlers[:]:
433 for log_filter in log_handler.filters[:]:
434 logging_handler.addFilter(log_filter)
435 root_logger.removeHandler(log_handler)
436 root_logger.addHandler(logging_handler)
438 def logging_worker() -> None:
439 monitoring_bus = get_monitoring_bus()
441 try:
442 log_record = self._logging_queue.get(timeout=self._monitoring_period)
443 except Empty:
444 return
445 if log_record is None:
446 return
448 # Print log records to stdout, if required:
449 if not self._is_instrumented or not monitoring_bus.prints_records:
450 record = logging_formatter.format(log_record)
451 # TODO: Investigate if async write would be worth here.
452 sys.stdout.write(f"{record}\n")
453 sys.stdout.flush()
455 # Emit a log record if server is instrumented:
456 if self._is_instrumented:
457 log_record_level = LogRecordLevel(int(log_record.levelno / 10))
458 log_record_creation_time = datetime.fromtimestamp(log_record.created)
459 # logging.LogRecord.extra must be a str to str dict:
460 if "extra" in log_record.__dict__ and log_record.extra:
461 log_record_metadata = log_record.extra
462 else:
463 log_record_metadata = None
464 forged_record = self._forge_log_record(
465 domain=log_record.name,
466 level=log_record_level,
467 message=log_record.message,
468 creation_time=log_record_creation_time,
469 metadata=log_record_metadata,
470 )
471 monitoring_bus.send_record_nowait(forged_record)
473 while not shutdown_requested.is_set():
474 try:
475 logging_worker()
476 except Exception:
477 # The thread shouldn't exit on exceptions, but output the exception so that
478 # it can be found in the logs.
479 #
480 # Note, we DO NOT use `LOGGER` here, because we don't want to write
481 # anything new to the logging queue in case the Exception isn't some transient
482 # issue.
483 try:
484 sys.stdout.write("Exception in logging worker\n")
485 sys.stdout.flush()
486 traceback.print_exc()
487 except Exception:
488 # There's not a lot we can do at this point really.
489 pass
491 if shutdown_requested.is_set():
492 # Reset logging, so any logging after shutting down the logging worker
493 # still gets written to stdout and the queue doesn't get any more logs
494 stream_handler = logging.StreamHandler(stream=sys.stdout)
495 stream_handler.setFormatter(logging_formatter)
496 root_logger = logging.getLogger()
498 for log_filter in root_logger.filters[:]:
499 stream_handler.addFilter(log_filter)
500 root_logger.removeFilter(log_filter)
502 for log_handler in root_logger.handlers[:]:
503 for log_filter in log_handler.filters[:]:
504 stream_handler.addFilter(log_filter)
505 root_logger.removeHandler(log_handler)
506 root_logger.addHandler(stream_handler)
508 # Drain the log message queue
509 while self._logging_queue.qsize() > 0:
510 logging_worker()
512 def _forge_log_record(
513 self,
514 *,
515 domain: str,
516 level: LogRecordLevel,
517 message: str,
518 creation_time: datetime,
519 metadata: Optional[Dict[str, str]] = None,
520 ) -> LogRecord:
521 log_record = LogRecord()
523 log_record.creation_timestamp.FromDatetime(creation_time)
524 log_record.domain = domain
525 log_record.level = level.value
526 log_record.message = message
527 if metadata is not None:
528 log_record.metadata.update(metadata)
530 return log_record
532 def _state_monitoring_worker(self, shutdown_requested: threading.Event) -> None:
533 """Periodically publishes state metrics to the monitoring bus."""
534 while not shutdown_requested.is_set():
535 start = time.time()
536 try:
537 if self._execution_service.enabled:
538 for instance_name in self._execution_service.instances:
539 self._publish_client_metrics_for_instance(instance_name)
541 if self._bots_service.enabled:
542 for instance_name in self._bots_service.instances:
543 self._publish_bot_metrics_for_instance(instance_name)
545 if self._schedulers:
546 for instance_name in self._schedulers:
547 self._publish_scheduler_metrics_for_instance(instance_name)
548 except Exception:
549 # The thread shouldn't exit on exceptions, but log at a severe enough level
550 # that it doesn't get lost in logs
551 LOGGER.exception("Exception while gathering state metrics.")
553 end = time.time()
554 shutdown_requested.wait(timeout=max(0, self._monitoring_period - (end - start)))
556 def _publish_client_metrics_for_instance(self, instance_name: str) -> None:
557 """Queries the number of clients connected for a given instance"""
558 with instance_context(instance_name):
559 n_clients = self._execution_service.query_connected_clients_for_instance(instance_name)
560 publish_gauge_metric(METRIC.CONNECTIONS.CLIENT_COUNT, n_clients)
562 def _publish_bot_metrics_for_instance(self, instance_name: str) -> None:
563 with instance_context(instance_name):
564 for bot_status, count in self._bots_service.count_bots_by_status(instance_name).items():
565 publish_gauge_metric(METRIC.SCHEDULER.BOTS_COUNT, count, state=bot_status.name)
567 n_workers = self._bots_service.query_connected_bots_for_instance(instance_name)
568 publish_gauge_metric(METRIC.CONNECTIONS.WORKER_COUNT, n_workers)
570 def _publish_scheduler_metrics_for_instance(self, instance_name: str) -> None:
571 with instance_context(instance_name):
572 # Since multiple schedulers may be active for this instance, but should
573 # be using the same database, just use the first one
574 scheduler_metrics: Optional[SchedulerMetrics] = None
575 for scheduler in self._schedulers[instance_name]:
576 scheduler_metrics = scheduler.get_metrics()
577 if scheduler_metrics is None:
578 return
580 # Jobs
581 for stage, n_jobs in scheduler_metrics[MetricCategories.JOBS.value].items():
582 stage = OperationStage(stage)
583 publish_gauge_metric(METRIC.SCHEDULER.JOB_COUNT, n_jobs, state=stage.name)
584 # Leases
585 for state, n_leases in scheduler_metrics[MetricCategories.LEASES.value].items():
586 state = LeaseState(state)
587 publish_gauge_metric(METRIC.SCHEDULER.LEASE_COUNT, n_leases, state=state.name)