36
36
import java .util .Collections ;
37
37
import java .util .HashMap ;
38
38
import java .util .HashSet ;
39
+ import java .util .Iterator ;
39
40
import java .util .LinkedList ;
40
41
import java .util .List ;
41
42
import java .util .Map ;
@@ -79,6 +80,8 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
79
80
private Uuid clientInstanceId ;
80
81
private int injectTimeoutExceptionCounter ;
81
82
83
+ private long maxPollRecords = Long .MAX_VALUE ;
84
+
82
85
private final List <KafkaMetric > addedMetrics = new ArrayList <>();
83
86
84
87
/**
@@ -275,14 +278,22 @@ public synchronized ConsumerRecords<K, V> poll(final Duration timeout) {
275
278
// update the consumed offset
276
279
final Map <TopicPartition , List <ConsumerRecord <K , V >>> results = new HashMap <>();
277
280
final Map <TopicPartition , OffsetAndMetadata > nextOffsetAndMetadata = new HashMap <>();
278
- final List <TopicPartition > toClear = new ArrayList <>();
281
+ long numPollRecords = 0L ;
282
+
283
+ final Iterator <Map .Entry <TopicPartition , List <ConsumerRecord <K , V >>>> partitionsIter = this .records .entrySet ().iterator ();
284
+ while (partitionsIter .hasNext () && numPollRecords < this .maxPollRecords ) {
285
+ Map .Entry <TopicPartition , List <ConsumerRecord <K , V >>> entry = partitionsIter .next ();
279
286
280
- for (Map .Entry <TopicPartition , List <ConsumerRecord <K , V >>> entry : this .records .entrySet ()) {
281
287
if (!subscriptions .isPaused (entry .getKey ())) {
282
- final List <ConsumerRecord <K , V >> recs = entry .getValue ();
283
- for (final ConsumerRecord <K , V > rec : recs ) {
288
+ final Iterator <ConsumerRecord <K , V >> recIterator = entry .getValue ().iterator ();
289
+ while (recIterator .hasNext ()) {
290
+ if (numPollRecords >= this .maxPollRecords ) {
291
+ break ;
292
+ }
284
293
long position = subscriptions .position (entry .getKey ()).offset ;
285
294
295
+ final ConsumerRecord <K , V > rec = recIterator .next ();
296
+
286
297
if (beginningOffsets .get (entry .getKey ()) != null && beginningOffsets .get (entry .getKey ()) > position ) {
287
298
throw new OffsetOutOfRangeException (Collections .singletonMap (entry .getKey (), position ));
288
299
}
@@ -294,13 +305,17 @@ public synchronized ConsumerRecords<K, V> poll(final Duration timeout) {
294
305
rec .offset () + 1 , rec .leaderEpoch (), leaderAndEpoch );
295
306
subscriptions .position (entry .getKey (), newPosition );
296
307
nextOffsetAndMetadata .put (entry .getKey (), new OffsetAndMetadata (rec .offset () + 1 , rec .leaderEpoch (), "" ));
308
+ numPollRecords ++;
309
+ recIterator .remove ();
297
310
}
298
311
}
299
- toClear .add (entry .getKey ());
312
+
313
+ if (entry .getValue ().isEmpty ()) {
314
+ partitionsIter .remove ();
315
+ }
300
316
}
301
317
}
302
318
303
- toClear .forEach (records ::remove );
304
319
return new ConsumerRecords <>(results , nextOffsetAndMetadata );
305
320
}
306
321
@@ -314,6 +329,18 @@ public synchronized void addRecord(ConsumerRecord<K, V> record) {
314
329
recs .add (record );
315
330
}
316
331
332
+ /**
333
+ * Sets the maximum number of records returned in a single call to {@link #poll(Duration)}.
334
+ *
335
+ * @param maxPollRecords the max.poll.records.
336
+ */
337
+ public synchronized void setMaxPollRecords (long maxPollRecords ) {
338
+ if (this .maxPollRecords < 1 ) {
339
+ throw new IllegalArgumentException ("MaxPollRecords must be strictly superior to 0" );
340
+ }
341
+ this .maxPollRecords = maxPollRecords ;
342
+ }
343
+
317
344
public synchronized void setPollException (KafkaException exception ) {
318
345
this .pollException = exception ;
319
346
}
0 commit comments