Skip to content

Commit 66e9f0f

Browse files
author
Ajay Kelkar
committed
test: fix signatures
1 parent c22f556 commit 66e9f0f

File tree

8 files changed

+52
-33
lines changed

8 files changed

+52
-33
lines changed

internal/testhelpers/handler_mock.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func MockSetSession(t *testing.T, reg mockDeps, conf *config.Config) httprouter.
3939

4040
func MockSetSessionWithIdentity(t *testing.T, reg mockDeps, conf *config.Config, i *identity.Identity) httprouter.Handle {
4141
return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
42-
activeSession, _ := session.NewActiveSession(r.Context(), r.Header, i, conf, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
42+
activeSession, _ := session.NewActiveSession(r, i, conf, time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
4343
if aal := r.URL.Query().Get("set_aal"); len(aal) > 0 {
4444
activeSession.AuthenticatorAssuranceLevel = identity.AuthenticatorAssuranceLevel(aal)
4545
}

internal/testhelpers/identity.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package testhelpers
22

33
import (
4-
"context"
4+
"github.com/ory/kratos/x"
55
"testing"
66
"time"
77

@@ -14,12 +14,11 @@ import (
1414
)
1515

1616
func CreateSession(t *testing.T, reg driver.Registry) *session.Session {
17-
ctx := context.Background()
18-
header := make(map[string][]string, 0)
17+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
1918
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
20-
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, i))
21-
sess, err := session.NewActiveSession(ctx, header, i, reg.Config(), time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
19+
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(req.Context(), i))
20+
sess, err := session.NewActiveSession(req, i, reg.Config(), time.Now().UTC(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
2221
require.NoError(t, err)
23-
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, sess))
22+
require.NoError(t, reg.SessionPersister().UpsertSession(req.Context(), sess))
2423
return sess
2524
}

selfservice/flow/login/hook.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func (e *HookExecutor) handleLoginError(_ http.ResponseWriter, r *http.Request,
101101
}
102102

103103
func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, g node.UiNodeGroup, a *Flow, i *identity.Identity, s *session.Session) error {
104-
if err := s.Activate(r.Context(), r.Header, i, e.d.Config(), time.Now().UTC()); err != nil {
104+
if err := s.Activate(r, i, e.d.Config(), time.Now().UTC()); err != nil {
105105
return err
106106
}
107107

selfservice/flow/registration/hook.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque
151151
WithField("identity_id", i.ID).
152152
Info("A new identity has registered using self-service registration.")
153153

154-
s, err := session.NewActiveSession(r.Context(), r.Header, i, e.d.Config(), time.Now().UTC(), ct, identity.AuthenticatorAssuranceLevel1)
154+
s, err := session.NewActiveSession(r, i, e.d.Config(), time.Now().UTC(), ct, identity.AuthenticatorAssuranceLevel1)
155155
if err != nil {
156156
return err
157157
}

selfservice/strategy/link/strategy_recovery.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ func (s *Strategy) recoveryIssueSession(w http.ResponseWriter, r *http.Request,
272272
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
273273
}
274274

275-
sess, err := session.NewActiveSession(r.Context(), r.Header, id, s.d.Config(), time.Now().UTC(), identity.CredentialsTypeRecoveryLink, identity.AuthenticatorAssuranceLevel1)
275+
sess, err := session.NewActiveSession(r, id, s.d.Config(), time.Now().UTC(), identity.CredentialsTypeRecoveryLink, identity.AuthenticatorAssuranceLevel1)
276276
if err != nil {
277277
return s.retryRecoveryFlowWithError(w, r, flow.TypeBrowser, err)
278278
}

session/manager_http_test.go

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,15 @@ func TestManagerHTTP(t *testing.T) {
141141
})
142142

143143
t.Run("suite=SessionAddAuthenticationMethod", func(t *testing.T) {
144+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
145+
144146
conf, reg := internal.NewFastRegistryWithMocks(t)
145147
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
146148

147149
i := &identity.Identity{Traits: []byte("{}"), State: identity.StateActive}
148150
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i))
149151
sess := session.NewInactiveSession()
150-
require.NoError(t, sess.Activate(ctx, i, conf, time.Now()))
152+
require.NoError(t, sess.Activate(req, i, conf, time.Now()))
151153
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), sess))
152154
require.NoError(t, reg.SessionManager().SessionAddAuthenticationMethods(context.Background(), sess.ID,
153155
session.AuthenticationMethod{
@@ -202,11 +204,12 @@ func TestManagerHTTP(t *testing.T) {
202204
reg.RegisterPublicRoutes(context.Background(), rp)
203205

204206
t.Run("case=valid", func(t *testing.T) {
205-
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
207+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
208+
conf.MustSet(req.Context(), config.ViperKeySessionLifespan, "1m")
206209

207210
i := identity.Identity{Traits: []byte("{}")}
208211
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
209-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
212+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
210213

211214
c := testhelpers.NewClientWithCookies(t)
212215
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")
@@ -217,6 +220,7 @@ func TestManagerHTTP(t *testing.T) {
217220
})
218221

219222
t.Run("case=key rotation", func(t *testing.T) {
223+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
220224
original := conf.GetProvider(ctx).Strings(config.ViperKeySecretsCookie)
221225
t.Cleanup(func() {
222226
conf.MustSet(ctx, config.ViperKeySecretsCookie, original)
@@ -226,7 +230,7 @@ func TestManagerHTTP(t *testing.T) {
226230

227231
i := identity.Identity{Traits: []byte("{}")}
228232
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
229-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
233+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
230234

231235
c := testhelpers.NewClientWithCookies(t)
232236
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")
@@ -242,6 +246,7 @@ func TestManagerHTTP(t *testing.T) {
242246
})
243247

244248
t.Run("case=no panic on invalid cookie name", func(t *testing.T) {
249+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
245250
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
246251
conf.MustSet(ctx, config.ViperKeySessionName, "$%˜\"")
247252
t.Cleanup(func() {
@@ -255,7 +260,7 @@ func TestManagerHTTP(t *testing.T) {
255260

256261
i := identity.Identity{Traits: []byte("{}")}
257262
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
258-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
263+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
259264

260265
c := testhelpers.NewClientWithCookies(t)
261266
res, err := c.Get(pts.URL + "/session/set/invalid")
@@ -264,11 +269,12 @@ func TestManagerHTTP(t *testing.T) {
264269
})
265270

266271
t.Run("case=valid and uses x-session-cookie", func(t *testing.T) {
272+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
267273
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
268274

269275
i := identity.Identity{Traits: []byte("{}")}
270276
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
271-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
277+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
272278

273279
c := testhelpers.NewClientWithCookies(t)
274280
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")
@@ -297,16 +303,17 @@ func TestManagerHTTP(t *testing.T) {
297303
})
298304

299305
t.Run("case=valid bearer auth as fallback", func(t *testing.T) {
306+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
300307
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
301308

302309
i := identity.Identity{Traits: []byte("{}"), State: identity.StateActive}
303310
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
304-
s, err := session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
311+
s, err := session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
305312
require.NoError(t, err)
306313
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s))
307314
require.NotEmpty(t, s.Token)
308315

309-
req, err := http.NewRequest("GET", pts.URL+"/session/get", nil)
316+
req, err = http.NewRequest("GET", pts.URL+"/session/get", nil)
310317
require.NoError(t, err)
311318
req.Header.Set("Authorization", "Bearer "+s.Token)
312319

@@ -317,15 +324,16 @@ func TestManagerHTTP(t *testing.T) {
317324
})
318325

319326
t.Run("case=valid x-session-token auth even if bearer is set", func(t *testing.T) {
327+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
320328
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
321329

322330
i := identity.Identity{Traits: []byte("{}"), State: identity.StateActive}
323331
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
324-
s, err := session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
332+
s, err := session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
325333
require.NoError(t, err)
326334
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s))
327335

328-
req, err := http.NewRequest("GET", pts.URL+"/session/get", nil)
336+
req, err = http.NewRequest("GET", pts.URL+"/session/get", nil)
329337
require.NoError(t, err)
330338
req.Header.Set("Authorization", "Bearer invalid")
331339
req.Header.Set("X-Session-Token", s.Token)
@@ -337,14 +345,15 @@ func TestManagerHTTP(t *testing.T) {
337345
})
338346

339347
t.Run("case=expired", func(t *testing.T) {
348+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
340349
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1ns")
341350
t.Cleanup(func() {
342351
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
343352
})
344353

345354
i := identity.Identity{Traits: []byte("{}")}
346355
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
347-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
356+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
348357

349358
c := testhelpers.NewClientWithCookies(t)
350359
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")
@@ -357,11 +366,12 @@ func TestManagerHTTP(t *testing.T) {
357366
})
358367

359368
t.Run("case=revoked", func(t *testing.T) {
369+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
360370
i := identity.Identity{Traits: []byte("{}")}
361371
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), &i))
362-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
372+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
363373

364-
s, _ = session.NewActiveSession(ctx, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
374+
s, _ = session.NewActiveSession(req, &i, conf, time.Now(), identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
365375

366376
c := testhelpers.NewClientWithCookies(t)
367377
testhelpers.MockHydrateCookieClient(t, c, pts.URL+"/session/set")
@@ -379,6 +389,7 @@ func TestManagerHTTP(t *testing.T) {
379389
conf.MustSet(ctx, config.ViperKeySessionLifespan, "1m")
380390

381391
t.Run("required_aal=aal2", func(t *testing.T) {
392+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
382393
idAAL2 := createAAL2Identity(t, reg)
383394
idAAL1 := createAAL1Identity(t, reg)
384395
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), idAAL1))
@@ -389,7 +400,7 @@ func TestManagerHTTP(t *testing.T) {
389400
for _, m := range complete {
390401
s.CompletedLoginFor(m, "")
391402
}
392-
require.NoError(t, s.Activate(ctx, i, conf, time.Now().UTC()))
403+
require.NoError(t, s.Activate(req, i, conf, time.Now().UTC()))
393404
err := reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, requested)
394405
if expectedError != nil {
395406
require.ErrorAs(t, err, &expectedError)
@@ -424,7 +435,6 @@ func TestManagerHTTP(t *testing.T) {
424435
}
425436

426437
func TestDoesSessionSatisfy(t *testing.T) {
427-
ctx := context.Background()
428438
conf, reg := internal.NewFastRegistryWithMocks(t)
429439
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
430440

@@ -552,11 +562,12 @@ func TestDoesSessionSatisfy(t *testing.T) {
552562
require.NoError(t, reg.PrivilegedIdentityPool().DeleteIdentity(context.Background(), id.ID))
553563
})
554564

565+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
555566
s := session.NewInactiveSession()
556567
for _, m := range tc.amr {
557568
s.CompletedLoginFor(m.Method, m.AAL)
558569
}
559-
require.NoError(t, s.Activate(ctx, id, conf, time.Now().UTC()))
570+
require.NoError(t, s.Activate(req, id, conf, time.Now().UTC()))
560571

561572
err := reg.SessionManager().DoesSessionSatisfy((&http.Request{}).WithContext(context.Background()), s, string(tc.requested))
562573
if tc.err != nil {

session/session.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@ func (s *Session) Activate(r *http.Request, i *identity.Identity, c lifespanProv
199199
s.ClientIPAddress = trueClientIP
200200
} else if realClientIP := r.Header.Get("X-Real-IP"); realClientIP != "" {
201201
s.ClientIPAddress = realClientIP
202-
} else {
202+
} else if forwardedIP := r.Header.Get("X-Forwarded-For"); forwardedIP != "" {
203203
// TODO: Use x lib implementation to parse client IP address from the header string
204-
s.ClientIPAddress = r.Header.Get("X-Forwarded-For")
204+
s.ClientIPAddress = forwardedIP
205+
} else {
206+
s.ClientIPAddress = r.RemoteAddr
205207
}
206208

207209
clientGeoLocation := []string{r.Header.Get("Cf-Ipcity"), r.Header.Get("Cf-Ipcountry")}

session/session_test.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package session_test
33
import (
44
"context"
55
"fmt"
6+
"github.com/ory/kratos/x"
67
"testing"
78
"time"
89

@@ -22,16 +23,18 @@ func TestSession(t *testing.T) {
2223
authAt := time.Now()
2324

2425
t.Run("case=active session", func(t *testing.T) {
26+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
27+
2528
i := new(identity.Identity)
2629
i.State = identity.StateActive
27-
s, _ := session.NewActiveSession(ctx, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
30+
s, _ := session.NewActiveSession(req, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
2831
assert.True(t, s.IsActive())
2932
require.NotEmpty(t, s.Token)
3033
require.NotEmpty(t, s.LogoutToken)
3134
assert.EqualValues(t, identity.CredentialsTypePassword, s.AMR[0].Method)
3235

3336
i = new(identity.Identity)
34-
s, err := session.NewActiveSession(ctx, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
37+
s, err := session.NewActiveSession(req, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
3538
assert.Nil(t, s)
3639
assert.ErrorIs(t, err, session.ErrIdentityDisabled)
3740
})
@@ -51,14 +54,16 @@ func TestSession(t *testing.T) {
5154
})
5255

5356
t.Run("case=activate", func(t *testing.T) {
57+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
58+
5459
s := session.NewInactiveSession()
55-
require.NoError(t, s.Activate(ctx, &identity.Identity{State: identity.StateActive}, conf, authAt))
60+
require.NoError(t, s.Activate(req, &identity.Identity{State: identity.StateActive}, conf, authAt))
5661
assert.True(t, s.Active)
5762
assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel)
5863
assert.Equal(t, authAt, s.AuthenticatedAt)
5964

6065
s = session.NewInactiveSession()
61-
require.ErrorIs(t, s.Activate(ctx, &identity.Identity{State: identity.StateInactive}, conf, authAt), session.ErrIdentityDisabled)
66+
require.ErrorIs(t, s.Activate(req, &identity.Identity{State: identity.StateInactive}, conf, authAt), session.ErrIdentityDisabled)
6267
assert.False(t, s.Active)
6368
assert.Equal(t, identity.NoAuthenticatorAssuranceLevel, s.AuthenticatorAssuranceLevel)
6469
assert.Empty(t, s.AuthenticatedAt)
@@ -192,6 +197,8 @@ func TestSession(t *testing.T) {
192197
}
193198

194199
t.Run("case=session refresh", func(t *testing.T) {
200+
req := x.NewTestHTTPRequest(t, "GET", "/sessions/whoami", nil)
201+
195202
conf.MustSet(ctx, config.ViperKeySessionLifespan, "24h")
196203
conf.MustSet(ctx, config.ViperKeySessionRefreshMinTimeLeft, "12h")
197204
t.Cleanup(func() {
@@ -200,7 +207,7 @@ func TestSession(t *testing.T) {
200207
})
201208
i := new(identity.Identity)
202209
i.State = identity.StateActive
203-
s, _ := session.NewActiveSession(ctx, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
210+
s, _ := session.NewActiveSession(req, i, conf, authAt, identity.CredentialsTypePassword, identity.AuthenticatorAssuranceLevel1)
204211
assert.False(t, s.CanBeRefreshed(ctx, conf), "fresh session is not refreshable")
205212

206213
s.ExpiresAt = s.ExpiresAt.Add(-12 * time.Hour)

0 commit comments

Comments
 (0)