Skip to content

Commit b77fc05

Browse files
Urbanssonaeneasr
andauthored
fix: return browser to 'return_to' when logging in without registered account using oidc. (ory#2496)
Close ory#2444 Co-authored-by: aeneasr <[email protected]>
1 parent 71387d5 commit b77fc05

File tree

16 files changed

+135
-10
lines changed

16 files changed

+135
-10
lines changed

selfservice/flow/login/flow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ func (f Flow) MarshalJSON() ([]byte, error) {
192192
}
193193

194194
func (f *Flow) SetReturnTo() {
195+
// Return to is already set, do not overwrite it.
196+
if len(f.ReturnTo) > 0 {
197+
return
198+
}
195199
if u, err := url.Parse(f.RequestURL); err == nil {
196200
f.ReturnTo = u.Query().Get("return_to")
197201
}

selfservice/flow/login/flow_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,13 @@ func TestFlowEncodeJSON(t *testing.T) {
158158
assert.EqualValues(t, "/bar", gjson.Get(jsonx.TestMarshalJSONString(t, &login.Flow{RequestURL: "https://foo.bar?return_to=/bar"}), "return_to").String())
159159
assert.EqualValues(t, "/bar", gjson.Get(jsonx.TestMarshalJSONString(t, login.Flow{RequestURL: "https://foo.bar?return_to=/bar"}), "return_to").String())
160160
}
161+
162+
func TestFlowDontOverrideReturnTo(t *testing.T) {
163+
f := &login.Flow{ReturnTo: "/foo"}
164+
f.SetReturnTo()
165+
assert.Equal(t, "/foo", f.ReturnTo)
166+
167+
f = &login.Flow{RequestURL: "https://foo.bar?return_to=/bar"}
168+
f.SetReturnTo()
169+
assert.Equal(t, "/bar", f.ReturnTo)
170+
}

selfservice/flow/login/handler.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,23 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
8686
admin.GET(RouteSubmitFlow, x.RedirectToPublicRoute(h.d))
8787
}
8888

89-
func (h *Handler) NewLoginFlow(w http.ResponseWriter, r *http.Request, ft flow.Type) (*Flow, error) {
89+
type FlowOption func(f *Flow)
90+
91+
func WithFlowReturnTo(returnTo string) FlowOption {
92+
return func(f *Flow) {
93+
f.ReturnTo = returnTo
94+
}
95+
}
96+
97+
func (h *Handler) NewLoginFlow(w http.ResponseWriter, r *http.Request, ft flow.Type, opts ...FlowOption) (*Flow, error) {
9098
conf := h.d.Config(r.Context())
9199
f, err := NewFlow(conf, conf.SelfServiceFlowLoginRequestLifespan(), h.d.GenerateCSRFToken(r), r, ft)
92100
if err != nil {
93101
return nil, err
94102
}
103+
for _, o := range opts {
104+
o(f)
105+
}
95106

96107
if f.RequestedAAL == "" {
97108
f.RequestedAAL = identity.AuthenticatorAssuranceLevel1

selfservice/flow/login/hook.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, a *
8282
// Verify the redirect URL before we do any other processing.
8383
c := e.d.Config(r.Context())
8484
returnTo, err := x.SecureRedirectTo(r, c.SelfServiceBrowserDefaultReturnTo(),
85+
x.SecureRedirectReturnTo(a.ReturnTo),
8586
x.SecureRedirectUseSourceURL(a.RequestURL),
8687
x.SecureRedirectAllowURLs(c.SelfServiceBrowserAllowedReturnToDomains()),
8788
x.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL()),

selfservice/flow/recovery/flow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ func (f Flow) MarshalJSON() ([]byte, error) {
191191
}
192192

193193
func (f *Flow) SetReturnTo() {
194+
// Return to is already set, do not overwrite it.
195+
if len(f.ReturnTo) > 0 {
196+
return
197+
}
194198
if u, err := url.Parse(f.RequestURL); err == nil {
195199
f.ReturnTo = u.Query().Get("return_to")
196200
}

selfservice/flow/recovery/flow_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,13 @@ func TestFromOldFlow(t *testing.T) {
101101
})
102102
}
103103
}
104+
105+
func TestFlowDontOverrideReturnTo(t *testing.T) {
106+
f := &recovery.Flow{ReturnTo: "/foo"}
107+
f.SetReturnTo()
108+
assert.Equal(t, "/foo", f.ReturnTo)
109+
110+
f = &recovery.Flow{RequestURL: "https://foo.bar?return_to=/bar"}
111+
f.SetReturnTo()
112+
assert.Equal(t, "/bar", f.ReturnTo)
113+
}

selfservice/flow/registration/flow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ func (f Flow) MarshalJSON() ([]byte, error) {
159159
}
160160

161161
func (f *Flow) SetReturnTo() {
162+
// Return to is already set, do not overwrite it.
163+
if len(f.ReturnTo) > 0 {
164+
return
165+
}
162166
if u, err := url.Parse(f.RequestURL); err == nil {
163167
f.ReturnTo = u.Query().Get("return_to")
164168
}

selfservice/flow/registration/flow_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,13 @@ func TestFlowEncodeJSON(t *testing.T) {
125125
assert.EqualValues(t, "/bar", gjson.Get(jsonx.TestMarshalJSONString(t, &registration.Flow{RequestURL: "https://foo.bar?return_to=/bar"}), "return_to").String())
126126
assert.EqualValues(t, "/bar", gjson.Get(jsonx.TestMarshalJSONString(t, registration.Flow{RequestURL: "https://foo.bar?return_to=/bar"}), "return_to").String())
127127
}
128+
129+
func TestFlowDontOverrideReturnTo(t *testing.T) {
130+
f := &registration.Flow{ReturnTo: "/foo"}
131+
f.SetReturnTo()
132+
assert.Equal(t, "/foo", f.ReturnTo)
133+
134+
f = &registration.Flow{RequestURL: "https://foo.bar?return_to=/bar"}
135+
f.SetReturnTo()
136+
assert.Equal(t, "/bar", f.ReturnTo)
137+
}

selfservice/flow/registration/handler.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,15 @@ func (h *Handler) RegisterAdminRoutes(admin *x.RouterAdmin) {
9191
admin.GET(RouteSubmitFlow, x.RedirectToPublicRoute(h.d))
9292
}
9393

94-
func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request, ft flow.Type) (*Flow, error) {
94+
type FlowOption func(f *Flow)
9595

96+
func WithFlowReturnTo(returnTo string) FlowOption {
97+
return func(f *Flow) {
98+
f.ReturnTo = returnTo
99+
}
100+
}
101+
102+
func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request, ft flow.Type, opts ...FlowOption) (*Flow, error) {
96103
if !h.d.Config(r.Context()).SelfServiceFlowRegistrationEnabled() {
97104
return nil, errors.WithStack(ErrRegistrationDisabled)
98105
}
@@ -101,6 +108,9 @@ func (h *Handler) NewRegistrationFlow(w http.ResponseWriter, r *http.Request, ft
101108
if err != nil {
102109
return nil, err
103110
}
111+
for _, o := range opts {
112+
o(f)
113+
}
104114

105115
for _, s := range h.d.RegistrationStrategies(r.Context()) {
106116
if err := s.PopulateRegistrationMethod(r, f); err != nil {

selfservice/flow/registration/hook.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
128128
// Verify the redirect URL before we do any other processing.
129129
c := e.d.Config(r.Context())
130130
returnTo, err := x.SecureRedirectTo(r, c.SelfServiceBrowserDefaultReturnTo(),
131+
x.SecureRedirectReturnTo(a.ReturnTo),
131132
x.SecureRedirectUseSourceURL(a.RequestURL),
132133
x.SecureRedirectAllowURLs(c.SelfServiceBrowserAllowedReturnToDomains()),
133134
x.SecureRedirectAllowSelfServiceURLs(c.SelfPublicURL()),

selfservice/flow/settings/flow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ func (f Flow) MarshalJSON() ([]byte, error) {
203203
}
204204

205205
func (f *Flow) SetReturnTo() {
206+
// Return to is already set, do not overwrite it.
207+
if len(f.ReturnTo) > 0 {
208+
return
209+
}
206210
if u, err := url.Parse(f.RequestURL); err == nil {
207211
f.ReturnTo = u.Query().Get("return_to")
208212
}

selfservice/flow/settings/flow_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,13 @@ func TestFlowEncodeJSON(t *testing.T) {
166166
assert.EqualValues(t, "/bar", gjson.Get(jsonx.TestMarshalJSONString(t, &settings.Flow{RequestURL: "https://foo.bar?return_to=/bar"}), "return_to").String())
167167
assert.EqualValues(t, "/bar", gjson.Get(jsonx.TestMarshalJSONString(t, settings.Flow{RequestURL: "https://foo.bar?return_to=/bar"}), "return_to").String())
168168
}
169+
170+
func TestFlowDontOverrideReturnTo(t *testing.T) {
171+
f := &settings.Flow{ReturnTo: "/foo"}
172+
f.SetReturnTo()
173+
assert.Equal(t, "/foo", f.ReturnTo)
174+
175+
f = &settings.Flow{RequestURL: "https://foo.bar?return_to=/bar"}
176+
f.SetReturnTo()
177+
assert.Equal(t, "/bar", f.ReturnTo)
178+
}

selfservice/strategy/oidc/strategy_login.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,16 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
8585
// not need additional consent/login.
8686

8787
// This is kinda hacky but the only way to ensure seamless login/registration flows when using OIDC.
88-
8988
s.d.Logger().WithField("provider", provider.Config().ID).WithField("subject", claims.Subject).Debug("Received successful OpenID Connect callback but user is not registered. Re-initializing registration flow now.")
9089

90+
// If return_to was set before, we need to preserve it.
91+
var opts []registration.FlowOption
92+
if len(a.ReturnTo) > 0 {
93+
opts = append(opts, registration.WithFlowReturnTo(a.ReturnTo))
94+
}
95+
9196
// This flow only works for browsers anyways.
92-
aa, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, flow.TypeBrowser)
97+
aa, err := s.d.RegistrationHandler().NewRegistrationFlow(w, r, flow.TypeBrowser, opts...)
9398
if err != nil {
9499
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
95100
}

selfservice/strategy/oidc/strategy_registration.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,14 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, a
177177
WithField("subject", claims.Subject).
178178
Debug("Received successful OpenID Connect callback but user is already registered. Re-initializing login flow now.")
179179

180+
// If return_to was set before, we need to preserve it.
181+
var opts []login.FlowOption
182+
if len(a.ReturnTo) > 0 {
183+
opts = append(opts, login.WithFlowReturnTo(a.ReturnTo))
184+
}
185+
180186
// This endpoint only handles browser flow at the moment.
181-
ar, err := s.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser)
187+
ar, err := s.d.LoginHandler().NewLoginFlow(w, r, flow.TypeBrowser, opts...)
182188
if err != nil {
183189
return nil, s.handleError(w, r, a, provider.Config().ID, nil, err)
184190
}

selfservice/strategy/oidc/strategy_test.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ func TestStrategy(t *testing.T) {
5858
claims idTokenClaims
5959
scope []string
6060
)
61-
6261
remoteAdmin, remotePublic, hydraIntegrationTSURL := newHydra(t, &subject, &claims, &scope)
6362
returnTS := newReturnTs(t, reg)
63+
conf.MustSet(config.ViperKeyURLsAllowedReturnToDomains, []string{returnTS.URL})
6464
uiTS := newUI(t, reg)
6565
errTS := testhelpers.NewErrorTestServer(t, reg)
6666
routerP := x.NewRouterPublic()
@@ -368,6 +368,20 @@ func TestStrategy(t *testing.T) {
368368
})
369369
})
370370

371+
t.Run("case=login without registered account with return_to", func(t *testing.T) {
372+
subject = "[email protected]"
373+
scope = []string{"openid"}
374+
returnTo := "/foo"
375+
376+
t.Run("case=should pass login", func(t *testing.T) {
377+
r := newLoginFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, returnTo), time.Minute)
378+
action := afv(t, r.ID, "valid")
379+
res, body := makeRequest(t, "valid", action, url.Values{})
380+
assert.True(t, strings.HasSuffix(res.Request.URL.String(), returnTo))
381+
ai(t, res, body)
382+
})
383+
})
384+
371385
t.Run("case=register and register again but login", func(t *testing.T) {
372386
subject = "[email protected]"
373387
scope = []string{"openid"}
@@ -385,6 +399,15 @@ func TestStrategy(t *testing.T) {
385399
res, body := makeRequest(t, "valid", action, url.Values{})
386400
ai(t, res, body)
387401
})
402+
403+
t.Run("case=should pass third time registration with return to", func(t *testing.T) {
404+
returnTo := "/foo"
405+
r := newLoginFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, returnTo), time.Minute)
406+
action := afv(t, r.ID, "valid")
407+
res, body := makeRequest(t, "valid", action, url.Values{})
408+
assert.True(t, strings.HasSuffix(res.Request.URL.String(), returnTo))
409+
ai(t, res, body)
410+
})
388411
})
389412

