Coverage for /builds/BuildGrid/buildgrid/buildgrid/server/rabbitmq/pika_consumer.py: 54.69%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

245 statements  

1# Copyright (C) 2021 Bloomberg LP 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# <http://www.apache.org/licenses/LICENSE-2.0> 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import functools 

16import logging 

17import threading 

18import time 

19 

20from typing import Callable, Dict, Optional, Set 

21 

22import pika # type: ignore 

23 

24from buildgrid.utils import retry_delay 

25 

26 

27class QueueBinding: 

28 def __init__(self, queue: str, exchange: str, routing_key: str, 

29 auto_delete_queue: bool=False): 

30 self.queue = queue 

31 self.exchange = exchange 

32 self.routing_key = routing_key 

33 self.auto_delete_queue = auto_delete_queue 

34 

35 

36class PikaConsumer: 

37 def __init__(self, 

38 connection_parameters: pika.ConnectionParameters, 

39 exchanges: Dict[str, pika.exchange_type.ExchangeType], 

40 bindings: Set[QueueBinding], 

41 prefetch_size: int = 0, 

42 prefetch_count: int = 0, 

43 on_connection_established_callback: Optional[Callable] = None, 

44 on_connection_error_callback: Optional[Callable] = None): 

45 self._connection_parameters = connection_parameters 

46 

47 if prefetch_size < 0 or prefetch_count < 0: 

48 raise ValueError("prefetch_size and prefetch_count cannot be negative") 

49 

50 self._prefetch_size = prefetch_size 

51 self._prefetch_count = prefetch_count 

52 

53 self._exchanges = exchanges 

54 self._bindings = bindings 

55 

56 self._connection: Optional[pika.connection.Connection] = None 

57 self._channel: Optional[pika.channel.Channel] = None 

58 

59 # Status of the consumer. The Condition allows `subscribe()` to 

60 # block until either a connection attempt succeeds or is aborted. 

61 self._connection_status_cv = threading.Condition() 

62 self._stopped = False 

63 self._connected = False 

64 

65 # Callback to notify the caller that the connection is no longer active: 

66 self._on_connection_error_callback = on_connection_error_callback 

67 self._on_connection_established_callback = on_connection_established_callback 

68 

69 self._logger = logging.getLogger(__name__) 

70 

71 self._successful_bindings_counter = 0 

72 self._successful_bindings_counter_lock = threading.Lock() 

73 

74 def start(self): 

75 """Start the consumer, blocking while it runs.""" 

76 self._connect() 

77 self._connection.ioloop.start() 

78 

79 def stop(self): 

80 """Stop the consumer task.""" 

81 with self._connection_status_cv: 

82 if not self._stopped: 

83 self._stopped = True 

84 self._connection_status_cv.notifyAll() 

85 

86 self._logger.debug('Stopping consumer...') 

87 self._close_connection() 

88 self._connection.ioloop.stop() 

89 self._logger.info('Consumer stopped') 

90 

91 def subscribe(self, queue_name: str, callback: Callable) -> str: 

92 """Register a callback to be invoked when receiving a message 

93 from the specified queue. That callback must take an argument 

94 of type `bytes` that will contain the message body and a 

95 `delivery_tag` string that should be used to ACK or NACK the 

96 message. 

97 

98 Returns a consumer tag that can later be passed to `unsubscribe()`. 

99 

100 A sample usage looks like this: 

101 

102 .. code-block:: 

103 

104 consumer = PikaConsumer(...) 

105 def process_message_callback(body: bytes, delivery_tag: str): 

106 # make some preparations 

107 if process_message_and_save_results(body): 

108 consumer.ack_message(delivery_tag) 

109 else: 

110 consumer.nack_message(delivery_tag) 

111 # clean up (message already ACK'd or NACK'd) 

112 

113 consumer_tag = consumer.subscribe('queue', process_message_callback) 

114 ... 

115 // Stop receiving messages from 'queue': 

116 consumer.unsubscribe(consumer_tag) 

117 

118 """ 

119 

120 self._logger.debug(f"Registering callback for queue '{queue_name}'") 

121 

122 self._wait_for_connection_open_or_stop_request() 

123 

124 if self._channel is None: 

125 raise RuntimeError("Channel is not open.") 

126 

