Skip to content

Commit

Permalink
Merge pull request #58 from deepchatterjeeligo/stop-consumer
Browse files Browse the repository at this point in the history
add a stop method to the consumer class; address #57
  • Loading branch information
cnweaver authored Aug 31, 2022
2 parents d204eb5 + 37dc33b commit 869b19d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
19 changes: 17 additions & 2 deletions adc/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import enum
import logging
from datetime import timedelta
import threading
from typing import Dict, Iterable, Iterator, List, Optional, Set, Union
from collections import defaultdict

Expand All @@ -25,6 +26,7 @@ def __init__(self, conf: 'ConsumerConfig') -> None:
# Workaround for https://github.com/edenhill/librdkafka/issues/3263.
# Remove once confluent-kafka-python 1.9.0 has been released.
self._consumer.poll(0)
self._stop_event = threading.Event()

def subscribe(self,
topics: Union[str, Iterable],
Expand Down Expand Up @@ -84,6 +86,12 @@ def mark_done(self, msg: confluent_kafka.Message, asynchronous: bool = True):
else:
self._consumer.commit(msg, asynchronous=False)

def stop(self):
"""Stops the runloop of the consumer. Useful when running the
consumer in a different thread.
"""
self._stop_event.set()

def stream(self,
autocommit: bool = True,
batch_size: int = 100,
Expand Down Expand Up @@ -121,11 +129,14 @@ def _stream_forever(self,
batch_size: int = 100,
batch_timeout: timedelta = timedelta(seconds=1.0),
) -> Iterator[confluent_kafka.Message]:
while True:
self._stop_event.clear()
while not self._stop_event.is_set():
try:
messages = self._consumer.consume(batch_size,
batch_timeout.total_seconds())
for m in messages:
if self._stop_event.is_set():
break
err = m.error()
if err is None:
self.logger.debug(f"read message from partition {m.partition()}")
Expand Down Expand Up @@ -153,10 +164,13 @@ def _stream_until_eof(self,
self.logger.debug(f"tracking until eof for topic={tp.topic} partition={tp.partition}")
active_partitions[tp.topic].add(tp.partition)

while len(active_partitions) > 0:
self._stop_event.clear()
while len(active_partitions) > 0 and not self._stop_event.is_set():
messages = self._consumer.consume(batch_size, batch_timeout.total_seconds())
try:
for m in messages:
if self._stop_event.is_set():
raise StopIteration
err = m.error()
# A new message may arrive from a previously removed topic/partition,
# in which case it must be re-added
Expand All @@ -181,6 +195,7 @@ def _stream_until_eof(self,
finally:
if autocommit:
self._consumer.commit(asynchronous=True)
self._stop_event.set()

def close(self):
""" Close the consumer, ending its subscriptions. """
Expand Down
29 changes: 29 additions & 0 deletions tests/test_kafka_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_consume_from_beginning(self):
stream = consumer.stream()
msgs = [msg for msg in stream]

assert consumer._stop_event.is_set()
self.assertEqual(len(batch), len(msgs))
for expected, actual in zip(batch, msgs):
self.assertEqual(actual.topic(), topic)
Expand Down Expand Up @@ -137,6 +138,7 @@ def test_consume_stored_offsets(self):
msgs_1 = [msg for msg in stream_1]

# Check that all messages from first batch are processed.
assert consumer_1._stop_event.is_set()
self.assertEqual(len(batch_1), len(msgs_1))
for expected, actual in zip(batch_1, msgs_1):
self.assertEqual(actual.topic(), topic)
Expand All @@ -155,6 +157,7 @@ def test_consume_stored_offsets(self):
# batch is processed.
stream_1 = consumer_1.stream()
msgs_1 = [msg for msg in stream_1]
assert consumer_1._stop_event.is_set()
self.assertEqual(len(batch_2), len(msgs_1))
for expected, actual in zip(batch_2, msgs_1):
self.assertEqual(actual.topic(), topic)
Expand All @@ -173,6 +176,7 @@ def test_consume_stored_offsets(self):
msgs_2 = [msg for msg in stream_2]

# Now check that messages from both batches are processed.
assert consumer_2._stop_event.is_set()
self.assertEqual(len(batch_1 + batch_2), len(msgs_2))
for expected, actual in zip(batch_1 + batch_2, msgs_2):
self.assertEqual(actual.topic(), topic)
Expand All @@ -196,8 +200,33 @@ def test_consume_not_forever(self):
raise Exception(msg.error())
self.assertEqual(msg.topic(), topic)
self.assertEqual(msg.value(), b"message 1")
assert not consumer._stop_event.is_set()
with self.assertRaises(StopIteration):
next(stream)
assert consumer._stop_event.is_set()

def test_consumer_terminating_in_thread(self):
topic = "test_consume_forever_in_thread"
simple_write_msgs(
self.kafka, topic, ["message 1", "message 2", "message 3"])

consumer = adc.consumer.Consumer(adc.consumer.ConsumerConfig(
broker_urls=[self.kafka.address],
group_id="test_consumer",
auth=self.kafka.auth,
read_forever=True
))
consumer.subscribe(topic)

import threading
t = threading.Thread(
target=lambda c: {_ for _ in c.stream()}, args=(consumer,),
name="ListenerThread")
t.start()
# stop listener
consumer.stop()
t.join()
assert t.is_alive() is False

def test_contextmanager_support(self):
topic = "test_contextmanager_support"
Expand Down

0 comments on commit 869b19d

Please sign in to comment.