Skip to content

Commit d56586b

Browse files
authored
fix: include flow id in use recovery token query (#2679)
This PR adds the `selfservice_recovery_flow_id` to the query used when "using" a token in the recovery flow. This PR also adds a new enum field for `identity_recovery_tokens` to distinguish the two flows: admin versus self-service recovery. BREAKING CHANGES: This patch invalidates recovery flows initiated using the Admin API. Please re-generate any admin-generated recovery flows and tokens.
1 parent 1cd2672 commit d56586b

15 files changed

+143
-50
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ALTER TABLE identity_recovery_tokens
2+
DROP token_type;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ALTER TABLE identity_recovery_tokens
2+
ADD token_type int NOT NULL DEFAULT 0;

persistence/sql/migrations/sql/20220824165300000001_populate_flow_type_in_recovery_tokens.down.sql

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
UPDATE identity_recovery_tokens
2+
SET token_type = 1
3+
WHERE selfservice_recovery_flow_id IS NULL;
4+
5+
UPDATE identity_recovery_tokens
6+
SET token_type = 2
7+
WHERE selfservice_recovery_flow_id IS NOT NULL;

persistence/sql/migrations/sql/20220824165300000002_add_flow_type_check_constraint.down.sql

Whitespace-only changes.

persistence/sql/migrations/sql/20220824165300000002_add_flow_type_check_constraint.sqlite3.down.sql

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
-- SQLITE does not support Check constraints in all cases
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ALTER TABLE identity_recovery_tokens
2+
ADD CONSTRAINT identity_recovery_tokens_token_type_ck CHECK (token_type = 1 OR token_type = 2);

persistence/sql/persister_recovery.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (p *Persister) CreateRecoveryToken(ctx context.Context, token *link.Recover
6565
return nil
6666
}
6767

68-
func (p *Persister) UseRecoveryToken(ctx context.Context, token string) (*link.RecoveryToken, error) {
68+
func (p *Persister) UseRecoveryToken(ctx context.Context, fID uuid.UUID, token string) (*link.RecoveryToken, error) {
6969
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.UseRecoveryToken")
7070
defer span.End()
7171

@@ -74,7 +74,7 @@ func (p *Persister) UseRecoveryToken(ctx context.Context, token string) (*link.R
7474
nid := p.NetworkID(ctx)
7575
if err := sqlcon.HandleError(p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) (err error) {
7676
for _, secret := range p.r.Config().SecretsSession(ctx) {
77-
if err = tx.Where("token = ? AND nid = ? AND NOT used", p.hmacValueWithSecret(ctx, token, secret), nid).First(&rt); err != nil {
77+
if err = tx.Where("token = ? AND nid = ? AND NOT used AND selfservice_recovery_flow_id = ?", p.hmacValueWithSecret(ctx, token, secret), nid, fID).First(&rt); err != nil {
7878
if !errors.Is(sqlcon.HandleError(err), sqlcon.ErrNoRows) {
7979
return err
8080
}

selfservice/strategy/link/persistence.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ package link
22

33
import (
44
"context"
5+
6+
"github.com/gofrs/uuid"
57
)
68

79
type (
810
RecoveryTokenPersister interface {
911
CreateRecoveryToken(ctx context.Context, token *RecoveryToken) error
10-
UseRecoveryToken(ctx context.Context, token string) (*RecoveryToken, error)
12+
UseRecoveryToken(ctx context.Context, fID uuid.UUID, token string) (*RecoveryToken, error)
1113
DeleteRecoveryToken(ctx context.Context, token string) error
1214
}
1315

selfservice/strategy/link/strategy_recovery.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (s *Strategy) createRecoveryLink(w http.ResponseWriter, r *http.Request, _
167167
return
168168
}
169169

170-
token := NewRecoveryToken(id.ID, expiresIn)
170+
token := NewAdminRecoveryToken(id.ID, req.ID, expiresIn)
171171
if err := s.d.RecoveryTokenPersister().CreateRecoveryToken(r.Context(), token); err != nil {
172172
s.d.Writer().WriteError(w, r, err)
173173
return
@@ -222,7 +222,7 @@ func (s *Strategy) Recover(w http.ResponseWriter, r *http.Request, f *recovery.F
222222
return s.HandleRecoveryError(w, r, nil, body, err)
223223
}
224224

225-
return s.recoveryUseToken(w, r, body)
225+
return s.recoveryUseToken(w, r, f.ID, body)
226226
}
227227

228228
if _, err := s.d.SessionManager().FetchFromRequest(r.Context(), r); err == nil {
@@ -313,8 +313,8 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request,
313313
return errors.WithStack(flow.ErrCompletedByStrategy)
314314
}
315315

316-
func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, body *recoverySubmitPayload) error {
317-
token, err := s.d.RecoveryTokenPersister().UseRecoveryToken(r.Context(), body.Token)
316+
func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, fID uuid.UUID, body *recoverySubmitPayload) error {
317+
token, err := s.d.RecoveryTokenPersister().UseRecoveryToken(r.Context(), fID, body.Token)
318318
if err != nil {
319319
if errors.Is(err, sqlcon.ErrNoRows) {
320320
return s.retryRecoveryFlowWithMessage(w, r, flow.TypeBrowser, text.NewErrorValidationRecoveryTokenInvalidOrAlreadyUsed())
@@ -351,7 +351,7 @@ func (s *Strategy) recoveryUseToken(w http.ResponseWriter, r *http.Request, body
351351
}
352352

353353
// mark address as verified only for a self-service flow
354-
if token.FlowID.Valid {
354+
if token.TokenType == RecoveryTokenTypeSelfService {
355355
if err := s.markRecoveryAddressVerified(w, r, f, recovered, token.RecoveryAddress); err != nil {
356356
return s.HandleRecoveryError(w, r, f, body, err)
357357
}

selfservice/strategy/link/strategy_recovery_test.go

+82-28
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313
"time"
1414

15+
"github.com/ory/kratos/driver"
1516
"github.com/ory/kratos/session"
1617

1718
"github.com/davecgh/go-spew/spew"
@@ -56,6 +57,23 @@ func init() {
5657
corpx.RegisterFakes()
5758
}
5859

60+
func createIdentityToRecover(t *testing.T, reg *driver.RegistryDefault, email string) *identity.Identity {
61+
var id = &identity.Identity{
62+
Credentials: map[identity.CredentialsType]identity.Credentials{
63+
"password": {Type: "password", Identifiers: []string{email}, Config: sqlxx.JSONRawMessage(`{"hashed_password":"foo"}`)}},
64+
Traits: identity.Traits(fmt.Sprintf(`{"email":"%s"}`, email)),
65+
SchemaID: config.DefaultIdentityTraitsSchemaID,
66+
}
67+
require.NoError(t, reg.IdentityManager().Create(context.Background(), id, identity.ManagerAllowWriteProtectedTraits))
68+
69+
addr, err := reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, email)
70+
assert.NoError(t, err)
71+
assert.False(t, addr.Verified)
72+
assert.Nil(t, addr.VerifiedAt)
73+
assert.Equal(t, identity.VerifiableAddressStatusPending, addr.Status)
74+
return id
75+
}
76+
5977
func TestAdminStrategy(t *testing.T) {
6078
ctx := context.Background()
6179
conf, reg := internal.NewFastRegistryWithMocks(t)
@@ -183,6 +201,59 @@ func TestAdminStrategy(t *testing.T) {
183201
assert.Nil(t, addr.VerifiedAt)
184202
assert.Equal(t, identity.VerifiableAddressStatusPending, addr.Status)
185203
})
204+
205+
t.Run("case=should not be able to use code from different flow", func(t *testing.T) {
206+
email := strings.ToLower(testhelpers.RandomEmail())
207+
id := createIdentityToRecover(t, reg, email)
208+
209+
rl1, _, err := adminSDK.V0alpha2Api.
210+
AdminCreateSelfServiceRecoveryLink(context.Background()).
211+
AdminCreateSelfServiceRecoveryLinkBody(kratos.AdminCreateSelfServiceRecoveryLinkBody{
212+
IdentityId: id.ID.String(),
213+
}).
214+
Execute()
215+
require.NoError(t, err)
216+
217+
checkLink(t, rl1, time.Now().Add(conf.SelfServiceFlowRecoveryRequestLifespan(ctx)+time.Second))
218+
219+
rl2, _, err := adminSDK.V0alpha2Api.
220+
AdminCreateSelfServiceRecoveryLink(context.Background()).
221+
AdminCreateSelfServiceRecoveryLinkBody(kratos.AdminCreateSelfServiceRecoveryLinkBody{
222+
IdentityId: id.ID.String(),
223+
}).
224+
Execute()
225+
require.NoError(t, err)
226+
227+
checkLink(t, rl2, time.Now().Add(conf.SelfServiceFlowRecoveryRequestLifespan(ctx)+time.Second))
228+
229+
recoveryUrl1, err := url.Parse(rl1.RecoveryLink)
230+
require.NoError(t, err)
231+
232+
recoveryUrl2, err := url.Parse(rl2.RecoveryLink)
233+
require.NoError(t, err)
234+
235+
token1 := recoveryUrl1.Query().Get("token")
236+
require.NotEmpty(t, token1)
237+
token2 := recoveryUrl2.Query().Get("token")
238+
require.NotEmpty(t, token2)
239+
require.NotEqual(t, token1, token2)
240+
241+
values := recoveryUrl1.Query()
242+
243+
values.Set("token", token2)
244+
245+
recoveryUrl1.RawQuery = values.Encode()
246+
247+
action := recoveryUrl1.String()
248+
// Submit the modified link with token from rl2 and flow from rl1
249+
res, err := publicTS.Client().Get(action)
250+
require.NoError(t, err)
251+
body := ioutilx.MustReadAll(res.Body)
252+
253+
action = gjson.GetBytes(body, "ui.action").String()
254+
require.NotEmpty(t, action)
255+
assert.Equal(t, "The recovery token is invalid or has already been used. Please retry the flow.", gjson.GetBytes(body, "ui.messages.0.text").String())
256+
})
186257
}
187258

188259
func TestRecovery(t *testing.T) {
@@ -197,23 +268,6 @@ func TestRecovery(t *testing.T) {
197268

198269
public, _, publicRouter, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
199270

200-
var createIdentityToRecover = func(email string) *identity.Identity {
201-
var id = &identity.Identity{
202-
Credentials: map[identity.CredentialsType]identity.Credentials{
203-
"password": {Type: "password", Identifiers: []string{email}, Config: sqlxx.JSONRawMessage(`{"hashed_password":"foo"}`)}},
204-
Traits: identity.Traits(fmt.Sprintf(`{"email":"%s"}`, email)),
205-
SchemaID: config.DefaultIdentityTraitsSchemaID,
206-
}
207-
require.NoError(t, reg.IdentityManager().Create(context.Background(), id, identity.ManagerAllowWriteProtectedTraits))
208-
209-
addr, err := reg.IdentityPool().FindVerifiableAddressByValue(context.Background(), identity.VerifiableAddressTypeEmail, email)
210-
assert.NoError(t, err)
211-
assert.False(t, addr.Verified)
212-
assert.Nil(t, addr.VerifiedAt)
213-
assert.Equal(t, identity.VerifiableAddressStatusPending, addr.Status)
214-
return id
215-
}
216-
217271
var expect = func(t *testing.T, hc *http.Client, isAPI, isSPA bool, values func(url.Values), c int) string {
218272
if hc == nil {
219273
hc = testhelpers.NewDebugClient(t)
@@ -414,23 +468,23 @@ func TestRecovery(t *testing.T) {
414468

415469
t.Run("type=browser", func(t *testing.T) {
416470
email := "[email protected]"
417-
createIdentityToRecover(email)
471+
createIdentityToRecover(t, reg, email)
418472
check(t, expectSuccess(t, nil, false, false, func(v url.Values) {
419473
v.Set("email", email)
420474
}), email, false)
421475
})
422476

423477
t.Run("type=spa", func(t *testing.T) {
424478
email := "[email protected]"
425-
createIdentityToRecover(email)
479+
createIdentityToRecover(t, reg, email)
426480
check(t, expectSuccess(t, nil, true, true, func(v url.Values) {
427481
v.Set("email", email)
428482
}), email, true)
429483
})
430484

431485
t.Run("type=api", func(t *testing.T) {
432486
email := "[email protected]"
433-
createIdentityToRecover(email)
487+
createIdentityToRecover(t, reg, email)
434488
check(t, expectSuccess(t, nil, true, false, func(v url.Values) {
435489
v.Set("email", email)
436490
}), email, true)
@@ -487,7 +541,7 @@ func TestRecovery(t *testing.T) {
487541

488542
t.Run("type=browser", func(t *testing.T) {
489543
email := "[email protected]"
490-
createIdentityToRecover(email)
544+
createIdentityToRecover(t, reg, email)
491545
check(t, expectSuccess(t, nil, false, false, func(v url.Values) {
492546
v.Set("email", email)
493547
}), email, "")
@@ -496,7 +550,7 @@ func TestRecovery(t *testing.T) {
496550
t.Run("type=browser set return_to", func(t *testing.T) {
497551
email := "[email protected]"
498552
returnTo := "https://www.ory.sh"
499-
createIdentityToRecover(email)
553+
createIdentityToRecover(t, reg, email)
500554

501555
hc := testhelpers.NewClientWithCookies(t)
502556
hc.Transport = testhelpers.NewTransportWithLogger(http.DefaultTransport, t).RoundTripper
@@ -518,15 +572,15 @@ func TestRecovery(t *testing.T) {
518572

519573
t.Run("type=spa", func(t *testing.T) {
520574
email := "[email protected]"
521-
createIdentityToRecover(email)
575+
createIdentityToRecover(t, reg, email)
522576
check(t, expectSuccess(t, nil, true, true, func(v url.Values) {
523577
v.Set("email", email)
524578
}), email, "")
525579
})
526580

527581
t.Run("type=api", func(t *testing.T) {
528582
email := "[email protected]"
529-
createIdentityToRecover(email)
583+
createIdentityToRecover(t, reg, email)
530584
check(t, expectSuccess(t, nil, true, false, func(v url.Values) {
531585
v.Set("email", email)
532586
}), email, "")
@@ -563,7 +617,7 @@ func TestRecovery(t *testing.T) {
563617
}
564618

565619
email := x.NewUUID().String() + "@ory.sh"
566-
id := createIdentityToRecover(email)
620+
id := createIdentityToRecover(t, reg, email)
567621

568622
t.Run("case=unauthenticated", func(t *testing.T) {
569623
var values = func(v url.Values) {
@@ -604,7 +658,7 @@ func TestRecovery(t *testing.T) {
604658

605659
recoveryEmail := strings.ToLower(testhelpers.RandomEmail())
606660
email := recoveryEmail
607-
id := createIdentityToRecover(email)
661+
id := createIdentityToRecover(t, reg, email)
608662

609663
sess, err := session.NewActiveSession(ctx, id, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
610664
require.NoError(t, err)
@@ -659,7 +713,7 @@ func TestRecovery(t *testing.T) {
659713

660714
t.Run("description=should not be able to use an outdated link", func(t *testing.T) {
661715
recoveryEmail := "[email protected]"
662-
createIdentityToRecover(recoveryEmail)
716+
createIdentityToRecover(t, reg, recoveryEmail)
663717
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Millisecond*200)
664718
t.Cleanup(func() {
665719
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Minute)
@@ -685,7 +739,7 @@ func TestRecovery(t *testing.T) {
685739

686740
t.Run("description=should not be able to use an outdated flow", func(t *testing.T) {
687741
recoveryEmail := "[email protected]"
688-
createIdentityToRecover(recoveryEmail)
742+
createIdentityToRecover(t, reg, recoveryEmail)
689743
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Millisecond*200)
690744
t.Cleanup(func() {
691745
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryRequestLifespan, time.Minute)

selfservice/strategy/link/test/persistence.go

+19-13
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
3434
conf.MustSet(ctx, config.ViperKeySecretsDefault, []string{"secret-a", "secret-b"})
3535

3636
t.Run("token=recovery", func(t *testing.T) {
37-
t.Run("case=should error when the recovery token does not exist", func(t *testing.T) {
38-
_, err := p.UseRecoveryToken(ctx, "i-do-not-exist")
39-
require.Error(t, err)
40-
})
4137

42-
newRecoveryToken := func(t *testing.T, email string) *link.RecoveryToken {
38+
newRecoveryToken := func(t *testing.T, email string) (*link.RecoveryToken, *recovery.Flow) {
4339
var req recovery.Flow
4440
require.NoError(t, faker.FakeData(&req))
4541
require.NoError(t, p.CreateRecoveryFlow(ctx, &req))
@@ -52,42 +48,52 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
5248

5349
require.NoError(t, p.CreateIdentity(ctx, &i))
5450

55-
return &link.RecoveryToken{Token: x.NewUUID().String(), FlowID: uuid.NullUUID{UUID: req.ID, Valid: true},
51+
return &link.RecoveryToken{
52+
Token: x.NewUUID().String(),
53+
FlowID: uuid.NullUUID{UUID: req.ID, Valid: true},
5654
RecoveryAddress: &i.RecoveryAddresses[0],
5755
ExpiresAt: time.Now(),
5856
IssuedAt: time.Now(),
5957
IdentityID: i.ID,
60-
}
58+
TokenType: link.RecoveryTokenTypeAdmin,
59+
}, &req
6160
}
6261

6362
t.Run("case=should error when the recovery token does not exist", func(t *testing.T) {
64-
_, err := p.UseRecoveryToken(ctx, "i-do-not-exist")
63+
_, err := p.UseRecoveryToken(ctx, x.NewUUID(), "i-do-not-exist")
6564
require.Error(t, err)
6665
})
6766

6867
t.Run("case=should create a new recovery token", func(t *testing.T) {
69-
token := newRecoveryToken(t, "[email protected]")
68+
token, _ := newRecoveryToken(t, "[email protected]")
7069
require.NoError(t, p.CreateRecoveryToken(ctx, token))
7170
})
7271

72+
t.Run("case=should error when token is used with different flow id", func(t *testing.T) {
73+
token, _ := newRecoveryToken(t, "[email protected]")
74+
require.NoError(t, p.CreateRecoveryToken(ctx, token))
75+
_, err := p.UseRecoveryToken(ctx, x.NewUUID(), token.Token)
76+
require.Error(t, err)
77+
})
78+
7379
t.Run("case=should create a recovery token and use it", func(t *testing.T) {
74-
expected := newRecoveryToken(t, "[email protected]")
80+
expected, f := newRecoveryToken(t, "[email protected]")
7581
require.NoError(t, p.CreateRecoveryToken(ctx, expected))
7682

7783
t.Run("not work on another network", func(t *testing.T) {
7884
_, p := testhelpers.NewNetwork(t, ctx, p)
79-
_, err := p.UseRecoveryToken(ctx, expected.Token)
85+
_, err := p.UseRecoveryToken(ctx, f.ID, expected.Token)
8086
require.ErrorIs(t, err, sqlcon.ErrNoRows)
8187
})
8288

83-
actual, err := p.UseRecoveryToken(ctx, expected.Token)
89+
actual, err := p.UseRecoveryToken(ctx, f.ID, expected.Token)
8490
require.NoError(t, err)
8591
assert.Equal(t, nid, actual.NID)
8692
assert.Equal(t, expected.IdentityID, actual.IdentityID)
8793
assert.NotEqual(t, expected.Token, actual.Token)
8894
assert.EqualValues(t, expected.FlowID, actual.FlowID)
8995

90-
_, err = p.UseRecoveryToken(ctx, expected.Token)
96+
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
9197
require.Error(t, err)
9298
})
9399

0 commit comments

Comments
 (0)