Skip to content

Commit 866b472

Browse files
jonas-jonasaeneasr
authored andcommitted
fix: ignore CSRF for session extension on public route
1 parent 576f9c0 commit 866b472

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

session/handler.go

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ func (h *Handler) RegisterPublicRoutes(public *x.RouterPublic) {
7878
h.r.CSRFHandler().IgnorePath(RouteWhoami)
7979
h.r.CSRFHandler().IgnorePath(RouteCollection)
8080
h.r.CSRFHandler().IgnoreGlob(RouteCollection + "/*")
81+
h.r.CSRFHandler().IgnoreGlob(RouteCollection + "/*/extend")
8182
h.r.CSRFHandler().IgnoreGlob(AdminRouteIdentity + "/*/sessions")
8283

8384
for _, m := range []string{http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodConnect, http.MethodOptions, http.MethodTrace} {

session/handler_test.go

+20-9
Original file line numberDiff line numberDiff line change
@@ -687,21 +687,22 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
687687

688688
func TestHandlerRefreshSessionBySessionID(t *testing.T) {
689689
conf, reg := internal.NewFastRegistryWithMocks(t)
690-
_, ts, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
690+
publicServer, adminServer, _, _ := testhelpers.NewKratosServerWithCSRFAndRouters(t, reg)
691691

692692
// set this intermediate because kratos needs some valid url for CRUDE operations
693693
conf.MustSet(config.ViperKeyPublicBaseURL, "http://example.com")
694694
testhelpers.SetDefaultIdentitySchema(conf, "file://./stub/identity.schema.json")
695-
conf.MustSet(config.ViperKeyPublicBaseURL, ts.URL)
695+
conf.MustSet(config.ViperKeyPublicBaseURL, adminServer.URL)
696+
697+
i := identity.NewIdentity("")
698+
require.NoError(t, reg.IdentityManager().Create(context.Background(), i))
699+
s := &Session{Identity: i, ExpiresAt: time.Now().Add(5 * time.Minute)}
700+
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s))
696701

697702
t.Run("case=should return 200 after refreshing one session", func(t *testing.T) {
698703
client := testhelpers.NewClientWithCookies(t)
699-
i := identity.NewIdentity("")
700-
require.NoError(t, reg.IdentityManager().Create(context.Background(), i))
701-
s := &Session{Identity: i, ExpiresAt: time.Now().Add(5 * time.Minute)}
702-
require.NoError(t, reg.SessionPersister().UpsertSession(context.Background(), s))
703704

704-
req, _ := http.NewRequest("PATCH", ts.URL+"/admin/sessions/"+s.ID.String()+"/extend", nil)
705+
req, _ := http.NewRequest("PATCH", adminServer.URL+"/admin/sessions/"+s.ID.String()+"/extend", nil)
705706
res, err := client.Do(req)
706707
require.NoError(t, err)
707708
require.Equal(t, http.StatusOK, res.StatusCode)
@@ -712,7 +713,7 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {
712713

713714
t.Run("case=should return 400 when bad UUID is sent", func(t *testing.T) {
714715
client := testhelpers.NewClientWithCookies(t)
715-
req, _ := http.NewRequest("PATCH", ts.URL+"/admin/sessions/BADUUID/extend", nil)
716+
req, _ := http.NewRequest("PATCH", adminServer.URL+"/admin/sessions/BADUUID/extend", nil)
716717
res, err := client.Do(req)
717718
require.NoError(t, err)
718719
require.Equal(t, http.StatusBadRequest, res.StatusCode)
@@ -721,9 +722,19 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {
721722
t.Run("case=should return 404 when calling with missing UUID", func(t *testing.T) {
722723
client := testhelpers.NewClientWithCookies(t)
723724
someID, _ := uuid.NewV4()
724-
req, _ := http.NewRequest("PATCH", ts.URL+"/admin/sessions/"+someID.String()+"/extend", nil)
725+
req, _ := http.NewRequest("PATCH", adminServer.URL+"/admin/sessions/"+someID.String()+"/extend", nil)
725726
res, err := client.Do(req)
726727
require.NoError(t, err)
727728
require.Equal(t, http.StatusNotFound, res.StatusCode)
728729
})
730+
731+
t.Run("case=should return 404 when calling puplic server", func(t *testing.T) {
732+
req := x.NewTestHTTPRequest(t, "PATCH", publicServer.URL+"/sessions/"+s.ID.String()+"/extend", nil)
733+
734+
res, err := publicServer.Client().Do(req)
735+
require.NoError(t, err)
736+
assert.Equal(t, http.StatusNotFound, res.StatusCode)
737+
body := ioutilx.MustReadAll(res.Body)
738+
assert.NotEqual(t, gjson.GetBytes(body, "error.id").String(), "security_csrf_violation")
739+
})
729740
}

0 commit comments

Comments
 (0)