127 # pika passes the channel as an argument to the `on_message_callback`, 

128 # but we do not want to expose that. We only forward the message 

129 # and the delivery tag to the user-provided callback. 

130 def callback_invoker(channel: pika.channel.Channel, 

131 method: pika.spec.Basic.Deliver, 

132 properties: pika.spec.BasicProperties, 

133 body: bytes): 

134 self._logger.debug(f"Channel {channel} " 

135 f"received message #{method.delivery_tag} " 

136 f"from app_id='{properties.app_id}' " 

137 f"with consumer tag='{method.consumer_tag}' and " 

138 f"containing {len(body)} bytes in its body") 

139 callback(body, method.delivery_tag) 

140 

141 consumer_tag = self._channel.basic_consume(queue=queue_name, 

142 on_message_callback=callback_invoker) 

143 

144 self._logger.debug(f"Registered callback for queue '{queue_name}', " 

145 f"consumer_tag='{consumer_tag}'") 

146 return consumer_tag 

147 

148 def unsubscribe(self, consumer_tag: str, timeout_seconds: int = 60): 

149 """Stop receiving messages with the subscribed callback. 

150 Issues a cancel request to the server and blocks waiting for it 

151 to return or the specified timeout to be exceeded. 

152 

153 Returns `True` if `consumer_tag` was successfully cancelled, which 

154 guarantees that no more messages will arrive to the subscribed 

155 callback. 

156 """ 

157 self._logger.debug(f"Unsubscribing consumer tag '{consumer_tag}'") 

158 

159 cancelled_event = threading.Event() 

160 

161 def on_cancelled(_method: pika.frame.Method): 

162 cancelled_event.set() 

163 

164 self._channel.basic_cancel(consumer_tag, callback=on_cancelled) # type: ignore 

165 

166 return cancelled_event.wait(timeout=timeout_seconds) 

167 

168 def ack_message(self, delivery_tag: str): 

169 """Schedule the ACK of a message. (The delivery tag 

170 was given to the subscribed callback that received the message.) 

171 """ 

172 callback = functools.partial(self._channel.basic_ack, # type: ignore 

173 delivery_tag=delivery_tag) 

174 self._connection.ioloop.add_callback_threadsafe(callback) # type: ignore 

175 

176 def nack_message(self, delivery_tag: str): 

177 """Schedule the NACK of a message. (The delivery tag 

178 was given to the subscribed callback that received the message.) 

179 """ 

180 callback = functools.partial(self._channel.basic_nack, # type: ignore 

181 delivery_tag=delivery_tag) 

182 self._connection.ioloop.add_callback_threadsafe(callback) # type: ignore 

183 

184 def _wait_for_connection_open_or_stop_request(self): 

185 with self._connection_status_cv: 

186 self._connection_status_cv.wait_for(lambda: self._connected or self._stopped) 

187 

188 # 

189 # Set-up callbacks: 

190 # 

191 def _connect(self): 

192 self._logger.info(f'Connecting to {self._connection_parameters}') 

193 self._connection = pika.SelectConnection( # 1) Open connection 

194 parameters=self._connection_parameters, 

195 on_open_callback=self._on_connection_open, 

196 on_open_error_callback=self._on_connection_open_error, 

197 on_close_callback=self._on_connection_closed) 

198 

199 def _on_connection_open(self, connection: pika.connection.Connection): # 2) Create channel 

200 self._logger.info('Creating new channel') 

201 self._connection.channel(on_open_callback=self._on_channel_open) # type: ignore 

202 

203 def _on_channel_open(self, channel: pika.channel.Channel): # 3) Configure channel 

204 self._channel = channel 

205 self._logger.debug(f'Channel successfully created: {self._channel}') 

206 

207 if self._prefetch_size > 0 or self._prefetch_count > 0: 

208 self._logger.info(f'Setting channel QOS: ' 

209 f'prefetch_size={self._prefetch_size}, ' 

210 f'prefetch_count={self._prefetch_count}') 

211 self._channel.basic_qos(prefetch_size=self._prefetch_size, 

212 prefetch_count=self._prefetch_count) 

213 

214 self._channel.add_on_close_callback(self._on_channel_closed) 

215 

216 self._declare_exchanges() # 4) Declare exchanges 

