Skip to content

Commit 0c2efa2

Browse files
authored
fix: identity sessions list response includes pagination headers (#2763)
Closes #2762
1 parent d8514b5 commit 0c2efa2

File tree

5 files changed

+119
-14
lines changed

5 files changed

+119
-14
lines changed

persistence/sql/persister_session.go

+15-5
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,16 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID, expandables s
5454
}
5555

5656
// ListSessionsByIdentity retrieves sessions for an identity from the store.
57-
func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables session.Expandables) ([]*session.Session, error) {
57+
func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables session.Expandables) ([]*session.Session, int64, error) {
5858
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListSessionsByIdentity")
5959
defer span.End()
6060

6161
s := make([]*session.Session, 0)
62+
t := int64(0)
6263
nid := p.NetworkID(ctx)
6364

6465
if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
65-
q := c.Where("identity_id = ? AND nid = ?", iID, nid).Paginate(page, perPage)
66+
q := c.Where("identity_id = ? AND nid = ?", iID, nid)
6667
if except != uuid.Nil {
6768
q = q.Where("id != ?", except)
6869
}
@@ -72,7 +73,16 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
7273
if len(expandables) > 0 {
7374
q = q.Eager(expandables.ToEager()...)
7475
}
75-
if err := q.All(&s); err != nil {
76+
77+
// Get the total count of matching items
78+
total, err := q.Count(new(session.Session))
79+
if err != nil {
80+
return sqlcon.HandleError(err)
81+
}
82+
t = int64(total)
83+
84+
// Get the paginated list of matching items
85+
if err := q.Paginate(page, perPage).All(&s); err != nil {
7686
return sqlcon.HandleError(err)
7787
}
7888

@@ -88,10 +98,10 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
8898
}
8999
return nil
90100
}); err != nil {
91-
return nil, err
101+
return nil, 0, err
92102
}
93103

94-
return s, nil
104+
return s, t, nil
95105
}
96106

97107
// UpsertSession creates a session if not found else updates.

session/handler.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/pkg/errors"
1212

1313
"github.com/ory/x/decoderx"
14+
"github.com/ory/x/urlx"
1415

1516
"github.com/ory/herodot"
1617

@@ -304,12 +305,13 @@ func (h *Handler) adminListIdentitySessions(w http.ResponseWriter, r *http.Reque
304305
}
305306

306307
page, perPage := x.ParsePagination(r)
307-
sess, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), iID, active, page, perPage, uuid.Nil, ExpandEverything)
308+
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), iID, active, page, perPage, uuid.Nil, ExpandEverything)
308309
if err != nil {
309310
h.r.Writer().WriteError(w, r, err)
310311
return
311312
}
312313

314+
x.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, perPage)
313315
h.r.Writer().Write(w, r, sess)
314316
}
315317

@@ -448,12 +450,13 @@ func (h *Handler) listSessions(w http.ResponseWriter, r *http.Request, _ httprou
448450
}
449451

450452
page, perPage := x.ParsePagination(r)
451-
sess, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID, ExpandEverything)
453+
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID, ExpandEverything)
452454
if err != nil {
453455
h.r.Writer().WriteError(w, r, err)
454456
return
455457
}
456458

459+
x.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, perPage)
457460
h.r.Writer().Write(w, r, sess)
458461
}
459462

session/handler_test.go

+88-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"net/http/httptest"
10+
"strconv"
1011
"strings"
1112
"testing"
1213
"time"
@@ -472,6 +473,61 @@ func TestHandlerAdminSessionManagement(t *testing.T) {
472473
require.Equal(t, http.StatusNotFound, res.StatusCode)
473474
})
474475

476+
t.Run("case=should return pagination headers on list response", func(t *testing.T) {
477+
client := testhelpers.NewClientWithCookies(t)
478+
i := identity.NewIdentity("")
479+
require.NoError(t, reg.IdentityManager().Create(ctx, i))
480+
481+
numSessions := 5
482+
numSessionsActive := 2
483+
484+
sess := make([]Session, numSessions)
485+
for j := range sess {
486+
require.NoError(t, faker.FakeData(&sess[j]))
487+
sess[j].Identity = i
488+
if j < numSessionsActive {
489+
sess[j].Active = true
490+
} else {
491+
sess[j].Active = false
492+
}
493+
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
494+
}
495+
496+
for _, tc := range []struct {
497+
activeOnly string
498+
expectedTotalCount int
499+
}{
500+
{
501+
activeOnly: "true",
502+
expectedTotalCount: numSessionsActive,
503+
},
504+
{
505+
activeOnly: "false",
506+
expectedTotalCount: numSessions - numSessionsActive,
507+
},
508+
{
509+
activeOnly: "",
510+
expectedTotalCount: numSessions,
511+
},
512+
} {
513+
t.Run(fmt.Sprintf("active=%#v", tc.activeOnly), func(t *testing.T) {
514+
reqURL := ts.URL + "/admin/identities/" + i.ID.String() + "/sessions"
515+
if tc.activeOnly != "" {
516+
reqURL += "?active=" + tc.activeOnly
517+
}
518+
req, _ := http.NewRequest("GET", reqURL, nil)
519+
res, err := client.Do(req)
520+
require.NoError(t, err)
521+
require.Equal(t, http.StatusOK, res.StatusCode)
522+
523+
totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
524+
require.NoError(t, err)
525+
require.Equal(t, tc.expectedTotalCount, totalCount)
526+
require.NotEqual(t, "", res.Header.Get("Link"))
527+
})
528+
}
529+
})
530+
475531
t.Run("case=should respect active on list", func(t *testing.T) {
476532
client := testhelpers.NewClientWithCookies(t)
477533
i := identity.NewIdentity("")
@@ -559,6 +615,36 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
559615
}
560616
}
561617

