Skip to content

Commit e8ca7ff

Browse files
refresh_partitions should accept only boolean values default to False
1 parent 5ecf918 commit e8ca7ff

2 files changed

Lines changed: 13 additions & 19 deletions

File tree

streamz/sources.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _close_consumer(self):
453453
class FromKafkaBatched(Stream):
454454
"""Base class for both local and cluster-based batched kafka processing"""
455455
def __init__(self, topic, consumer_params, poll_interval='1s',
456-
npartitions=None, check_npartitions_every=None,
456+
npartitions=None, refresh_partitions=False,
457457
max_batch_size=10000, keys=False,
458458
engine=None, **kwargs):
459459
self.consumer_params = consumer_params
@@ -463,7 +463,7 @@ def __init__(self, topic, consumer_params, poll_interval='1s',
463463
consumer_params['auto.offset.reset'] = 'latest'
464464
self.topic = topic
465465
self.npartitions = npartitions
466-
self.check_npartitions_every = check_npartitions_every
466+
self.refresh_partitions = refresh_partitions
467467
if self.npartitions is not None and self.npartitions <= 0:
468468
raise ValueError("Number of Kafka topic partitions must be > 0.")
469469
self.poll_interval = convert_interval(poll_interval)
@@ -511,12 +511,10 @@ def checkpoint_emit(_part):
511511
break
512512

513513
try:
514-
if self.check_npartitions_every is not None:
515-
cycles = 0
516514
while not self.stopped:
517515
out = []
518516

519-
if self.check_npartitions_every is not None and cycles == 0:
517+
if self.refresh_partitions:
520518
kafka_cluster_metadata = self.consumer.list_topics(self.topic)
521519
if self.engine == "cudf": # pragma: no cover
522520
new_partitions = len(kafka_cluster_metadata[self.topic.encode('utf-8')])
@@ -548,9 +546,6 @@ def checkpoint_emit(_part):
548546
self.positions[partition] = high
549547
self.consumer_params['auto.offset.reset'] = 'earliest'
550548

551-
if self.check_npartitions_every is not None:
552-
cycles = (cycles + 1) % self.check_npartitions_every
553-
554549
for part in out:
555550
yield self.loop.add_callback(checkpoint_emit, part)
556551

@@ -580,7 +575,7 @@ def start(self):
580575

581576
@Stream.register_api(staticmethod)
582577
def from_kafka_batched(topic, consumer_params, poll_interval='1s',
583-
npartitions=None, check_npartitions_every=None,
578+
npartitions=None, refresh_partitions=False,
584579
start=False, dask=False,
585580
max_batch_size=10000, keys=False,
586581
engine=None, **kwargs):
@@ -617,13 +612,12 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
617612
npartitions: int (None)
618613
| Number of partitions in the topic.
619614
| If None, streamz will poll Kafka to get the number of partitions.
620-
check_npartitions_every: int (None)
621-
| Useful if the user expects to increase the number of partitions on the fly,
622-
| maybe to handle spikes in load, etc. Streamz polls Kafka after every
623-
| 'check_npartitions_every' number of batches/cycles to determine the current
624-
| number of topic partitions. If partitions have been added, streamz will
625-
| automatically start reading data from the new partitions as well.
626-
| If set to None, streamz will not accommodate changing partitions on the fly.
615+
refresh_partitions: bool (False)
616+
| Useful if the user expects to increase the number of topic partitions on the
617+
| fly, maybe to handle spikes in load. Streamz polls Kafka in every batch to
618+
| determine the current number of partitions. If partitions have been added,
619+
| streamz will automatically start reading data from the new partitions as well.
620+
| If set to False, streamz will not accommodate adding partitions on the fly.
627621
| It is recommended to restart the stream after decreasing the number of partitions.
628622
start: bool (False)
629623
Whether to start polling upon instantiation
@@ -675,7 +669,7 @@ def from_kafka_batched(topic, consumer_params, poll_interval='1s',
675669
source = FromKafkaBatched(topic, consumer_params,
676670
poll_interval=poll_interval,
677671
npartitions=npartitions,
678-
check_npartitions_every=check_npartitions_every,
672+
refresh_partitions=refresh_partitions,
679673
max_batch_size=max_batch_size,
680674
keys=keys,
681675
engine=engine,

streamz/tests/test_kafka.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_kafka_batch_npartitions():
290290
stream3.upstream.stopped = True
291291

292292

293-
def test_kafka_check_npartitions_every():
293+
def test_kafka_refresh_partitions():
294294
j1 = random.randint(0, 10000)
295295
ARGS = {'bootstrap.servers': 'localhost:9092',
296296
'group.id': 'streamz-test%i' % j1,
@@ -315,7 +315,7 @@ def test_kafka_check_npartitions_every():
315315

316316
stream = Stream.from_kafka_batched(TOPIC, ARGS,
317317
asynchronous=True,
318-
check_npartitions_every=1,
318+
refresh_partitions=True,
319319
poll_interval='2s')
320320
out = stream.gather().sink_to_list()
321321
stream.start()

0 commit comments

Comments
 (0)