Skip to content

Add expireAt to OAuth2AuthorizationRequest and clean-up expired requests on save #7381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
*/
package org.springframework.security.oauth2.client.web;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
Expand Down Expand Up @@ -48,6 +52,10 @@
* <b>NOTE:</b> The default base {@code URI} {@code /oauth2/authorization} may be overridden
* via it's constructor {@link #DefaultOAuth2AuthorizationRequestResolver(ClientRegistrationRepository, String)}.
*
* <p>
* <b>NOTE:</b> {@link OAuth2AuthorizationRequest}s expire after two minutes, the default duration can be configured via
* {@link #setOAuth2AuthorizationRequestExpiresIn(Duration)}.
*
* @author Joe Grandja
* @author Rob Winch
* @author Eddú Meléndez
Expand All @@ -62,6 +70,8 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
private final AntPathRequestMatcher authorizationRequestMatcher;
private final StringKeyGenerator stateGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
private final StringKeyGenerator codeVerifierGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private Duration oAuth2AuthorizationRequestExpiresIn = Duration.ofSeconds(120);
private Clock clock = Clock.systemUTC();

/**
* Constructs a {@code DefaultOAuth2AuthorizationRequestResolver} using the provided parameters.
Expand Down Expand Up @@ -140,6 +150,7 @@ private OAuth2AuthorizationRequest resolve(HttpServletRequest request, String re
.scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes)
.expiresAt(calculateExpiration())
.build();

return authorizationRequest;
Expand Down Expand Up @@ -230,4 +241,35 @@ private String createCodeChallenge(String codeVerifier) throws NoSuchAlgorithmEx
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
}

private Instant calculateExpiration() {
return this.oAuth2AuthorizationRequestExpiresIn.isNegative() ? null
: Instant.now(this.clock).plus(this.oAuth2AuthorizationRequestExpiresIn);
}

/**
* Sets the {@link Duration} used in {@link Instant#plus(TemporalAmount)} when calculating the {@link
* OAuth2AuthorizationRequest#getExpiresAt()}, a negative {@link Duration} indicates that {@link
* OAuth2AuthorizationRequest} never expire.
*
* @param oAuth2AuthorizationRequestExpiresIn the {@link Duration} a {@link OAuth2AuthorizationRequest} is
* considered not expired
* @since 5.2
*/
public void setOAuth2AuthorizationRequestExpiresIn(Duration oAuth2AuthorizationRequestExpiresIn) {
Assert.notNull(oAuth2AuthorizationRequestExpiresIn, "oAuth2AuthorizationRequestExpiresIn cannot be null");
this.oAuth2AuthorizationRequestExpiresIn = oAuth2AuthorizationRequestExpiresIn;
}

/**
* Sets the {@link Clock} used in {@link Instant#now(Clock)} when setting the {@link
* OAuth2AuthorizationRequest#getExpiresAt()}.
*
* @param clock the clock
* @since 5.2
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.security.oauth2.client.web;

import java.time.Clock;
import java.time.Instant;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.Assert;
Expand All @@ -40,6 +42,7 @@ public final class HttpSessionOAuth2AuthorizationRequestRepository implements Au
HttpSessionOAuth2AuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST";

private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
private Clock clock = Clock.systemUTC();

@Override
public OAuth2AuthorizationRequest loadAuthorizationRequest(HttpServletRequest request) {
Expand All @@ -64,6 +67,7 @@ public void saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationReq
String state = authorizationRequest.getState();
Assert.hasText(state, "authorizationRequest.state cannot be empty");
Map<String, OAuth2AuthorizationRequest> authorizationRequests = this.getAuthorizationRequests(request);
cleaUpExpiredAuthorizationRequests(authorizationRequests);
authorizationRequests.put(state, authorizationRequest);
request.getSession().setAttribute(this.sessionAttributeName, authorizationRequests);
}
Expand Down Expand Up @@ -100,11 +104,17 @@ private String getStateParameter(HttpServletRequest request) {
return request.getParameter(OAuth2ParameterNames.STATE);
}

private void cleaUpExpiredAuthorizationRequests(
Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest) {
stateToAuthzRequest.values().removeIf(request -> request.isExpired(this.clock));
}

/**
* Gets a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}
* @param request
* @return a non-null and mutable map of {@link OAuth2AuthorizationRequest#getState()} to an {@link OAuth2AuthorizationRequest}.
*/
@SuppressWarnings("unchecked")
private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpServletRequest request) {
HttpSession session = request.getSession(false);
Map<String, OAuth2AuthorizationRequest> authorizationRequests = session == null ? null :
Expand All @@ -114,4 +124,15 @@ private Map<String, OAuth2AuthorizationRequest> getAuthorizationRequests(HttpSer
}
return authorizationRequests;
}

