Skip to content

Commit 1689bb9

Browse files
authored
fix: do not invalidate recovery addr on update (#2699)
1 parent a0d2bfb commit 1689bb9

9 files changed

+238
-28
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/ory/kratos
22

3-
go 1.17
3+
go 1.18
44

55
replace (
66
github.com/bradleyjkemp/cupaloy/v2 => github.com/aeneasr/cupaloy/v2 v2.6.1-0.20210924214125-3dfdd01210a3

identity/identity_recovery.go

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package identity
22

33
import (
44
"context"
5+
"fmt"
56
"time"
67

78
"github.com/gofrs/uuid"
@@ -55,6 +56,11 @@ func (a RecoveryAddress) ValidateNID() error {
5556
return nil
5657
}
5758

59+
// Hash returns a unique string representation for the recovery address.
60+
func (a RecoveryAddress) Hash() string {
61+
return fmt.Sprintf("%v|%v|%v|%v", a.Value, a.Via, a.IdentityID, a.NID)
62+
}
63+
5864
func NewRecoveryEmailAddress(
5965
value string,
6066
identity uuid.UUID,

identity/identity_recovery_test.go

+40-1
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package identity
22

33
import (
44
"testing"
5+
"time"
56

67
"github.com/gofrs/uuid"
7-
88
"github.com/stretchr/testify/assert"
99

1010
"github.com/ory/kratos/x"
@@ -19,3 +19,42 @@ func TestNewRecoveryEmailAddress(t *testing.T) {
1919
assert.Equal(t, iid, a.IdentityID)
2020
assert.Equal(t, uuid.Nil, a.ID)
2121
}
22+
23+
// TestRecoveryAddress_Hash tests that the hash considers all fields that are
24+
// written to the database (ignoring some well-known fields like the ID or
25+
// timestamps).
26+
func TestRecoveryAddress_Hash(t *testing.T) {
27+
cases := []struct {
28+
name string
29+
a RecoveryAddress
30+
}{
31+
{
32+
name: "full fields",
33+
a: RecoveryAddress{
34+
ID: x.NewUUID(),
35+
36+
Via: AddressTypeEmail,
37+
CreatedAt: time.Now(),
38+
UpdatedAt: time.Now(),
39+
IdentityID: x.NewUUID(),
40+
NID: x.NewUUID(),
41+
},
42+
}, {
43+
name: "empty fields",
44+
a: RecoveryAddress{},
45+
}, {
46+
name: "constructor",
47+
a: *NewRecoveryEmailAddress("[email protected]", x.NewUUID()),
48+
},
49+
}
50+
51+
for _, tc := range cases {
52+
t.Run("case="+tc.name, func(t *testing.T) {
53+
assert.Equal(t,
54+
reflectiveHash(tc.a),
55+
tc.a.Hash(),
56+
)
57+
})
58+
}
59+
60+
}

identity/identity_verification.go

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package identity
22

33
import (
44
"context"
5+
"fmt"
56
"time"
67

78
"github.com/gofrs/uuid"
@@ -129,3 +130,8 @@ func (a VerifiableAddress) GetNID() uuid.UUID {
129130
func (a VerifiableAddress) ValidateNID() error {
130131
return nil
131132
}
133+
134+
// Hash returns a unique string representation for the recovery address.
135+
func (a VerifiableAddress) Hash() string {
136+
return fmt.Sprintf("%v|%v|%v|%v|%v|%v", a.Value, a.Verified, a.Via, a.Status, a.IdentityID, a.NID)
137+
}

identity/identity_verification_test.go

+76
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package identity
22

33
import (
4+
"fmt"
5+
"reflect"
6+
"strings"
47
"testing"
8+
"time"
59

610
"github.com/gofrs/uuid"
711

@@ -25,3 +29,75 @@ func TestNewVerifiableEmailAddress(t *testing.T) {
2529
assert.Equal(t, iid, a.IdentityID)
2630
assert.Equal(t, uuid.Nil, a.ID)
2731
}
32+
33+
var tagsIgnoredForHashing = map[string]struct{}{
34+
"id": {},
35+
"created_at": {},
36+
"updated_at": {},
37+
"verified_at": {},
38+
}
39+
40+
func reflectiveHash(record any) string {
41+
var (
42+
val = reflect.ValueOf(record)
43+
typ = reflect.TypeOf(record)
44+
values = []string{}
45+
)
46+
for i := 0; i < val.NumField(); i++ {
47+
dbTag, ok := typ.Field(i).Tag.Lookup("db")
48+
if !ok {
49+
continue
50+
}
51+
if _, ignore := tagsIgnoredForHashing[dbTag]; ignore {
52+
continue
53+
}
54+
if !val.Field(i).CanInterface() {
55+
continue
56+
}
57+
values = append(values, fmt.Sprintf("%v", val.Field(i).Interface()))
58+
}
59+
return strings.Join(values, "|")
60+
}
61+
62+
// TestVerifiableAddress_Hash tests that the hash considers all fields that are
63+
// written to the database (ignoring some well-known fields like the ID or
64+
// timestamps).
65+
func TestVerifiableAddress_Hash(t *testing.T) {
66+
now := sqlxx.NullTime(time.Now())
67+
cases := []struct {
68+
name string
69+
a VerifiableAddress
70+
}{
71+
{
72+
name: "full fields",
73+
a: VerifiableAddress{
74+
ID: x.NewUUID(),
75+
76+
Verified: false,
77+
Via: AddressTypeEmail,
78+
Status: VerifiableAddressStatusPending,
79+
VerifiedAt: &now,
80+
CreatedAt: time.Now(),
81+
UpdatedAt: time.Now(),
82+
IdentityID: x.NewUUID(),
83+
NID: x.NewUUID(),
84+
},
85+
}, {
86+
name: "empty fields",
87+
a: VerifiableAddress{},
88+
}, {
89+
name: "constructor",
90+
a: *NewVerifiableEmailAddress("[email protected]", x.NewUUID()),
91+
},
92+
}
93+
94+
for _, tc := range cases {
95+
t.Run("case="+tc.name, func(t *testing.T) {
96+
assert.Equal(t,
97+
reflectiveHash(tc.a),
98+
tc.a.Hash(),
99+
)
100+
})
101+
}
102+
103+
}

package-lock.json

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

persistence/sql/persister_identity.go

+80-21
Original file line numberDiff line numberDiff line change
@@ -192,24 +192,83 @@ func (p *Persister) createVerifiableAddresses(ctx context.Context, i *identity.I
192192
defer span.End()
193193

194194
for k := range i.VerifiableAddresses {
195-
i.VerifiableAddresses[k].IdentityID = i.ID
196-
i.VerifiableAddresses[k].NID = p.NetworkID(ctx)
197-
i.VerifiableAddresses[k].Value = stringToLowerTrim(i.VerifiableAddresses[k].Value)
198195
if err := p.GetConnection(ctx).Create(&i.VerifiableAddresses[k]); err != nil {
199196
return err
200197
}
201198
}
202199
return nil
203200
}
204201

202+
func updateAssociation[T interface {
203+
Hash() string
204+
}](ctx context.Context, p *Persister, i *identity.Identity, inID []T) error {
205+
var inDB []T
206+
if err := p.GetConnection(ctx).
207+
Where("identity_id = ? AND nid = ?", i.ID, p.NetworkID(ctx)).
208+
Order("id ASC").
209+
All(&inDB); err != nil {
210+
211+
return sqlcon.HandleError(err)
212+
}
213+
214+
newAssocs := make(map[string]*T)
215+
oldAssocs := make(map[string]*T)
216+
for i, a := range inID {
217+
newAssocs[a.Hash()] = &inID[i]
218+
}
219+
for i, a := range inDB {
220+
oldAssocs[a.Hash()] = &inDB[i]
221+
}
222+
223+
// Subtle: we delete the old associations from the DB first, because else
224+
// they could cause UNIQUE constraints to fail on insert.
225+
for h, a := range oldAssocs {
226+
if _, found := newAssocs[h]; found {
227+
newAssocs[h] = nil // Ignore associations that are already in the db.
228+
} else {
229+
if err := p.GetConnection(ctx).Destroy(a); err != nil {
230+
return sqlcon.HandleError(err)
231+
}
232+
}
233+
}
234+
235+
for _, a := range newAssocs {
236+
if a != nil {
237+
if err := p.GetConnection(ctx).Create(a); err != nil {
238+
return sqlcon.HandleError(err)
239+
}
240+
}
241+
}
242+
243+
return nil
244+
}
245+
246+
func (p *Persister) normalizeAllAddressess(ctx context.Context, id *identity.Identity) {
247+
p.normalizeRecoveryAddresses(ctx, id)
248+
p.normalizeVerifiableAddresses(ctx, id)
249+
}
250+
251+
func (p *Persister) normalizeVerifiableAddresses(ctx context.Context, id *identity.Identity) {
252+
for k := range id.VerifiableAddresses {
253+
id.VerifiableAddresses[k].IdentityID = id.ID
254+
id.VerifiableAddresses[k].NID = p.NetworkID(ctx)
255+
id.VerifiableAddresses[k].Value = stringToLowerTrim(id.VerifiableAddresses[k].Value)
256+
}
257+
}
258+
259+
func (p *Persister) normalizeRecoveryAddresses(ctx context.Context, id *identity.Identity) {
260+
for k := range id.RecoveryAddresses {
261+
id.RecoveryAddresses[k].IdentityID = id.ID
262+
id.RecoveryAddresses[k].NID = p.NetworkID(ctx)
263+
id.RecoveryAddresses[k].Value = stringToLowerTrim(id.RecoveryAddresses[k].Value)
264+
}
265+
}
266+
205267
func (p *Persister) createRecoveryAddresses(ctx context.Context, i *identity.Identity) error {
206268
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createRecoveryAddresses")
207269
defer span.End()
208270

209271
for k := range i.RecoveryAddresses {
210-
i.RecoveryAddresses[k].IdentityID = i.ID
211-
i.RecoveryAddresses[k].NID = p.NetworkID(ctx)
212-
i.RecoveryAddresses[k].Value = stringToLowerTrim(i.RecoveryAddresses[k].Value)
213272
if err := p.GetConnection(ctx).Create(&i.RecoveryAddresses[k]); err != nil {
214273
return err
215274
}
@@ -285,6 +344,8 @@ func (p *Persister) CreateIdentity(ctx context.Context, i *identity.Identity) er
285344
return sqlcon.HandleError(err)
286345
}
287346

347+
p.normalizeAllAddressess(ctx, i)
348+
288349
if err := p.createVerifiableAddresses(ctx, i); err != nil {
289350
return sqlcon.HandleError(err)
290351
}
@@ -350,27 +411,25 @@ func (p *Persister) UpdateIdentity(ctx context.Context, i *identity.Identity) er
350411
return sql.ErrNoRows
351412
}
352413

353-
for _, tn := range []string{
354-
new(identity.Credentials).TableName(ctx),
355-
new(identity.VerifiableAddress).TableName(ctx),
356-
new(identity.RecoveryAddress).TableName(ctx),
357-
} {
358-
/* #nosec G201 TableName is static */
359-
if err := tx.RawQuery(fmt.Sprintf(
360-
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`, tn), i.ID, p.NetworkID(ctx)).Exec(); err != nil {
361-
return err
362-
}
414+
p.normalizeAllAddressess(ctx, i)
415+
if err := updateAssociation(ctx, p, i, i.RecoveryAddresses); err != nil {
416+
return err
363417
}
364-
365-
if err := p.update(WithTransaction(ctx, tx), i); err != nil {
418+
if err := updateAssociation(ctx, p, i, i.VerifiableAddresses); err != nil {
366419
return err
367420
}
368421

369-
if err := p.createVerifiableAddresses(ctx, i); err != nil {
370-
return err
422+
/* #nosec G201 TableName is static */
423+
if err := tx.RawQuery(
424+
fmt.Sprintf(
425+
`DELETE FROM %s WHERE identity_id = ? AND nid = ?`,
426+
new(identity.Credentials).TableName(ctx)),
427+
i.ID, p.NetworkID(ctx)).Exec(); err != nil {
428+
429+
return sqlcon.HandleError(err)
371430
}
372431

373-
if err := p.createRecoveryAddresses(ctx, i); err != nil {
432+
if err := p.update(WithTransaction(ctx, tx), i); err != nil {
374433
return err
375434
}
376435

persistence/sql/persister_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ func TestPersister_Transaction(t *testing.T) {
285285
Traits: ri.Traits(`{}`),
286286
}
287287
errMessage := "failing because why not"
288-
err := p.Transaction(context.Background(), func(ctx context.Context, connection *pop.Connection) error {
288+
err := p.Transaction(context.Background(), func(_ context.Context, connection *pop.Connection) error {
289289
require.NoError(t, connection.Create(i))
290290
return errors.Errorf(errMessage)
291291
})

selfservice/strategy/link/test/persistence.go

+25-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,31 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
9393
assert.NotEqual(t, expected.Token, actual.Token)
9494
assert.EqualValues(t, expected.FlowID, actual.FlowID)
9595

96-
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
97-
require.Error(t, err)
96+
t.Run("double spend", func(t *testing.T) {
97+
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
98+
require.Error(t, err)
99+
})
100+
})
101+
102+
t.Run("case=update to identity should not invalidate token", func(t *testing.T) {
103+
expected, f := newRecoveryToken(t, "[email protected]")
104+
105+
require.NoError(t, p.CreateRecoveryToken(ctx, expected))
106+
id, err := p.GetIdentity(ctx, expected.IdentityID)
107+
require.NoError(t, err)
108+
require.NoError(t, p.UpdateIdentity(ctx, id))
109+
110+
actual, err := p.UseRecoveryToken(ctx, f.ID, expected.Token)
111+
require.NoError(t, err)
112+
assert.Equal(t, nid, actual.NID)
113+
assert.Equal(t, expected.IdentityID, actual.IdentityID)
114+
assert.NotEqual(t, expected.Token, actual.Token)
115+
assert.EqualValues(t, expected.FlowID, actual.FlowID)
116+
117+
t.Run("double spend", func(t *testing.T) {
118+
_, err = p.UseRecoveryToken(ctx, f.ID, expected.Token)
119+
require.Error(t, err)
120+
})
98121
})
99122

100123
})

0 commit comments

Comments
 (0)