217 

218 def _declare_exchanges(self): 

219 for exchange_name, exchange_type in self._exchanges.items(): 

220 self._logger.info(f"Declaring exchange '{exchange_name}' " 

221 f"of type {exchange_type}") 

222 

223 callback = functools.partial(self._on_exchange_declared, 

224 exchange=exchange_name) 

225 

226 self._channel.exchange_declare(exchange=exchange_name, 

227 exchange_type=exchange_type, 

228 durable=True, 

229 callback=callback) 

230 

231 def _on_exchange_declared(self, frame: pika.frame.Method, exchange: str): 

232 self._logger.debug(f"Exchange '{exchange}' of type " 

233 f"'{self._exchanges[exchange]}' successfully " 

234 f"declared: {frame}") 

235 

236 self._declare_queues(exchange) 

237 

238 def _declare_queues(self, exchange_name: str): # 5) Declare queues 

239 for binding in self._bindings: 

240 if binding.exchange == exchange_name: 

241 self._logger.info(f"Declaring queue '{binding.queue}'") 

242 

243 callback = functools.partial(self._on_queue_declared, 

244 queue=binding.queue) 

245 

246 self._channel.queue_declare(queue=binding.queue, # type: ignore 

247 auto_delete=binding.auto_delete_queue, 

248 durable=True, 

249 callback=callback) 

250 

251 def _on_queue_declared(self, frame: pika.frame.Method, queue: str): 

252 self._logger.debug(f"Queue '{queue}' successfully declared: " 

253 f"{frame}") 

254 

255 self._bind_queue(queue) 

256 

257 def _bind_queue(self, queue_name: str): # 6) Bind queues 

258 for binding in self._bindings: 

259 if binding.queue == queue_name: 

260 self._logger.info(f"Binding queue '{binding.queue}' " 

261 f"to exchange '{binding.exchange}' " 

262 f"with routing key '{binding.routing_key}'") 

263 

264 callback = functools.partial(self._on_queue_bind_succeeded, 

265 queue_name=binding.queue, 

266 exchange_name=binding.exchange, 

267 routing_key=binding.routing_key) 

268 

269 self._channel.queue_bind(binding.queue, binding.exchange, # type: ignore 

270 binding.routing_key, 

271 callback=callback) 

272 

273 def _on_queue_bind_succeeded(self, 

274 frame: pika.frame.Method, 

275 queue_name: str, 

276 exchange_name: str, 

277 routing_key: str): 

278 self._logger.info(f"Queue '{queue_name}' successfully bound " 

279 f"to exchange '{exchange_name}' " 

280 f"with routing key '{routing_key}': {frame}") 

281 

282 with self._successful_bindings_counter_lock: 

283 self._successful_bindings_counter += 1 

284 if self._successful_bindings_counter == len(self._bindings): 

285 with self._connection_status_cv: 

286 self._connected = True 

287 self._connection_status_cv.notifyAll() 

288 

289 if self._on_connection_established_callback: 

290 self._on_connection_established_callback() 

291 

292 # 

293 # Error-handling callbacks 

294 # 

295 def _on_connection_open_error(self, 

296 connection: pika.connection.Connection, 

297 error: Exception): 

298 self._logger.error(f'Error opening connection: {error}') 

299 self.stop() 

300 if self._on_connection_error_callback: 

301 self._on_connection_error_callback() 

302 

303 def _on_connection_closed(self, 

304 connection: pika.connection.Connection, 

305 reason: Exception): 

306 self.stop() 

307 if self._on_connection_error_callback: 

308 self._on_connection_error_callback() 

309 

310 def _on_channel_closed(self, 

311 channel: pika.channel.Channel, 

312 reason: Exception): 

313 self._logger.error(f'Channel closed. Reason: {reason}') 

314 self.stop() 

315 if self._on_connection_error_callback: 

316 self._on_connection_error_callback() 

317 

318 def _close_connection(self): 

319 if self._connection.is_closing or self._connection.is_closed: 

320 self._logger.debug('Connection is closing or already closed') 

321 else: 

322 self._logger.info('Closing connection') 

323 self._connection.close() 

324 

325 

326class RetryingPikaConsumer: 