/**
* Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking {@link OAuth2AuthorizationRequest#isExpired(Clock)}.
*
* @param clock the clock
* @since 5.2
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

package org.springframework.security.oauth2.client.web.server;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
Expand Down Expand Up @@ -77,6 +81,10 @@ public class DefaultServerOAuth2AuthorizationRequestResolver

private final StringKeyGenerator codeVerifierGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);

private Duration oAuth2AuthorizationRequestExpiresIn = Duration.ofSeconds(120);

private Clock clock = Clock.systemUTC();

/**
* Creates a new instance
* @param clientRegistrationRepository the repository to resolve the {@link ClientRegistration}
Expand Down Expand Up @@ -152,6 +160,7 @@ else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizat
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes)
.expiresAt(calculateExpiration())
.build();
}

Expand Down Expand Up @@ -236,4 +245,35 @@ private String createCodeChallenge(String codeVerifier) throws NoSuchAlgorithmEx
byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
}

private Instant calculateExpiration() {
return this.oAuth2AuthorizationRequestExpiresIn.isNegative() ? null
: Instant.now(this.clock).plus(this.oAuth2AuthorizationRequestExpiresIn);
}

/**
* Sets the {@link Duration} used in {@link Instant#plus(TemporalAmount)} when calculating the {@link
* OAuth2AuthorizationRequest#getExpiresAt()}, a negative {@link Duration} indicates that {@link
* OAuth2AuthorizationRequest} never expire.
*
* @param oAuth2AuthorizationRequestExpiresIn the {@link Duration} a {@link OAuth2AuthorizationRequest} is
* considered not expired
* @since 5.2
*/
public void setOAuth2AuthorizationRequestExpiresIn(Duration oAuth2AuthorizationRequestExpiresIn) {
Assert.notNull(oAuth2AuthorizationRequestExpiresIn, "oAuth2AuthorizationRequestExpiresIn cannot be null");
this.oAuth2AuthorizationRequestExpiresIn = oAuth2AuthorizationRequestExpiresIn;
}

/**
* Sets the {@link Clock} used in {@link Instant#now(Clock)} when setting the {@link
* OAuth2AuthorizationRequest#getExpiresAt()}.
*
* @param clock the clock
* @since 5.2
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.security.oauth2.client.web.server;

import java.time.Clock;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -45,6 +47,7 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
WebSessionOAuth2ServerAuthorizationRequestRepository.class.getName() + ".AUTHORIZATION_REQUEST";

private final String sessionAttributeName = DEFAULT_AUTHORIZATION_REQUEST_ATTR_NAME;
private Clock clock = Clock.systemUTC();

@Override
public Mono<OAuth2AuthorizationRequest> loadAuthorizationRequest(
Expand Down Expand Up @@ -123,19 +126,38 @@ private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRe

return getSessionAttributes(exchange)
.doOnNext(sessionAttrs -> {
Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(sessionAttrs);

if (stateToAuthzRequest == null) {
stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
stateToAuthzRequest = new HashMap<>();
}

cleaUpExpiredAuthorizationRequests(stateToAuthzRequest);

// No matter stateToAuthzRequest was in session or not, we should always put it into session again
// in case of redis or hazelcast session. #6215
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
}).flatMap(sessionAttrs -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
}

private void cleaUpExpiredAuthorizationRequests(
Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest) {
stateToAuthzRequest.values().removeIf(request -> request.isExpired(this.clock));
}

@SuppressWarnings("unchecked")
private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(Map<String, Object> sessionAttrs) {
return (Map<String, OAuth2AuthorizationRequest>) sessionAttrs.get(this.sessionAttributeName);
}

/**
* Sets the {@link Clock} used in {@link Instant#now(Clock)} when checking {@link OAuth2AuthorizationRequest#isExpired(Clock)}.
*
* @param clock the clock
* @since 5.2
*/
public void setClock(Clock clock) {
Assert.notNull(clock, "clock cannot be null");
this.clock = clock;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.security.oauth2.core.endpoint;

import java.time.Clock;
import java.time.Instant;
import org.springframework.security.core.SpringSecurityCoreVersion;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -55,6 +57,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Map<String, Object> attributes;
private Instant expiresAt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An Authorization Request does not have an expiresAt attribute/parameter as per spec. This really is an implementation detail and should be externalized to the expiring component.


private OAuth2AuthorizationRequest() {
}
Expand Down Expand Up @@ -154,6 +157,27 @@ public <T> T getAttribute(String name) {
return (T) this.getAttributes().get(name);
}

/**
* Returns the {@code Instant} when this {@link OAuth2AuthorizationRequest} expires, or {@code null} when it
* never expires.
*
* @since 5.2
* @return the {@code Instant} when this request expires, or {@code null} when it never expires.
*/
public Instant getExpiresAt() {
return expiresAt;
}

/**
* Returns whether the request is expired.
*
* @param clock the {@link Clock} to validate the expiration with.
* @return whether this request is expired.
*/
public boolean isExpired(Clock clock) {
return this.getExpiresAt() != null && this.getExpiresAt().isBefore(Instant.now(clock));
}

/**
* Returns the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
Expand Down Expand Up @@ -204,7 +228,8 @@ public static Builder from(OAuth2AuthorizationRequest authorizationRequest) {
.scopes(authorizationRequest.getScopes())
.state(authorizationRequest.getState())
.additionalParameters(authorizationRequest.getAdditionalParameters())
.attributes(authorizationRequest.getAttributes());
.attributes(authorizationRequest.getAttributes())
.expiresAt(authorizationRequest.getExpiresAt());
}

/**
Expand All @@ -221,6 +246,7 @@ public static class Builder {
private Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private Map<String, Object> attributes;
private Instant expiresAt;

private Builder(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
Expand Down Expand Up @@ -323,6 +349,18 @@ public Builder attributes(Map<String, Object> attributes) {
return this;
}

/**
* Sets when the request expires.
*
* @since 5.2
* @param expiresAt the instant when the request should be considered expired
* @return the {@link Builder}
*/
public Builder expiresAt(Instant expiresAt) {
this.expiresAt = expiresAt;
return this;
}

/**
* Sets the {@code URI} string representation of the OAuth 2.0 Authorization Request.
*
Expand Down Expand Up @@ -370,6 +408,7 @@ public OAuth2AuthorizationRequest build() {
authorizationRequest.attributes = Collections.unmodifiableMap(
CollectionUtils.isEmpty(this.attributes) ?
Collections.emptyMap() : new LinkedHashMap<>(this.attributes));
authorizationRequest.expiresAt = this.expiresAt;

return authorizationRequest;
}
Expand Down
Loading