390413
t.Run("case=register, merge, and complete data", func(t *testing.T) {

x/http_secure_redirect.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
type secureRedirectOptions struct {
2121
allowlist []url.URL
2222
defaultReturnTo *url.URL
23+
returnTo string
2324
sourceURL string
2425
}
2526

@@ -40,6 +41,13 @@ func SecureRedirectUseSourceURL(source string) SecureRedirectOption {
4041
}
4142
}
4243

44+
// SecureRedirectReturnTo uses the provided URL to redirect the user to it.
45+
func SecureRedirectReturnTo(returnTo string) SecureRedirectOption {
46+
return func(o *secureRedirectOptions) {
47+
o.returnTo = returnTo
48+
}
49+
}
50+
4351
// SecureRedirectAllowSelfServiceURLs allows the caller to define `?return_to=` values
4452
// which contain the server's URL and `/self-service` path prefix. Useful for redirecting
4553
// to the login endpoint, for example.
@@ -81,14 +89,18 @@ func SecureRedirectTo(r *http.Request, defaultReturnTo *url.URL, opts ...SecureR
8189
if o.sourceURL != "" {
8290
source, err = url.ParseRequestURI(o.sourceURL)
8391
if err != nil {
84-
return nil, herodot.ErrInternalServerError.WithWrap(err).WithReasonf("Unable to parse the original request URL: %s", err)
92+
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReasonf("Unable to parse the original request URL: %s", err))
8593
}
8694
}
8795

88-
if len(source.Query().Get("return_to")) == 0 {
96+
rawReturnTo := stringsx.Coalesce(o.returnTo, source.Query().Get("return_to"))
97+
if rawReturnTo == "" {
8998
return o.defaultReturnTo, nil
90-
} else if returnTo, err = url.Parse(source.Query().Get("return_to")); err != nil {
91-
return nil, herodot.ErrInternalServerError.WithWrap(err).WithReasonf("Unable to parse the return_to query parameter as an URL: %s", err)
99+
}
100+
101+
returnTo, err = url.Parse(rawReturnTo)
102+
if err != nil {
103+
return nil, errors.WithStack(herodot.ErrInternalServerError.WithWrap(err).WithReasonf("Unable to parse the return_to query parameter as an URL: %s", err))
92104
}
93105

94106
returnTo.Host = stringsx.Coalesce(returnTo.Host, o.defaultReturnTo.Host)

0 commit comments

Comments
 (0)