327 def __init__(self, 

328 connection_parameters: pika.ConnectionParameters, 

329 exchanges: Dict[str, pika.exchange_type.ExchangeType], 

330 bindings: Set[QueueBinding], 

331 max_connection_attempts: int = 4, 

332 retry_delay_base: int = 1, 

333 prefetch_size: Optional[int] = 0, 

334 prefetch_count: Optional[int] = 0, 

335 on_connection_attempts_exceeded_callback: Optional[Callable] = None): 

336 """ 

337 Create a `PikaConsumer` and watch it to ensure that it stays 

338 connected. 

339 

340 If provided, the `on_connection_attempts_exceeded_callback` will 

341 be invoked after the last connection attempt has failed. 

342 """ 

343 

344 self._logger = logging.getLogger(__name__) 

345 

346 self._connection_parameters = connection_parameters 

347 self._exchanges = exchanges 

348 self._bindings = bindings 

349 

350 self._consumer = None 

351 self._consumer_thread = None 

352 

353 # Status of the consumer. The Condition allows methods like 

354 # `un/subscribe()` to block until either a connection attempt 

355 # succeeds or is aborted. 

356 self._connection_status_cv = threading.Condition() 

357 self._connected = False 

358 self._stopped = False 

359 

360 self._connection_attempts = 0 

361 

362 # If `max_connection_attempts == 0`, retry forever. 

363 self._max_connection_attempts = max_connection_attempts 

364 self._retry_delay_base = retry_delay_base 

365 

366 self._prefetch_size = prefetch_size 

367 self._prefetch_count = prefetch_count 

368 

369 self._on_connection_attempts_exceeded_callback = on_connection_attempts_exceeded_callback 

370 

371 self._connection_watcher_thread = threading.Thread(target=self._watch_consumer) 

372 self._connection_error_event = threading.Event() 

373 self._connection_watcher_thread.start() 

374 

375 # Mapping of queues to user-provided callbacks: 

376 self._callbacks_lock = threading.Lock() 

377 self._callbacks: Dict[str, Callable] = {} 

378 # Mapping of queues to the consumer tags returned by the `PikaConsumer` 

379 # (needed to unsubscribe): 

380 self._consumer_tags: Dict[str, str] = {} 

381 

382 self._start_consumer() 

383 

384 def stop(self): 

385 """Stop the underlying consumer and the monitoring of it.""" 

386 with self._connection_status_cv: 

387 if not self._stopped: 

388 self._stopped = True 

389 self._connection_status_cv.notifyAll() 

390 

391 self._consumer.stop() 

392 self._connection_error_event.set() 

393 self._connection_watcher_thread.join() 

394 

395 def _start_consumer(self): 

396 """Create a new `PikaConsumer` and start it in a background 

397 thread. 

398 (This method will also be called on reconnection attempts.) 

399 """ 

400 with self._connection_status_cv: 

401 self._connected = False 

402 

403 if self._consumer is not None: 

404 self._consumer.stop() 

405 self._consumer = None 

406 

407 self._consumer_tags = {} 

408 

409 self._connection_error_event.clear() 

410 

411 self._consumer = PikaConsumer(connection_parameters=self._connection_parameters, 

412 exchanges=self._exchanges, 

413 bindings=self._bindings, 

414 prefetch_count=self._prefetch_count, 

415 prefetch_size=self._prefetch_size, 

416 on_connection_established_callback=self._on_connection_established, 

417 on_connection_error_callback=self._on_connection_error) 

418 self._consumer_thread = threading.Thread(target=self._consumer.start) 

419 self._consumer_thread.start() 

420 

421 def _on_connection_established(self): 

422 """When the underlying consumer is connected, re-register the 

423 callbacks that might have been present on a previous connection. 

424 """ 

425 self._logger.debug("on_connection_established() called") 

426 

427 with self._connection_status_cv: 

428 self._connected = True 

429 self._connection_status_cv.notifyAll() 

430 

431 # Re-adding existing subscriptions to the brand-new consumer: 

432 with self._callbacks_lock: 

433 for queue, callback in self._callbacks.items(): 

434 consumer_tag = self._consumer.subscribe(queue, callback) 

435 self._consumer_tags[queue] = consumer_tag 

436 

437 def _on_connection_error(self): 