618+
t.Run("case=list should return pagination headers", func(t *testing.T) {
619+
client, i, _ := setup(t)
620+
621+
numSessions := 5
622+
numSessionsActive := 2
623+
624+
sess := make([]Session, numSessions)
625+
for j := range sess {
626+
require.NoError(t, faker.FakeData(&sess[j]))
627+
sess[j].Identity = i
628+
if j < numSessionsActive {
629+
sess[j].Active = true
630+
} else {
631+
sess[j].Active = false
632+
}
633+
require.NoError(t, reg.SessionPersister().UpsertSession(ctx, &sess[j]))
634+
}
635+
636+
reqURL := ts.URL + "/sessions"
637+
req, _ := http.NewRequest("GET", reqURL, nil)
638+
res, err := client.Do(req)
639+
require.NoError(t, err)
640+
require.Equal(t, http.StatusOK, res.StatusCode)
641+
642+
totalCount, err := strconv.Atoi(res.Header.Get("X-Total-Count"))
643+
require.NoError(t, err)
644+
require.Equal(t, numSessionsActive, totalCount)
645+
require.NotEqual(t, "", res.Header.Get("Link"))
646+
})
647+
562648
t.Run("case=should return 200 and number after invalidating all other sessions", func(t *testing.T) {
563649
client, i, currSess := setup(t)
564650

@@ -601,9 +687,10 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
601687
require.NoError(t, err)
602688
require.Equal(t, http.StatusNoContent, res.StatusCode)
603689

604-
actualOthers, err := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandNothing)
690+
actualOthers, total, err := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, ExpandNothing)
605691
require.NoError(t, err)
606692
require.Len(t, actualOthers, 3)
693+
require.Equal(t, int64(3), total)
607694

608695
for _, s := range actualOthers {
609696
if s.ID == others[0].ID {

session/persistence.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type Persister interface {
1616
GetSession(ctx context.Context, sid uuid.UUID, expandables Expandables) (*Session, error)
1717

1818
// ListSessionsByIdentity retrieves sessions for an identity from the store.
19-
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables Expandables) ([]*Session, error)
19+
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID, expandables Expandables) ([]*Session, int64, error)
2020

2121
// UpsertSession inserts or updates a session into / in the store.
2222
UpsertSession(ctx context.Context, s *Session) error

session/test/persistence.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,11 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
178178
},
179179
} {
180180
t.Run("case="+tc.desc, func(t *testing.T) {
181-
actual, err := p.ListSessionsByIdentity(ctx, i.ID, tc.active, 1, 10, tc.except, session.ExpandEverything)
181+
actual, total, err := p.ListSessionsByIdentity(ctx, i.ID, tc.active, 1, 10, tc.except, session.ExpandEverything)
182182
require.NoError(t, err)
183183

184184
require.Equal(t, len(tc.expected), len(actual))
185+
require.Equal(t, int64(len(tc.expected)), total)
185186
for _, es := range tc.expected {
186187
found := false
187188
for _, as := range actual {
@@ -197,8 +198,9 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
197198

198199
t.Run("other network", func(t *testing.T) {
199200
_, other := testhelpers.NewNetwork(t, ctx, p)
200-
actual, err := other.ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
201+
actual, total, err := other.ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
201202
require.NoError(t, err)
203+
require.Equal(t, int64(0), total)
202204
assert.Len(t, actual, 0)
203205
})
204206
})
@@ -322,9 +324,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
322324
require.NoError(t, err)
323325
assert.Equal(t, 1, n)
324326

325-
actual, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
327+
actual, total, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
326328
require.NoError(t, err)
327329
require.Len(t, actual, 2)
330+
require.Equal(t, int64(2), total)
328331

329332
if actual[0].ID == sessions[0].ID {
330333
assert.True(t, actual[0].Active)
@@ -335,9 +338,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
335338
assert.False(t, actual[0].Active)
336339
}
337340

338-
otherIdentitiesSessions, err := p.ListSessionsByIdentity(ctx, sessions[2].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
341+
otherIdentitiesSessions, total, err := p.ListSessionsByIdentity(ctx, sessions[2].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
339342
require.NoError(t, err)
340343
require.Len(t, actual, 2)
344+
require.Equal(t, int64(2), total)
341345

342346
for _, s := range otherIdentitiesSessions {
343347
assert.True(t, s.Active)
@@ -369,9 +373,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
369373

370374
require.NoError(t, p.RevokeSession(ctx, sessions[0].IdentityID, sessions[0].ID))
371375

372-
actual, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
376+
actual, total, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil, session.ExpandNothing)
373377
require.NoError(t, err)
374378
require.Len(t, actual, 2)
379+
require.Equal(t, int64(2), total)
375380

376381
if actual[0].ID == sessions[0].ID {
377382
assert.False(t, actual[0].Active)

0 commit comments

Comments
 (0)