Skip to content

Commit ceff843

Browse files
committed
Allow an existing TaskExecutor to be configured in ChannelRegistration
This commit introduces a new method to configure an existing TaskExecutor in ChannelRegistration. Contrary to TaskExecutorRegistration, a ThreadPoolTaskExecutor is not necessary, and it can't be further configured. This includes the thread name prefix. Closes spring-projectsgh-32081
1 parent 6b3bf55 commit ceff843

File tree

3 files changed

+198
-21
lines changed

3 files changed

+198
-21
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractMessageBrokerConfiguration.java

+29-19
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.HashMap;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.function.Supplier;
2425

2526
import org.springframework.beans.factory.BeanInitializationException;
2627
import org.springframework.beans.factory.annotation.Qualifier;
@@ -62,6 +63,7 @@
6263
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
6364
import org.springframework.util.Assert;
6465
import org.springframework.util.ClassUtils;
66+
import org.springframework.util.CustomizableThreadCreator;
6567
import org.springframework.util.MimeTypeUtils;
6668
import org.springframework.util.PathMatcher;
6769
import org.springframework.util.StringUtils;
@@ -164,10 +166,8 @@ public AbstractSubscribableChannel clientInboundChannel(
164166

165167
@Bean
166168
public TaskExecutor clientInboundChannelExecutor() {
167-
TaskExecutorRegistration reg = getClientInboundChannelRegistration().taskExecutor();
168-
ThreadPoolTaskExecutor executor = reg.getTaskExecutor();
169-
executor.setThreadNamePrefix("clientInboundChannel-");
170-
return executor;
169+
return getTaskExecutor(getClientInboundChannelRegistration(),
170+
"clientInboundChannel-", this::defaultTaskExecutor);
171171
}
172172

173173
protected final ChannelRegistration getClientInboundChannelRegistration() {
@@ -202,10 +202,8 @@ public AbstractSubscribableChannel clientOutboundChannel(
202202

203203
@Bean
204204
public TaskExecutor clientOutboundChannelExecutor() {
205-
TaskExecutorRegistration reg = getClientOutboundChannelRegistration().taskExecutor();
206-
ThreadPoolTaskExecutor executor = reg.getTaskExecutor();
207-
executor.setThreadNamePrefix("clientOutboundChannel-");
208-
return executor;
205+
return getTaskExecutor(getClientOutboundChannelRegistration(),
206+
"clientOutboundChannel-", this::defaultTaskExecutor);
209207
}
210208

211209
protected final ChannelRegistration getClientOutboundChannelRegistration() {
@@ -246,19 +244,31 @@ public TaskExecutor brokerChannelExecutor(
246244

247245
MessageBrokerRegistry registry = getBrokerRegistry(clientInboundChannel, clientOutboundChannel);
248246
ChannelRegistration registration = registry.getBrokerChannelRegistration();
249-
ThreadPoolTaskExecutor executor;
250-
if (registration.hasTaskExecutor()) {
251-
executor = registration.taskExecutor().getTaskExecutor();
252-
}
253-
else {
247+
return getTaskExecutor(registration, "brokerChannel-", () -> {
254248
// Should never be used
255-
executor = new ThreadPoolTaskExecutor();
256-
executor.setCorePoolSize(0);
257-
executor.setMaxPoolSize(1);
258-
executor.setQueueCapacity(0);
249+
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
250+
threadPoolTaskExecutor.setCorePoolSize(0);
251+
threadPoolTaskExecutor.setMaxPoolSize(1);
252+
threadPoolTaskExecutor.setQueueCapacity(0);
253+
return threadPoolTaskExecutor;
254+
});
255+
}
256+
257+
private static TaskExecutor getTaskExecutor(ChannelRegistration registration,
258+
String threadNamePrefix, Supplier<TaskExecutor> fallback) {
259+
260+
return registration.getTaskExecutor(fallback,
261+
executor -> setThreadNamePrefix(executor, threadNamePrefix));
262+
}
263+
264+
private TaskExecutor defaultTaskExecutor() {
265+
return new TaskExecutorRegistration().getTaskExecutor();
266+
}
267+
268+
private static void setThreadNamePrefix(TaskExecutor taskExecutor, String name) {
269+
if (taskExecutor instanceof CustomizableThreadCreator ctc) {
270+
ctc.setThreadNamePrefix(name);
259271
}
260-
executor.setThreadNamePrefix("brokerChannel-");
261-
return executor;
262272
}
263273

264274
/**

spring-messaging/src/main/java/org/springframework/messaging/simp/config/ChannelRegistration.java

+48-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,7 +19,10 @@
1919
import java.util.ArrayList;
2020
import java.util.Arrays;
2121
import java.util.List;
22+
import java.util.function.Consumer;
23+
import java.util.function.Supplier;
2224

25+
import org.springframework.core.task.TaskExecutor;
2326
import org.springframework.lang.Nullable;
2427
import org.springframework.messaging.support.ChannelInterceptor;
2528
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
@@ -29,13 +32,17 @@
2932
* {@link org.springframework.messaging.MessageChannel}.
3033
*
3134
* @author Rossen Stoyanchev
35+
* @author Stephane Nicoll
3236
* @since 4.0
3337
*/
3438
public class ChannelRegistration {
3539

3640
@Nullable
3741
private TaskExecutorRegistration registration;
3842

43+
@Nullable
44+
private TaskExecutor executor;
45+
3946
private final List<ChannelInterceptor> interceptors = new ArrayList<>();
4047

4148

@@ -59,6 +66,18 @@ public TaskExecutorRegistration taskExecutor(@Nullable ThreadPoolTaskExecutor ta
5966
return this.registration;
6067
}
6168

69+
/**
70+
* Configure the given {@link TaskExecutor} for this message channel,
71+
* taking precedence over a {@linkplain #taskExecutor() task executor
72+
* registration} if any.
73+
* @param taskExecutor the task executor to use
74+
* @since 6.1.4
75+
*/
76+
public ChannelRegistration executor(TaskExecutor taskExecutor) {
77+
this.executor = taskExecutor;
78+
return this;
79+
}
80+
6281
/**
6382
* Configure the given interceptors for this message channel,
6483
* adding them to the channel's current list of interceptors.
@@ -71,13 +90,40 @@ public ChannelRegistration interceptors(ChannelInterceptor... interceptors) {
7190

7291

7392
protected boolean hasTaskExecutor() {
74-
return (this.registration != null);
93+
return (this.registration != null || this.executor != null);
7594
}
7695

7796
protected boolean hasInterceptors() {
7897
return !this.interceptors.isEmpty();
7998
}
8099

100+
/**
101+
* Return the {@link TaskExecutor} to use. If no task executor has been
102+
* configured, the {@code fallback} supplier is used to provide a fallback
103+
* instance.
104+
* <p>
105+
* If the {@link TaskExecutor} to use is suitable for further customizations,
106+
* the {@code customizer} consumer is invoked.
107+
* @param fallback a supplier of a fallback task executor in case none is configured
108+
* @param customizer further customizations
109+
* @return the task executor to use
110+
*/
111+
protected TaskExecutor getTaskExecutor(Supplier<TaskExecutor> fallback, Consumer<TaskExecutor> customizer) {
112+
if (this.executor != null) {
113+
return this.executor;
114+
}
115+
else if (this.registration != null) {
116+
ThreadPoolTaskExecutor registeredTaskExecutor = this.registration.getTaskExecutor();
117+
customizer.accept(registeredTaskExecutor);
118+
return registeredTaskExecutor;
119+
}
120+
else {
121+
TaskExecutor taskExecutor = fallback.get();
122+
customizer.accept(taskExecutor);
123+
return taskExecutor;
124+
}
125+
}
126+
81127
protected List<ChannelInterceptor> getInterceptors() {
82128
return this.interceptors;
83129
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.simp.config;
18+
19+
import java.util.function.Consumer;
20+
import java.util.function.Supplier;
21+
22+
import org.junit.jupiter.api.Test;
23+
24+
import org.springframework.core.task.TaskExecutor;
25+
import org.springframework.messaging.support.ChannelInterceptor;
26+
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
27+
28+
import static org.assertj.core.api.Assertions.assertThat;
29+
import static org.mockito.BDDMockito.given;
30+
import static org.mockito.Mockito.mock;
31+
import static org.mockito.Mockito.verify;
32+
import static org.mockito.Mockito.verifyNoInteractions;
33+
34+
/**
35+
* Tests for {@link ChannelRegistration}.
36+
*
37+
* @author Stephane Nicoll
38+
*/
39+
class ChannelRegistrationTests {
40+
41+
private final Supplier<TaskExecutor> fallback = mock();
42+
43+
private final Consumer<TaskExecutor> customizer = mock();
44+
45+
@Test
46+
void emptyRegistrationUsesFallback() {
47+
TaskExecutor fallbackTaskExecutor = mock(TaskExecutor.class);
48+
given(this.fallback.get()).willReturn(fallbackTaskExecutor);
49+
ChannelRegistration registration = new ChannelRegistration();
50+
assertThat(registration.hasTaskExecutor()).isFalse();
51+
TaskExecutor actual = registration.getTaskExecutor(this.fallback, this.customizer);
52+
assertThat(actual).isSameAs(fallbackTaskExecutor);
53+
verify(this.fallback).get();
54+
verify(this.customizer).accept(fallbackTaskExecutor);
55+
}
56+
57+
@Test
58+
void emptyRegistrationDoesNotHaveInterceptors() {
59+
ChannelRegistration registration = new ChannelRegistration();
60+
assertThat(registration.hasInterceptors()).isFalse();
61+
assertThat(registration.getInterceptors()).isEmpty();
62+
}
63+
64+
@Test
65+
void taskRegistrationCreatesDefaultInstance() {
66+
ChannelRegistration registration = new ChannelRegistration();
67+
registration.taskExecutor();
68+
assertThat(registration.hasTaskExecutor()).isTrue();
69+
TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer);
70+
assertThat(taskExecutor).isInstanceOf(ThreadPoolTaskExecutor.class);
71+
verifyNoInteractions(this.fallback);
72+
verify(this.customizer).accept(taskExecutor);
73+
}
74+
75+
@Test
76+
void taskRegistrationWithExistingThreadPoolTaskExecutor() {
77+
ThreadPoolTaskExecutor existingTaskExecutor = mock(ThreadPoolTaskExecutor.class);
78+
ChannelRegistration registration = new ChannelRegistration();
79+
registration.taskExecutor(existingTaskExecutor);
80+
assertThat(registration.hasTaskExecutor()).isTrue();
81+
TaskExecutor taskExecutor = registration.getTaskExecutor(this.fallback, this.customizer);
82+
assertThat(taskExecutor).isSameAs(existingTaskExecutor);
83+
verifyNoInteractions(this.fallback);
84+
verify(this.customizer).accept(taskExecutor);
85+
}
86+
87+
@Test
88+
void configureExecutor() {
89+
ChannelRegistration registration = new ChannelRegistration();
90+
TaskExecutor taskExecutor = mock(TaskExecutor.class);
91+
registration.executor(taskExecutor);
92+
assertThat(registration.hasTaskExecutor()).isTrue();
93+
TaskExecutor taskExecutor1 = registration.getTaskExecutor(this.fallback, this.customizer);
94+
assertThat(taskExecutor1).isSameAs(taskExecutor);
95+
verifyNoInteractions(this.fallback, this.customizer);
96+
}
97+
98+
@Test
99+
void configureExecutorTakesPrecedenceOverTaskRegistration() {
100+
ChannelRegistration registration = new ChannelRegistration();
101+
TaskExecutor taskExecutor = mock(TaskExecutor.class);
102+
registration.executor(taskExecutor);
103+
ThreadPoolTaskExecutor ignored = mock(ThreadPoolTaskExecutor.class);
104+
registration.taskExecutor(ignored);
105+
assertThat(registration.hasTaskExecutor()).isTrue();
106+
assertThat(registration.getTaskExecutor(this.fallback, this.customizer)).isSameAs(taskExecutor);
107+
verifyNoInteractions(ignored, this.fallback, this.customizer);
108+
109+
}
110+
111+
@Test
112+
void configureInterceptors() {
113+
ChannelRegistration registration = new ChannelRegistration();
114+
ChannelInterceptor interceptor1 = mock(ChannelInterceptor.class);
115+
registration.interceptors(interceptor1);
116+
ChannelInterceptor interceptor2 = mock(ChannelInterceptor.class);
117+
registration.interceptors(interceptor2);
118+
assertThat(registration.getInterceptors()).containsExactly(interceptor1, interceptor2);
119+
}
120+
121+
}

0 commit comments

Comments
 (0)