438 """When the underlying consumer reports a connection failure, 

439 notify the watcher thread so that it can either reconnect or 

440 stop monitoring. 

441 """ 

442 self._logger.debug("on_connection_error() called") 

443 self._connection_error_event.set() 

444 

445 def _watch_consumer(self): 

446 """Thread that monitors the connection status of the underlying 

447 consumer and makes sure that the connection stays open. 

448 On connection failures, as long as the retry limit is not 

449 exceeded, it will call `_start_consumer()` to create and 

450 configure a new `PikaConsumer`. 

451 """ 

452 while not self._stopped: 

453 self._logger.debug("Waiting for error event...") 

454 

455 self._connection_error_event.wait() 

456 if self._stopped: 

457 return 

458 self._logger.debug("Connection error event set") 

459 

460 if self._connected: 

461 self._logger.error("Consumer was disconnected, " 

462 "attempting to reconnect") 

463 self._connection_attempts = 1 

464 else: 

465 self._connection_attempts += 1 

466 if self._max_connection_attempts > 0: 

467 self._logger.warning(f"Failed connection attempt " 

468 f"{self._connection_attempts}/" 

469 f"{self._max_connection_attempts}.") 

470 else: 

471 self._logger.warning(f"Failed connection attempt " 

472 f"{self._connection_attempts}") 

473 

474 if self._max_connection_attempts == 0 or \ 

475 self._connection_attempts < self._max_connection_attempts: 

476 delay = retry_delay(self._connection_attempts, self._retry_delay_base) 

477 self._logger.info(f"Waiting {delay}s before " 

478 f"attempting to reconnect") 

479 

480 time.sleep(delay) 

481 self._start_consumer() 

482 else: 

483 self._logger.error(f"Failed last ({self._connection_attempts}/" 

484 f"{self._max_connection_attempts}) " 

485 f"connection attempt.") 

486 

487 with self._connection_status_cv: 

488 self._stopped = True 

489 self._connection_status_cv.notifyAll() 

490 

491 if self._on_connection_attempts_exceeded_callback: 

492 self._on_connection_attempts_exceeded_callback() 

493 

494 def _wait_for_connection_open_or_stop_request(self): 

495 with self._connection_status_cv: 

496 self._connection_status_cv.wait_for(lambda: self._connected or self._stopped) 

497 

498 def subscribe(self, queue_name: str, callback: Callable): 

499 """Register a callback to process messages that arrive to the 

500 queue. 

501 """ 

502 self._wait_for_connection_open_or_stop_request() 

503 

504 if self._consumer is None: 

505 raise RuntimeError("Consumer was not created.") 

506 

507 with self._callbacks_lock: 

508 self._callbacks[queue_name] = callback 

509 self._consumer_tags[queue_name] = self._consumer.subscribe(queue_name, callback) 

510 

511 def unsubscribe(self, queue_name: str): 

512 """Stop receiving messages from the queue.""" 

513 self._wait_for_connection_open_or_stop_request() 

514 

515 if self._consumer is None: 

516 raise RuntimeError("Consumer was not created.") 

517 

518 with self._callbacks_lock: 

519 tag = self._consumer_tags[queue_name] 

520 self._consumer.unsubscribe(tag) 

521 del self._callbacks[queue_name] 

522 

523 def ack_message(self, delivery_tag: str): 

524 """ACK a message using the delivery tag that was given to the 

525 subscribed callback. 

526 If the consumer is stopped or disconnected, raises 

527 `ConnectionError`. 

528 """ 

529 if self._consumer is None: 

530 self._logger.debug(f"Consumer is disconnected, could not " 

531 f"ACK delivery tag '{delivery_tag}'") 

532 raise ConnectionError("Consumer is disconnected") 

533 

534 self._consumer.ack_message(delivery_tag) 

535 

536 def nack_message(self, delivery_tag: str): 

537 """NACK a message using the delivery tag that was given to the 

538 subscribed callback. 

539 If the consumer is stopped or disconnected, raises 

540 `ConnectionError`. 

541 """ 

542 if self._consumer is None: 

543 self._logger.debug(f"Consumer is disconnected, could not " 

544 f"NACK delivery tag '{delivery_tag}'") 

545 raise ConnectionError("Consumer is disconnected") 

546 

547 self._consumer.nack_message(delivery_tag)