Skip to content

Commit 4a26981

Browse files
committed
fix: identity sessions list response includes pagination headers
This resolves issue #2762
1 parent 439f015 commit 4a26981

File tree

5 files changed

+33
-14
lines changed

5 files changed

+33
-14
lines changed

persistence/sql/persister_session.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,32 @@ func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID) (*session.Ses
4141
}
4242

4343
// ListSessionsByIdentity retrieves sessions for an identity from the store.
44-
func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID) ([]*session.Session, error) {
44+
func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID) ([]*session.Session, int64, error) {
4545
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListSessionsByIdentity")
4646
defer span.End()
4747

4848
s := make([]*session.Session, 0)
49+
t := int64(0)
4950
nid := p.NetworkID(ctx)
5051

5152
if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
52-
q := c.Where("identity_id = ? AND nid = ?", iID, nid).Paginate(page, perPage)
53+
q := c.Where("identity_id = ? AND nid = ?", iID, nid)
5354
if except != uuid.Nil {
5455
q = q.Where("id != ?", except)
5556
}
5657
if active != nil {
5758
q = q.Where("active = ?", *active)
5859
}
59-
if err := q.All(&s); err != nil {
60+
61+
// Get the total count of matching items
62+
total, err := q.Count(new(session.Session))
63+
if err != nil {
64+
return sqlcon.HandleError(err)
65+
}
66+
t = int64(total)
67+
68+
// Get the paginated list of matching items
69+
if err := q.Paginate(page, perPage).All(&s); err != nil {
6070
return sqlcon.HandleError(err)
6171
}
6272

@@ -70,10 +80,10 @@ func (p *Persister) ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, a
7080
}
7181
return nil
7282
}); err != nil {
73-
return nil, err
83+
return nil, 0, err
7484
}
7585

76-
return s, nil
86+
return s, t, nil
7787
}
7888

7989
func (p *Persister) UpsertSession(ctx context.Context, s *session.Session) error {

session/handler.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/pkg/errors"
1212

1313
"github.com/ory/x/decoderx"
14+
"github.com/ory/x/urlx"
1415

1516
"github.com/ory/herodot"
1617

@@ -304,12 +305,13 @@ func (h *Handler) adminListIdentitySessions(w http.ResponseWriter, r *http.Reque
304305
}
305306

306307
page, perPage := x.ParsePagination(r)
307-
sess, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), iID, active, page, perPage, uuid.Nil)
308+
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), iID, active, page, perPage, uuid.Nil)
308309
if err != nil {
309310
h.r.Writer().WriteError(w, r, err)
310311
return
311312
}
312313

314+
x.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, perPage)
313315
h.r.Writer().Write(w, r, sess)
314316
}
315317

@@ -448,12 +450,13 @@ func (h *Handler) listSessions(w http.ResponseWriter, r *http.Request, _ httprou
448450
}
449451

450452
page, perPage := x.ParsePagination(r)
451-
sess, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID)
453+
sess, total, err := h.r.SessionPersister().ListSessionsByIdentity(r.Context(), s.IdentityID, pointerx.Bool(true), page, perPage, s.ID)
452454
if err != nil {
453455
h.r.Writer().WriteError(w, r, err)
454456
return
455457
}
456458

459+
x.PaginationHeader(w, urlx.AppendPaths(h.r.Config().SelfAdminURL(r.Context()), RouteCollection), total, page, perPage)
457460
h.r.Writer().Write(w, r, sess)
458461
}
459462

session/handler_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,9 +601,10 @@ func TestHandlerSelfServiceSessionManagement(t *testing.T) {
601601
require.NoError(t, err)
602602
require.Equal(t, http.StatusNoContent, res.StatusCode)
603603

604-
actualOthers, err := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil)
604+
actualOthers, total, err := reg.SessionPersister().ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil)
605605
require.NoError(t, err)
606606
require.Len(t, actualOthers, 3)
607+
require.Equal(t, int64(3), total)
607608

608609
for _, s := range actualOthers {
609610
if s.ID == others[0].ID {

session/persistence.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ type Persister interface {
2424
GetSession(ctx context.Context, sid uuid.UUID) (*Session, error)
2525

2626
// ListSessionsByIdentity retrieves sessions for an identity from the store.
27-
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID) ([]*Session, error)
27+
ListSessionsByIdentity(ctx context.Context, iID uuid.UUID, active *bool, page, perPage int, except uuid.UUID) ([]*Session, int64, error)
2828

2929
// UpsertSession inserts or updates a session into / in the store.
3030
UpsertSession(ctx context.Context, s *Session) error

session/test/persistence.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
149149
},
150150
} {
151151
t.Run("case="+tc.desc, func(t *testing.T) {
152-
actual, err := p.ListSessionsByIdentity(ctx, i.ID, tc.active, 1, 10, tc.except)
152+
actual, total, err := p.ListSessionsByIdentity(ctx, i.ID, tc.active, 1, 10, tc.except)
153153
require.NoError(t, err)
154154

155155
require.Equal(t, len(tc.expected), len(actual))
156+
require.Equal(t, int64(len(tc.expected)), total)
156157
for _, es := range tc.expected {
157158
found := false
158159
for _, as := range actual {
@@ -167,8 +168,9 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
167168

168169
t.Run("other network", func(t *testing.T) {
169170
_, other := testhelpers.NewNetwork(t, ctx, p)
170-
actual, err := other.ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil)
171+
actual, total, err := other.ListSessionsByIdentity(ctx, i.ID, nil, 1, 10, uuid.Nil)
171172
require.NoError(t, err)
173+
require.Equal(t, int64(0), total)
172174
assert.Len(t, actual, 0)
173175
})
174176
})
@@ -292,9 +294,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
292294
require.NoError(t, err)
293295
assert.Equal(t, 1, n)
294296

295-
actual, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil)
297+
actual, total, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil)
296298
require.NoError(t, err)
297299
require.Len(t, actual, 2)
300+
require.Equal(t, int64(2), total)
298301

299302
if actual[0].ID == sessions[0].ID {
300303
assert.True(t, actual[0].Active)
@@ -305,9 +308,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
305308
assert.False(t, actual[0].Active)
306309
}
307310

308-
otherIdentitiesSessions, err := p.ListSessionsByIdentity(ctx, sessions[2].IdentityID, nil, 1, 10, uuid.Nil)
311+
otherIdentitiesSessions, total, err := p.ListSessionsByIdentity(ctx, sessions[2].IdentityID, nil, 1, 10, uuid.Nil)
309312
require.NoError(t, err)
310313
require.Len(t, actual, 2)
314+
require.Equal(t, int64(2), total)
311315

312316
for _, s := range otherIdentitiesSessions {
313317
assert.True(t, s.Active)
@@ -339,9 +343,10 @@ func TestPersister(ctx context.Context, conf *config.Config, p interface {
339343

340344
require.NoError(t, p.RevokeSession(ctx, sessions[0].IdentityID, sessions[0].ID))
341345

342-
actual, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil)
346+
actual, total, err := p.ListSessionsByIdentity(ctx, sessions[0].IdentityID, nil, 1, 10, uuid.Nil)
343347
require.NoError(t, err)
344348
require.Len(t, actual, 2)
349+
require.Equal(t, int64(2), total)
345350

346351
if actual[0].ID == sessions[0].ID {
347352
assert.False(t, actual[0].Active)

0 commit comments

Comments
 (0)