@@ -632,6 +632,8 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume
632
632
633
633
private final BatchInterceptor <K , V > commonBatchInterceptor = getBatchInterceptor ();
634
634
635
+ private final ConsumerAwareThreadStateProcessor pollThreadStateProcessor ;
636
+
635
637
private final ConsumerSeekCallback seekCallback = new InitialOrIdleSeekCallback ();
636
638
637
639
private final long maxPollInterval ;
@@ -746,12 +748,14 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume
746
748
this .batchListener = (BatchMessageListener <K , V >) listener ;
747
749
this .isBatchListener = true ;
748
750
this .wantsFullRecords = this .batchListener .wantsPollResult ();
751
+ this .pollThreadStateProcessor = setUpPollProcessor (true );
749
752
}
750
753
else if (listener instanceof MessageListener ) {
751
754
this .listener = (MessageListener <K , V >) listener ;
752
755
this .batchListener = null ;
753
756
this .isBatchListener = false ;
754
757
this .wantsFullRecords = false ;
758
+ this .pollThreadStateProcessor = setUpPollProcessor (false );
755
759
}
756
760
else {
757
761
throw new IllegalArgumentException ("Listener must be one of 'MessageListener', "
@@ -802,6 +806,19 @@ else if (listener instanceof MessageListener) {
802
806
this .pausedPartitions = new HashSet <>();
803
807
}
804
808
809
+ @ Nullable
810
+ private ConsumerAwareThreadStateProcessor setUpPollProcessor (boolean batch ) {
811
+ if (batch ) {
812
+ if (this .commonBatchInterceptor != null ) {
813
+ return this .commonBatchInterceptor ;
814
+ }
815
+ }
816
+ else if (this .commonRecordInterceptor != null ) {
817
+ return this .commonRecordInterceptor ;
818
+ }
819
+ return null ;
820
+ }
821
+
805
822
@ Nullable
806
823
private CommonErrorHandler determineCommonErrorHandler (@ Nullable GenericErrorHandler <?> errHandler ) {
807
824
CommonErrorHandler common = getCommonErrorHandler ();
@@ -1314,22 +1331,8 @@ private void invokeIfHaveRecords(@Nullable ConsumerRecords<K, V> records) {
1314
1331
}
1315
1332
1316
1333
private void clearThreadState () {
1317
- if (this .isBatchListener ) {
1318
- interceptClearThreadState (this .commonBatchInterceptor );
1319
- }
1320
- else {
1321
- interceptClearThreadState (this .commonRecordInterceptor );
1322
- }
1323
- }
1324
-
1325
- private void interceptClearThreadState (BeforeAfterPollProcessor <K , V > processor ) {
1326
- if (processor != null ) {
1327
- try {
1328
- processor .clearThreadState (this .consumer );
1329
- }
1330
- catch (Exception e ) {
1331
- this .logger .error (e , "BeforeAfterPollProcessor.clearThreadState threw an exception" );
1332
- }
1334
+ if (this .pollThreadStateProcessor != null ) {
1335
+ this .pollThreadStateProcessor .clearThreadState (this .consumer );
1333
1336
}
1334
1337
}
1335
1338
@@ -1480,22 +1483,8 @@ private ConsumerRecords<K, V> pollConsumer() {
1480
1483
}
1481
1484
1482
1485
private void beforePoll () {
1483
- if (this .isBatchListener ) {
1484
- interceptBeforePoll (this .commonBatchInterceptor );
1485
- }
1486
- else {
1487
- interceptBeforePoll (this .commonRecordInterceptor );
1488
- }
1489
- }
1490
-
1491
- private void interceptBeforePoll (BeforeAfterPollProcessor <K , V > processor ) {
1492
- if (processor != null ) {
1493
- try {
1494
- processor .beforePoll (this .consumer );
1495
- }
1496
- catch (Exception e ) {
1497
- this .logger .error (e , "BeforeAfterPollProcessor.beforePoll threw an exception" );
1498
- }
1486
+ if (this .pollThreadStateProcessor != null ) {
1487
+ this .pollThreadStateProcessor .setupThreadState (this .consumer );
1499
1488
}
1500
1489
}
1501
1490
@@ -2294,6 +2283,9 @@ private void invokeRecordListenerInTx(final ConsumerRecords<K, V> records) {
2294
2283
TransactionSupport .clearTransactionIdSuffix ();
2295
2284
}
2296
2285
}
2286
+ if (this .commonRecordInterceptor != null ) {
2287
+ this .commonRecordInterceptor .afterRecord (record , this .consumer );
2288
+ }
2297
2289
if (this .nackSleep >= 0 ) {
2298
2290
handleNack (records , record );
2299
2291
break ;
@@ -2374,6 +2366,9 @@ private void doInvokeWithRecords(final ConsumerRecords<K, V> records) {
2374
2366
}
2375
2367
this .logger .trace (() -> "Processing " + ListenerUtils .recordToString (record ));
2376
2368
doInvokeRecordListener (record , iterator );
2369
+ if (this .commonRecordInterceptor != null ) {
2370
+ this .commonRecordInterceptor .afterRecord (record , this .consumer );
2371
+ }
2377
2372
if (this .nackSleep >= 0 ) {
2378
2373
handleNack (records , record );
2379
2374
break ;
0 commit comments