Skip to content

Commit 9b1a85d

Browse files
committed
Limit the number of OAuth2AuthorizationRequest per session
1 parent 97b5c77 commit 9b1a85d

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepository.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.time.Instant;
2323
import java.util.HashMap;
2424
import java.util.Map;
25+
import java.util.Map.Entry;
2526

2627
import javax.servlet.http.HttpServletRequest;
2728
import javax.servlet.http.HttpServletResponse;
@@ -58,6 +59,8 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository
5859

5960
private Duration authorizationRequestTimeToLive = Duration.ofSeconds(120);
6061

62+
private int maxActiveAuthorizationRequestsPerSession = 10;
63+
6164
@Override
6265
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
6366
Assert.notNull(request, "request cannot be null");
@@ -84,6 +87,11 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
8487
Map<String, OAuth2AuthorizationRequestReference> authorizationRequests = this.getAuthorizationRequests(request);
8588
authorizationRequests.put(state, new OAuth2AuthorizationRequestReference(authorizationRequest,
8689
this.clock.instant().plus(this.authorizationRequestTimeToLive)));
90+
if (authorizationRequests.size() > this.maxActiveAuthorizationRequestsPerSession) {
91+
authorizationRequests.entrySet().stream()
92+
.sorted((e, f) -> e.getValue().expiresAt.compareTo(f.getValue().expiresAt)).findFirst()
93+
.map(Entry::getKey).ifPresent(authorizationRequests::remove);
94+
}
8795
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
8896
}
8997

@@ -167,6 +175,18 @@ public void setAuthorizationRequestTimeToLive(Duration authorizationRequestTimeT
167175
this.authorizationRequestTimeToLive = authorizationRequestTimeToLive;
168176
}
169177

178+
/**
179+
* Sets the maximum number of {@link OAuth2AuthorizationRequest} that can be
180+
* stored/active for a session. If the maximum number are present in a session when an
181+
* attempt is made to save another one, then the oldest will be removed.
182+
* @param maxActiveAuthorizationRequests must not be negative.
183+
*/
184+
public void setMaxActiveAuthorizationRequestsPerSession(int maxActiveAuthorizationRequestsPerSession) {
185+
Assert.state(maxActiveAuthorizationRequestsPerSession > 0,
186+
"maxActiveAuthorizationRequestsPerSession must be greater than zero");
187+
this.maxActiveAuthorizationRequestsPerSession = maxActiveAuthorizationRequestsPerSession;
188+
}
189+
170190
private static final class OAuth2AuthorizationRequestReference implements Serializable {
171191

172192
private static final long serialVersionUID = SpringSecurityCoreVersion.SERIAL_VERSION_UID;

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/HttpSessionOAuth2AuthorizationRequestRepositoryTests.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,36 @@ public void removeAuthorizationRequestWhenExpired() {
268268
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
269269
}
270270

271+
@Test
272+
public void removeOldestAuthorizationRequestWhenMoreThanMax() {
273+
this.authorizationRequestRepository.setMaxActiveAuthorizationRequestsPerSession(2);
274+
MockHttpServletRequest request = new MockHttpServletRequest();
275+
MockHttpServletResponse response = new MockHttpServletResponse();
276+
String state1 = "state-1122";
277+
OAuth2AuthorizationRequest authorizationRequest1 = createAuthorizationRequest().state(state1).build();
278+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest1, request, response);
279+
String state2 = "state-3344";
280+
OAuth2AuthorizationRequest authorizationRequest2 = createAuthorizationRequest().state(state2).build();
281+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest2, request, response);
282+
String state3 = "state-4455";
283+
OAuth2AuthorizationRequest authorizationRequest3 = createAuthorizationRequest().state(state3).build();
284+
this.authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest3, request, response);
285+
request.addParameter(OAuth2ParameterNames.STATE, state1);
286+
OAuth2AuthorizationRequest loadedAuthorizationRequest1 = this.authorizationRequestRepository
287+
.loadAuthorizationRequest(request);
288+
assertThat(loadedAuthorizationRequest1).isNull();
289+
request.removeParameter(OAuth2ParameterNames.STATE);
290+
request.addParameter(OAuth2ParameterNames.STATE, state2);
291+
OAuth2AuthorizationRequest loadedAuthorizationRequest2 = this.authorizationRequestRepository
292+
.loadAuthorizationRequest(request);
293+
assertThat(loadedAuthorizationRequest2).isEqualTo(authorizationRequest2);
294+
request.removeParameter(OAuth2ParameterNames.STATE);
295+
request.addParameter(OAuth2ParameterNames.STATE, state3);
296+
OAuth2AuthorizationRequest loadedAuthorizationRequest3 = this.authorizationRequestRepository
297+
.loadAuthorizationRequest(request);
298+
assertThat(loadedAuthorizationRequest3).isEqualTo(authorizationRequest3);
299+
}
300+
271301
private OAuth2AuthorizationRequest.Builder createAuthorizationRequest() {
272302
return OAuth2AuthorizationRequest.authorizationCode().authorizationUri("https://example.com/oauth2/authorize")
273303
.clientId("client-id-1234").state("state-1234");

0 commit comments

Comments
 (0)