Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/server.py: 78.75%
287 statements
« prev ^ index » next coverage.py v7.4.1, created at 2025-02-11 15:07 +0000
« prev ^ index » next coverage.py v7.4.1, created at 2025-02-11 15:07 +0000
1# Copyright (C) 2018 Bloomberg LP
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# <http://www.apache.org/licenses/LICENSE-2.0>
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import logging
17import logging.handlers
18import os
19import signal
20import sys
21import threading
22import time
23import traceback
24from collections import defaultdict
25from contextlib import ExitStack
26from datetime import datetime
27from queue import Empty, Queue
28from types import FrameType
29from typing import Any, Iterable, Sequence
31import grpc
32from grpc_reflection.v1alpha import reflection
34from buildgrid._protos.buildgrid.v2.monitoring_pb2 import LogRecord
35from buildgrid.server.actioncache.service import ActionCacheService
36from buildgrid.server.bots.service import BotsService
37from buildgrid.server.build_events.service import PublishBuildEventService, QueryBuildEventsService
38from buildgrid.server.capabilities.instance import CapabilitiesInstance
39from buildgrid.server.capabilities.service import CapabilitiesService
40from buildgrid.server.cas.service import ByteStreamService, ContentAddressableStorageService
41from buildgrid.server.context import instance_context
42from buildgrid.server.controller import ExecutionController
43from buildgrid.server.enums import LogRecordLevel, MetricCategories
44from buildgrid.server.exceptions import PermissionDeniedError, ShutdownDelayedError
45from buildgrid.server.execution.service import ExecutionService
46from buildgrid.server.introspection.service import IntrospectionService
47from buildgrid.server.logging import buildgrid_logger
48from buildgrid.server.metrics_names import METRIC
49from buildgrid.server.metrics_utils import publish_gauge_metric
50from buildgrid.server.monitoring import get_monitoring_bus
51from buildgrid.server.operations.service import OperationsService
52from buildgrid.server.scheduler import Scheduler, SchedulerMetrics
53from buildgrid.server.servicer import Instance, InstancedServicer
54from buildgrid.server.settings import LOG_RECORD_FORMAT, MIN_THREAD_POOL_SIZE, MONITORING_PERIOD, SHUTDOWN_ALARM_DELAY
55from buildgrid.server.threading import ContextThreadPoolExecutor, ContextWorker
56from buildgrid.server.types import OnServerStartCallback, PortAssignedCallback
58LOGGER = buildgrid_logger(__name__)
61def load_tls_server_credentials(
62 server_key: str | None = None, server_cert: str | None = None, client_certs: str | None = None
63) -> grpc.ServerCredentials | None:
64 """Looks-up and loads TLS server gRPC credentials.
66 Every private and public keys are expected to be PEM-encoded.
68 Args:
69 server_key(str): private server key file path.
70 server_cert(str): public server certificate file path.
71 client_certs(str): public client certificates file path.
73 Returns:
74 :obj:`ServerCredentials`: The credentials for use for a
75 TLS-encrypted gRPC server channel.
76 """
77 if not server_key or not os.path.exists(server_key):
78 return None
80 if not server_cert or not os.path.exists(server_cert):
81 return None
83 with open(server_key, "rb") as f:
84 server_key_pem = f.read()
85 with open(server_cert, "rb") as f:
86 server_cert_pem = f.read()
88 if client_certs and os.path.exists(client_certs):
89 with open(client_certs, "rb") as f:
90 client_certs_pem = f.read()
91 else:
92 client_certs_pem = None
93 client_certs = None
95 credentials = grpc.ssl_server_credentials(
96 [(server_key_pem, server_cert_pem)], root_certificates=client_certs_pem, require_client_auth=bool(client_certs)
97 )
99 # TODO: Fix this (missing stubs?) "ServerCredentials" has no attribute
100 credentials.server_key = server_key # type: ignore[attr-defined]
101 credentials.server_cert = server_cert # type: ignore[attr-defined]
102 credentials.client_certs = client_certs # type: ignore[attr-defined]
104 return credentials
107class Server:
108 """Creates a BuildGrid server instance.
110 The :class:`Server` class binds together all the gRPC services.
111 """
113 def __init__(
114 self,
115 server_reflection: bool,
116 grpc_compression: grpc.Compression,
117 is_instrumented: bool,
118 grpc_server_options: Sequence[tuple[str, Any]] | None,
119 max_workers: int | None,
120 monitoring_period: float = MONITORING_PERIOD,
121 ):
122 self._stack = ExitStack()
124 self._server_reflection = server_reflection
125 self._grpc_compression = grpc_compression
126 self._is_instrumented = is_instrumented
127 self._grpc_server_options = grpc_server_options
129 self._action_cache_service = ActionCacheService()
130 self._bots_service = BotsService()
131 self._bytestream_service = ByteStreamService()
132 self._capabilities_service = CapabilitiesService()
133 self._cas_service = ContentAddressableStorageService()
134 self._execution_service = ExecutionService()
135 self._operations_service = OperationsService()
136 self._introspection_service = IntrospectionService()
138 # Special cases
139 self._build_events_service = PublishBuildEventService()
140 self._query_build_events_service = QueryBuildEventsService()
142 self._schedulers: dict[str, set[Scheduler]] = defaultdict(set)
144 self._ports: list[tuple[str, dict[str, str] | None]] = []
145 self._port_map: dict[str, int] = {}
147 self._logging_queue: Queue[Any] = Queue()
148 self._monitoring_period = monitoring_period
150 if max_workers is None:
151 # Use max_workers default from Python 3.4+
152 max_workers = max(MIN_THREAD_POOL_SIZE, (os.cpu_count() or 1) * 5)
154 elif max_workers < MIN_THREAD_POOL_SIZE:
155 LOGGER.warning(
156 "Specified thread-limit is too small, bumping it.",
157 tags=dict(requested_thread_limit=max_workers, new_thread_limit=MIN_THREAD_POOL_SIZE),
158 )
159 # Enforce a minumun for max_workers
160 max_workers = MIN_THREAD_POOL_SIZE
162 self._max_grpc_workers = max_workers
164 def register_instance(self, instance_name: str, instance: Instance) -> None:
165 """
166 Register an instance with the server. Handled the logic of mapping instances to the
167 correct servicer container.
169 Args:
170 instance_name (str): The name of the instance.
172 instance (Instance): The instance implementation.
173 """
175 # Special case to handle the ExecutionController which combines the service interfaces.
176 if isinstance(instance, ExecutionController):
177 if bots_interface := instance.bots_interface:
178 self.register_instance(instance_name, bots_interface)
179 if execution_instance := instance.execution_instance:
180 self.register_instance(instance_name, execution_instance)
181 if operations_instance := instance.operations_instance:
182 self.register_instance(instance_name, operations_instance)
184 elif action_instance := self._action_cache_service.cast(instance):
185 self._action_cache_service.add_instance(instance_name, action_instance)
186 capabilities = self._capabilities_service.instances.setdefault(instance_name, CapabilitiesInstance())
187 capabilities.add_action_cache_instance(action_instance)
189 elif bots_instance := self._bots_service.cast(instance):
190 self._bots_service.add_instance(instance_name, bots_instance)
191 self._schedulers[instance_name].add(bots_instance.scheduler)
193 elif bytestream_instance := self._bytestream_service.cast(instance):
194 self._bytestream_service.add_instance(instance_name, bytestream_instance)
196 elif cas_instance := self._cas_service.cast(instance):
197 self._cas_service.add_instance(instance_name, cas_instance)
198 capabilities = self._capabilities_service.instances.setdefault(instance_name, CapabilitiesInstance())
199 capabilities.add_cas_instance(cas_instance)
201 elif execution_instance := self._execution_service.cast(instance):
202 self._execution_service.add_instance(instance_name, execution_instance)
203 self._schedulers[instance_name].add(execution_instance.scheduler)
204 capabilities = self._capabilities_service.instances.setdefault(instance_name, CapabilitiesInstance())
205 capabilities.add_execution_instance(execution_instance)
207 elif operations_instance := self._operations_service.cast(instance):
208 self._operations_service.add_instance(instance_name, operations_instance)
210 elif introspection_instance := self._introspection_service.cast(instance):
211 self._introspection_service.add_instance(instance_name, introspection_instance)
213 # The Build Events Services have no support for instance names, so this
214 # is a bit of a special case where the storage backend itself is the
215 # trigger for creating the gRPC services.
216 elif instance.SERVICE_NAME == "BuildEvents":
217 self._build_events_service.add_instance("", instance) # type: ignore[arg-type]
218 self._query_build_events_service.add_instance("", instance) # type: ignore[arg-type]
220 else:
221 raise ValueError(f"Instance of type {type(instance)} not supported by {type(self)}")
223 @property
224 def _services(self) -> Iterable[InstancedServicer[Any]]:
225 return (
226 self._action_cache_service,
227 self._bots_service,
228 self._bytestream_service,
229 self._capabilities_service,
230 self._cas_service,
231 self._execution_service,
232 self._operations_service,
233 self._introspection_service,
234 # Special cases
235 self._build_events_service,
236 self._query_build_events_service,
237 )
239 def add_port(self, address: str, credentials: dict[str, str] | None) -> None:
240 """Adds a port to the server.
242 Must be called before the server starts. If a credentials object exists,
243 it will make a secure port.
245 Args:
246 address (str): The address with port number.
247 credentials (:obj:`grpc.ChannelCredentials`): Credentials object.
248 """
249 self._ports.append((address, credentials))
251 def start(
252 self,
253 *,
254 on_server_start_cb: OnServerStartCallback | None = None,
255 port_assigned_callback: PortAssignedCallback | None = None,
256 run_forever: bool = True,
257 ) -> None:
258 """Starts the BuildGrid server.
260 BuildGrid server startup consists of 3 stages,
262 1. Starting logging and monitoring
264 This step starts up the logging coroutine, the periodic status metrics
265 coroutine, and the monitoring bus' publishing subprocess. Since this
266 step involves forking, anything not fork-safe needs to be done *after*
267 this step.
269 2. Instantiate gRPC
271 This step instantiates the gRPC server, and tells all the instances
272 which have been attached to the server to instantiate their gRPC
273 objects. It is also responsible for creating the various service
274 objects and connecting them to the server and the instances.
276 After this step, gRPC core is running and its no longer safe to fork
277 the process.
279 3. Start instances
281 Several of BuildGrid's services use background threads that need to
282 be explicitly started when BuildGrid starts up. Rather than doing
283 this at configuration parsing time, this step provides a hook for
284 services to start up in a more organised fashion.
286 4. Start the gRPC server
288 The final step is starting up the gRPC server. The callback passed in
289 via ``on_server_start_cb`` is executed in this step once the server
290 has started. After this point BuildGrid is ready to serve requests.
292 The final thing done by this method is adding a ``SIGTERM`` handler
293 which calls the ``Server.stop`` method to the event loop, and then
294 that loop is started up using ``run_forever()``.
296 Args:
297 on_server_start_cb (Callable): Callback function to execute once
298 the gRPC server has started up.
299 port_assigned_callback (Callable): Callback function to execute
300 once the gRPC server has started up. The mapping of addresses
301 to ports is passed to this callback.
303 """
304 # 1. Start logging and monitoring
305 self._stack.enter_context(
306 ContextWorker(
307 self._logging_worker,
308 "ServerLogger",
309 # Add a dummy value to the queue to unblock the get call.
310 on_shutdown_requested=lambda: self._logging_queue.put(None),
311 )
312 )
313 if self._is_instrumented:
314 self._stack.enter_context(get_monitoring_bus())
315 self._stack.enter_context(ContextWorker(self._state_monitoring_worker, "ServerMonitor"))
317 # 2. Instantiate gRPC objects
318 grpc_server = self.setup_grpc()
320 # 3. Start background threads
321 for service in self._services:
322 self._stack.enter_context(service)
324 # 4. Start the gRPC server.
325 grpc_server.start()
326 self._stack.callback(grpc_server.stop, None)
328 if on_server_start_cb:
329 on_server_start_cb()
330 if port_assigned_callback:
331 port_assigned_callback(port_map=self._port_map)
333 # Add the stop handler and run the event loop
334 if run_forever:
335 grpc_server.wait_for_termination()
337 def setup_grpc(self) -> grpc.Server:
338 """Instantiate the gRPC objects.
340 This creates the gRPC server, and causes the instances attached to
341 this server to instantiate any gRPC channels they need. This also
342 sets up the services which route to those instances, and sets up
343 gRPC reflection.
345 """
346 LOGGER.info(
347 "Setting up gRPC server.",
348 tags=dict(
349 maximum_concurrent_rpcs=self._max_grpc_workers,
350 compression=self._grpc_compression,
351 options=self._grpc_server_options,
352 ),
353 )
355 grpc_server = grpc.server(
356 ContextThreadPoolExecutor(self._max_grpc_workers, "gRPC_Executor", immediate_copy=True),
357 maximum_concurrent_rpcs=self._max_grpc_workers,
358 compression=self._grpc_compression,
359 options=self._grpc_server_options,
360 )
362 # Add the requested ports to the gRPC server
363 for address, credentials in self._ports:
364 port_number = 0
365 if credentials is not None:
366 LOGGER.info("Adding secure connection.", tags=dict(address=address))
367 server_key = credentials.get("tls-server-key")
368 server_cert = credentials.get("tls-server-cert")
369 client_certs = credentials.get("tls-client-certs")
370 server_credentials = load_tls_server_credentials(
371 server_cert=server_cert, server_key=server_key, client_certs=client_certs
372 )
373 # TODO should this error out??
374 if server_credentials:
375 port_number = grpc_server.add_secure_port(address, server_credentials)
377 else:
378 LOGGER.info("Adding insecure connection.", tags=dict(address=address))
379 port_number = grpc_server.add_insecure_port(address)
381 if not port_number:
382 raise PermissionDeniedError("Unable to configure socket")
384 self._port_map[address] = port_number
386 for service in self._services:
387 service.setup_grpc(grpc_server)
389 if self._server_reflection:
390 reflection_services = [service.FULL_NAME for service in self._services if service.enabled]
391 LOGGER.info("Server reflection is enabled.", tags=dict(reflection_services=reflection_services))
392 reflection.enable_server_reflection([reflection.SERVICE_NAME] + reflection_services, grpc_server)
393 else:
394 LOGGER.info("Server reflection is not enabled.")
396 return grpc_server
398 def stop(self) -> None:
399 LOGGER.info("Stopping BuildGrid server.")
401 def alarm_handler(_signal: int, _frame: FrameType | None) -> None:
402 LOGGER.warning(
403 "Shutdown still ongoing after shutdown delay.",
404 tags=dict(
405 shutdown_alarm_delay_seconds=SHUTDOWN_ALARM_DELAY, active_thread_count=threading.active_count()
406 ),
407 )
408 for thread in threading.enumerate():
409 if thread.ident is not None:
410 tb = "".join(traceback.format_stack(sys._current_frames()[thread.ident]))
411 LOGGER.warning(f"Thread {thread.name} ({thread.ident})\n{tb}")
412 raise ShutdownDelayedError(f"Shutdown took more than {SHUTDOWN_ALARM_DELAY} seconds")
414 LOGGER.debug("Setting alarm for delayed shutdown.")
415 signal.signal(signal.SIGALRM, alarm_handler)
416 signal.alarm(SHUTDOWN_ALARM_DELAY)
418 try:
419 self._stack.close()
420 except ShutdownDelayedError:
421 # Do nothing, this was raised to interrupt a potentially stuck stack close
422 pass
424 def _logging_worker(self, shutdown_requested: threading.Event) -> None:
425 """Publishes log records to the monitoring bus."""
427 logging_handler = logging.handlers.QueueHandler(self._logging_queue)
429 # Setup the main logging handler:
430 root_logger = logging.getLogger()
432 for log_filter in root_logger.filters[:]:
433 logging_handler.addFilter(log_filter)
434 root_logger.removeFilter(log_filter)
436 # Default formatter before extracting from root_logger handlers
437 logging_formatter = logging.Formatter(fmt=LOG_RECORD_FORMAT)
439 for root_log_handler in root_logger.handlers[:]:
440 for log_filter in root_log_handler.filters[:]:
441 logging_handler.addFilter(log_filter)
442 if root_log_handler.formatter:
443 logging_formatter = root_log_handler.formatter
444 root_logger.removeHandler(root_log_handler)
445 root_logger.addHandler(logging_handler)
447 def logging_worker() -> None:
448 monitoring_bus = get_monitoring_bus()
450 try:
451 log_record = self._logging_queue.get(timeout=self._monitoring_period)
452 except Empty:
453 return
454 if log_record is None:
455 return
457 # Print log records to stdout, if required:
458 if not self._is_instrumented or not monitoring_bus.prints_records:
459 record = logging_formatter.format(log_record)
460 # TODO: Investigate if async write would be worth here.
461 sys.stdout.write(f"{record}\n")
462 sys.stdout.flush()
464 # Emit a log record if server is instrumented:
465 if self._is_instrumented:
466 log_record_level = LogRecordLevel(int(log_record.levelno / 10))
467 log_record_creation_time = datetime.fromtimestamp(log_record.created)
468 # logging.LogRecord.extra must be a str to str dict:
469 if "extra" in log_record.__dict__ and log_record.extra:
470 log_record_metadata = log_record.extra
471 else:
472 log_record_metadata = None
473 forged_record = self._forge_log_record(
474 domain=log_record.name,
475 level=log_record_level,
476 message=log_record.message,
477 creation_time=log_record_creation_time,
478 metadata=log_record_metadata,
479 )
480 monitoring_bus.send_record_nowait(forged_record)
482 while not shutdown_requested.is_set():
483 try:
484 logging_worker()
485 except Exception:
486 # The thread shouldn't exit on exceptions, but output the exception so that
487 # it can be found in the logs.
488 #
489 # Note, we DO NOT use `LOGGER` here, because we don't want to write
490 # anything new to the logging queue in case the Exception isn't some transient
491 # issue.
492 try:
493 sys.stdout.write("Exception in logging worker\n")
494 sys.stdout.flush()
495 traceback.print_exc()
496 except Exception:
497 # There's not a lot we can do at this point really.
498 pass
500 if shutdown_requested.is_set():
501 # Reset logging, so any logging after shutting down the logging worker
502 # still gets written to stdout and the queue doesn't get any more logs
503 stream_handler = logging.StreamHandler(stream=sys.stdout)
504 stream_handler.setFormatter(logging_formatter)
505 root_logger = logging.getLogger()
507 for log_filter in root_logger.filters[:]:
508 stream_handler.addFilter(log_filter)
509 root_logger.removeFilter(log_filter)
511 for log_handler in root_logger.handlers[:]:
512 for log_filter in log_handler.filters[:]:
513 stream_handler.addFilter(log_filter)
514 root_logger.removeHandler(log_handler)
515 root_logger.addHandler(stream_handler)
517 # Drain the log message queue
518 while self._logging_queue.qsize() > 0:
519 logging_worker()
521 def _forge_log_record(
522 self,
523 *,
524 domain: str,
525 level: LogRecordLevel,
526 message: str,
527 creation_time: datetime,
528 metadata: dict[str, str] | None = None,
529 ) -> LogRecord:
530 log_record = LogRecord()
532 log_record.creation_timestamp.FromDatetime(creation_time)
533 log_record.domain = domain
534 log_record.level = level.value
535 log_record.message = message
536 if metadata is not None:
537 log_record.metadata.update(metadata)
539 return log_record
541 def _state_monitoring_worker(self, shutdown_requested: threading.Event) -> None:
542 """Periodically publishes state metrics to the monitoring bus."""
543 while not shutdown_requested.is_set():
544 start = time.time()
545 try:
546 if self._execution_service.enabled:
547 for instance_name in self._execution_service.instances:
548 self._publish_client_metrics_for_instance(instance_name)
550 if self._bots_service.enabled:
551 for instance_name in self._bots_service.instances:
552 self._publish_bot_metrics_for_instance(instance_name)
554 if self._schedulers:
555 for instance_name in self._schedulers:
556 self._publish_scheduler_metrics_for_instance(instance_name)
557 except Exception:
558 # The thread shouldn't exit on exceptions, but log at a severe enough level
559 # that it doesn't get lost in logs
560 LOGGER.exception("Exception while gathering state metrics.")
562 end = time.time()
563 shutdown_requested.wait(timeout=max(0, self._monitoring_period - (end - start)))
565 def _publish_client_metrics_for_instance(self, instance_name: str) -> None:
566 """Queries the number of clients connected for a given instance"""
567 with instance_context(instance_name):
568 n_clients = self._execution_service.query_connected_clients_for_instance(instance_name)
569 publish_gauge_metric(METRIC.CONNECTIONS.CLIENT_COUNT, n_clients)
571 def _publish_bot_metrics_for_instance(self, instance_name: str) -> None:
572 with instance_context(instance_name):
573 for bot_status, count in self._bots_service.count_bots_by_status(instance_name).items():
574 publish_gauge_metric(METRIC.SCHEDULER.BOTS_COUNT, count, state=bot_status.name)
576 n_workers = self._bots_service.query_connected_bots_for_instance(instance_name)
577 publish_gauge_metric(METRIC.CONNECTIONS.WORKER_COUNT, n_workers)
579 def _publish_scheduler_metrics_for_instance(self, instance_name: str) -> None:
580 with instance_context(instance_name):
581 # Since multiple schedulers may be active for this instance, but should
582 # be using the same database, just use the first one
583 scheduler_metrics: SchedulerMetrics | None = None
584 for scheduler in self._schedulers[instance_name]:
585 scheduler_metrics = scheduler.get_metrics()
586 if scheduler_metrics is None:
587 return
589 # Jobs
590 for [stage_name, property_label], number_of_jobs in scheduler_metrics[MetricCategories.JOBS.value].items():
591 publish_gauge_metric(
592 METRIC.SCHEDULER.JOB_COUNT, number_of_jobs, state=stage_name, propertyLabel=property_label
593 )