Skip to content

Commit 5f720ff

Browse files
committed
Remove OAuth2AuthorizationRequest when too many per registration id
1 parent 18978e6 commit 5f720ff

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import java.time.Duration;
2222
import java.time.Instant;
2323
import java.util.HashMap;
24+
import java.util.List;
2425
import java.util.Map;
25-
import java.util.Map.Entry;
26+
import java.util.Objects;
27+
import java.util.stream.Collectors;
2628

2729
import javax.servlet.http.HttpServletRequest;
2830
import javax.servlet.http.HttpServletResponse;
@@ -59,7 +61,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
5961

6062
private Duration authorizationRequestTimeToLive = Duration.ofSeconds(120);
6163

62-
private int maxActiveAuthorizationRequestsPerSession = 10;
64+
private int maxActiveAuthorizationRequestsPerRegistrationIdPerSession = 3;
6365

6466
@Override
6567
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
@@ -87,10 +89,14 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
8789
Map<String, OAuth2AuthorizationRequestReference> authorizationRequests = this.getAuthorizationRequests(request);
8890
authorizationRequests.put(state, new OAuth2AuthorizationRequestReference(authorizationRequest,
8991
this.clock.instant().plus(this.authorizationRequestTimeToLive)));
90-
if (authorizationRequests.size() > this.maxActiveAuthorizationRequestsPerSession) {
91-
authorizationRequests.entrySet().stream()
92-
.sorted((e, f) -> e.getValue().expiresAt.compareTo(f.getValue().expiresAt)).findFirst()
93-
.map(Entry::getKey).ifPresent(authorizationRequests::remove);
92+
for (String registrationId : authorizationRequests.values().stream().map((r) -> r.getRegistrationId())
93+
.distinct().collect(Collectors.toList())) {
94+
List<OAuth2AuthorizationRequestReference> references = authorizationRequests.values().stream()
95+
.filter((r) -> Objects.equals(registrationId, r.getRegistrationId())).collect(Collectors.toList());
96+
if (references.size() > this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession) {
97+
references.stream().sorted((a, b) -> a.expiresAt.compareTo(b.expiresAt)).findFirst()
98+
.map((r) -> r.getState()).ifPresent(authorizationRequests::remove);
99+
}
94100
}
95101
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
96102
}
@@ -177,14 +183,16 @@ void setAuthorizationRequestTimeToLive(Duration authorizationRequestTimeToLive)
177183

178184
/**
179185
* Sets the maximum number of {@link OAuth2AuthorizationRequest} that can be
180-
* stored/active for a session. If the maximum number are present in a session when an
181-
* attempt is made to save another one, then the oldest will be removed.
186+
* stored/active per registration id for a session. If the maximum number are present
187+
* in a session when an attempt is made to save another one, then the oldest will be
188+
* removed.
182189
* @param maxActiveAuthorizationRequestsPerSession must not be negative.
183190
*/
184-
void setMaxActiveAuthorizationRequestsPerSession(int maxActiveAuthorizationRequestsPerSession) {
185-
Assert.state(maxActiveAuthorizationRequestsPerSession > 0,
186-
"maxActiveAuthorizationRequestsPerSession must be greater than zero");
187-
this.maxActiveAuthorizationRequestsPerSession = maxActiveAuthorizationRequestsPerSession;
191+
void setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(
192+
int maxActiveAuthorizationRequestsPerRegistrationIdPerSession) {
193+
Assert.state(maxActiveAuthorizationRequestsPerRegistrationIdPerSession > 0,
194+
"maxActiveAuthorizationRequestsPerRegistrationIdPerSession must be greater than zero");
195+
this.maxActiveAuthorizationRequestsPerRegistrationIdPerSession = maxActiveAuthorizationRequestsPerRegistrationIdPerSession;
188196
}
189197

190198
private static final class OAuth2AuthorizationRequestReference implements Serializable {
@@ -202,6 +210,14 @@ private OAuth2AuthorizationRequestReference(OAuth2AuthorizationRequest authoriza
202210
this.authorizationRequest = authorizationRequest;
203211
}
204212

213+
private String getRegistrationId() {
214+
return this.authorizationRequest.getAttribute(OAuth2ParameterNames.REGISTRATION_ID);
215+
}
216+
217+
private String getState() {
218+
return this.authorizationRequest.getState();
219+
}
220+
205221
}
206222

207223
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.time.Duration;
2121
import java.time.Instant;
2222
import java.time.ZoneId;
23+
import java.util.Collections;
2324
import java.util.HashMap;
2425
import java.util.Map;
2526

@@ -270,17 +271,21 @@ public void removeAuthorizationRequestWhenExpired() {
270271

271272
@Test
272273
public void removeOldestAuthorizationRequestWhenMoreThanMax() {
273-
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerSession(2);
274+
String registrationId = "registration-id-1";
275+
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2);
274276
MockHttpServletRequest request = new MockHttpServletRequest();
275277
MockHttpServletResponse response = new MockHttpServletResponse();
276278
String state1 = "state-1122";
277-
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
279+
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1)
280+
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
278281
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
279282
String state2 = "state-3344";
280-
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
283+
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2)
284+
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
281285
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
282286
String state3 = "state-4455";
283-
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
287+
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3)
288+
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, registrationId)).build();
284289
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
285290
request.addParameter(OAuth2ParameterNames.STATE, state1);
286291
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
@@ -298,6 +303,42 @@ public void removeOldestAuthorizationRequestWhenMoreThanMax() {
298303
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
299304
}
300305

306+
@Test
307+
public void doNotremoveOldestAuthorizationRequestWhenLessThanMax() {
308+
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerRegistrationIdPerSession(2);
309+
MockHttpServletRequest request = new MockHttpServletRequest();
310+
MockHttpServletResponse response = new MockHttpServletResponse();
311+
String state1 = "state-1122";
312+
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1)
313+
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-1"))
314+
.build();
315+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
316+
String state2 = "state-3344";
317+
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2)
318+
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-2"))
319+
.build();
320+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
321+
String state3 = "state-4455";
322+
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3)
323+
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, "registration-id-3"))
324+
.build();
325+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
326+
request.addParameter(OAuth2ParameterNames.STATE, state1);
327+
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
328+
.loadAuthorizationRequest(request);
329+
assertThat(loadedAuthorizationRequest1).isEqualTo(authorizationRequest1);
330+
request.removeParameter(OAuth2ParameterNames.STATE);
331+
request.addParameter(OAuth2ParameterNames.STATE, state2);
332+
OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
333+
.loadAuthorizationRequest(request);
334+
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
335+
request.removeParameter(OAuth2ParameterNames.STATE);
336+
request.addParameter(OAuth2ParameterNames.STATE, state3);
337+
OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
338+
.loadAuthorizationRequest(request);
339+
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
340+
}
341+
301342
private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
302343
return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize")
303344
.clientId("client-id-1234").state("state-1234");

0 commit comments

Comments
 (0)