Skip to content

Commit c0ceaf3

Browse files
committed
feat: add pre-hooks to settings, verification, recovery
1 parent 1787e68 commit c0ceaf3

22 files changed

+447
-41
lines changed

driver/config/config.go

+15
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,12 @@ const (
129129
ViperKeySelfServiceLogoutBrowserDefaultReturnTo = "selfservice.flows.logout.after." + DefaultBrowserReturnURL
130130
ViperKeySelfServiceSettingsURL = "selfservice.flows.settings.ui_url"
131131
ViperKeySelfServiceSettingsAfter = "selfservice.flows.settings.after"
132+
ViperKeySelfServiceSettingsBeforeHooks = "selfservice.flows.settings.before.hooks"
132133
ViperKeySelfServiceSettingsRequestLifespan = "selfservice.flows.settings.lifespan"
133134
ViperKeySelfServiceSettingsPrivilegedAuthenticationAfter = "selfservice.flows.settings.privileged_session_max_age"
134135
ViperKeySelfServiceSettingsRequiredAAL = "selfservice.flows.settings.required_aal"
135136
ViperKeySelfServiceRecoveryAfter = "selfservice.flows.recovery.after"
137+
ViperKeySelfServiceRecoveryBeforeHooks = "selfservice.flows.recovery.before.hooks"
136138
ViperKeySelfServiceRecoveryEnabled = "selfservice.flows.recovery.enabled"
137139
ViperKeySelfServiceRecoveryUI = "selfservice.flows.recovery.ui_url"
138140
ViperKeySelfServiceRecoveryRequestLifespan = "selfservice.flows.recovery.lifespan"
@@ -142,6 +144,7 @@ const (
142144
ViperKeySelfServiceVerificationRequestLifespan = "selfservice.flows.verification.lifespan"
143145
ViperKeySelfServiceVerificationBrowserDefaultReturnTo = "selfservice.flows.verification.after." + DefaultBrowserReturnURL
144146
ViperKeySelfServiceVerificationAfter = "selfservice.flows.verification.after"
147+
ViperKeySelfServiceVerificationBeforeHooks = "selfservice.flows.verification.before.hooks"
145148
ViperKeyDefaultIdentitySchemaID = "identity.default_schema_id"
146149
ViperKeyIdentitySchemas = "identity.schemas"
147150
ViperKeyHasherAlgorithm = "hashers.algorithm"
@@ -623,6 +626,18 @@ func (p *Config) SelfServiceFlowLoginBeforeHooks(ctx context.Context) []SelfServ
623626
return p.selfServiceHooks(ctx, ViperKeySelfServiceLoginBeforeHooks)
624627
}
625628

629+
func (p *Config) SelfServiceFlowRecoveryBeforeHooks(ctx context.Context) []SelfServiceHook {
630+
return p.selfServiceHooks(ctx, ViperKeySelfServiceRecoveryBeforeHooks)
631+
}
632+
633+
func (p *Config) SelfServiceFlowVerificationBeforeHooks(ctx context.Context) []SelfServiceHook {
634+
return p.selfServiceHooks(ctx, ViperKeySelfServiceVerificationBeforeHooks)
635+
}
636+
637+
func (p *Config) SelfServiceFlowSettingsBeforeHooks(ctx context.Context) []SelfServiceHook {
638+
return p.selfServiceHooks(ctx, ViperKeySelfServiceSettingsBeforeHooks)
639+
}
640+
626641
func (p *Config) SelfServiceFlowRegistrationBeforeHooks(ctx context.Context) []SelfServiceHook {
627642
return p.selfServiceHooks(ctx, ViperKeySelfServiceRegistrationBeforeHooks)
628643
}

driver/registry_default_recovery.go

+9
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ func (m *RegistryDefault) RecoveryExecutor() *recovery.HookExecutor {
5050
return m.selfserviceRecoveryExecutor
5151
}
5252

53+
func (m *RegistryDefault) PreRecoveryHooks(ctx context.Context) (b []recovery.PreHookExecutor) {
54+
for _, v := range m.getHooks("", m.Config().SelfServiceFlowRecoveryBeforeHooks(ctx)) {
55+
if hook, ok := v.(recovery.PreHookExecutor); ok {
56+
b = append(b, hook)
57+
}
58+
}
59+
return
60+
}
61+
5362
func (m *RegistryDefault) PostRecoveryHooks(ctx context.Context) (b []recovery.PostHookExecutor) {
5463
for _, v := range m.getHooks(config.HookGlobal, m.Config().SelfServiceFlowRecoveryAfterHooks(ctx, config.HookGlobal)) {
5564
if hook, ok := v.(recovery.PostHookExecutor); ok {

driver/registry_default_settings.go

+9
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ func (m *RegistryDefault) PostSettingsPrePersistHooks(ctx context.Context, setti
1616
return
1717
}
1818

19+
func (m *RegistryDefault) PreSettingsHooks(ctx context.Context) (b []settings.PreHookExecutor) {
20+
for _, v := range m.getHooks("", m.Config().SelfServiceFlowSettingsBeforeHooks(ctx)) {
21+
if hook, ok := v.(settings.PreHookExecutor); ok {
22+
b = append(b, hook)
23+
}
24+
}
25+
return
26+
}
27+
1928
func (m *RegistryDefault) PostSettingsPostPersistHooks(ctx context.Context, settingsType string) (b []settings.PostHookPostPersistExecutor) {
2029
initialHookCount := 0
2130
if m.Config().SelfServiceFlowVerificationEnabled(ctx) {

driver/registry_default_test.go

+117
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,45 @@ func TestDriverDefault_Hooks(t *testing.T) {
3131
ctx := context.Background()
3232

3333
t.Run("type=verification", func(t *testing.T) {
34+
// BEFORE hooks
35+
for _, tc := range []struct {
36+
uc string
37+
prep func(conf *config.Config)
38+
expect func(reg *driver.RegistryDefault) []verification.PreHookExecutor
39+
}{
40+
{
41+
uc: "No hooks configured",
42+
prep: func(conf *config.Config) {},
43+
expect: func(reg *driver.RegistryDefault) []verification.PreHookExecutor { return nil },
44+
},
45+
{
46+
uc: "Two web_hooks are configured",
47+
prep: func(conf *config.Config) {
48+
conf.MustSet(ctx, config.ViperKeySelfServiceVerificationBeforeHooks, []map[string]interface{}{
49+
{"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST"}},
50+
{"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET"}},
51+
})
52+
},
53+
expect: func(reg *driver.RegistryDefault) []verification.PreHookExecutor {
54+
return []verification.PreHookExecutor{
55+
hook.NewWebHook(reg, json.RawMessage(`{"method":"POST","url":"foo"}`)),
56+
hook.NewWebHook(reg, json.RawMessage(`{"method":"GET","url":"bar"}`)),
57+
}
58+
},
59+
},
60+
} {
61+
t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) {
62+
conf, reg := internal.NewFastRegistryWithMocks(t)
63+
tc.prep(conf)
64+
65+
h := reg.PreVerificationHooks(ctx)
66+
67+
expectedExecutors := tc.expect(reg)
68+
require.Len(t, h, len(expectedExecutors))
69+
assert.Equal(t, expectedExecutors, h)
70+
})
71+
}
72+
3473
// AFTER hooks
3574
for _, tc := range []struct {
3675
uc string
@@ -72,6 +111,45 @@ func TestDriverDefault_Hooks(t *testing.T) {
72111
})
73112

74113
t.Run("type=recovery", func(t *testing.T) {
114+
// BEFORE hooks
115+
for _, tc := range []struct {
116+
uc string
117+
prep func(conf *config.Config)
118+
expect func(reg *driver.RegistryDefault) []recovery.PreHookExecutor
119+
}{
120+
{
121+
uc: "No hooks configured",
122+
prep: func(conf *config.Config) {},
123+
expect: func(reg *driver.RegistryDefault) []recovery.PreHookExecutor { return nil },
124+
},
125+
{
126+
uc: "Two web_hooks are configured",
127+
prep: func(conf *config.Config) {
128+
conf.MustSet(ctx, config.ViperKeySelfServiceRecoveryBeforeHooks, []map[string]interface{}{
129+
{"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST"}},
130+
{"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET"}},
131+
})
132+
},
133+
expect: func(reg *driver.RegistryDefault) []recovery.PreHookExecutor {
134+
return []recovery.PreHookExecutor{
135+
hook.NewWebHook(reg, json.RawMessage(`{"method":"POST","url":"foo"}`)),
136+
hook.NewWebHook(reg, json.RawMessage(`{"method":"GET","url":"bar"}`)),
137+
}
138+
},
139+
},
140+
} {
141+
t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) {
142+
conf, reg := internal.NewFastRegistryWithMocks(t)
143+
tc.prep(conf)
144+
145+
h := reg.PreRecoveryHooks(ctx)
146+
147+
expectedExecutors := tc.expect(reg)
148+
require.Len(t, h, len(expectedExecutors))
149+
assert.Equal(t, expectedExecutors, h)
150+
})
151+
}
152+
75153
// AFTER hooks
76154
for _, tc := range []struct {
77155
uc string
@@ -388,6 +466,45 @@ func TestDriverDefault_Hooks(t *testing.T) {
388466
})
389467

390468
t.Run("type=settings", func(t *testing.T) {
469+
// BEFORE hooks
470+
for _, tc := range []struct {
471+
uc string
472+
prep func(conf *config.Config)
473+
expect func(reg *driver.RegistryDefault) []settings.PreHookExecutor
474+
}{
475+
{
476+
uc: "No hooks configured",
477+
prep: func(conf *config.Config) {},
478+
expect: func(reg *driver.RegistryDefault) []settings.PreHookExecutor { return nil },
479+
},
480+
{
481+
uc: "Two web_hooks are configured",
482+
prep: func(conf *config.Config) {
483+
conf.MustSet(ctx, config.ViperKeySelfServiceSettingsBeforeHooks, []map[string]interface{}{
484+
{"hook": "web_hook", "config": map[string]interface{}{"url": "foo", "method": "POST"}},
485+
{"hook": "web_hook", "config": map[string]interface{}{"url": "bar", "method": "GET"}},
486+
})
487+
},
488+
expect: func(reg *driver.RegistryDefault) []settings.PreHookExecutor {
489+
return []settings.PreHookExecutor{
490+
hook.NewWebHook(reg, json.RawMessage(`{"method":"POST","url":"foo"}`)),
491+
hook.NewWebHook(reg, json.RawMessage(`{"method":"GET","url":"bar"}`)),
492+
}
493+
},
494+
},
495+
} {
496+
t.Run(fmt.Sprintf("before/uc=%s", tc.uc), func(t *testing.T) {
497+
conf, reg := internal.NewFastRegistryWithMocks(t)
498+
tc.prep(conf)
499+
500+
h := reg.PreSettingsHooks(ctx)
501+
502+
expectedExecutors := tc.expect(reg)
503+
require.Len(t, h, len(expectedExecutors))
504+
assert.Equal(t, expectedExecutors, h)
505+
})
506+
}
507+
391508
// AFTER hooks
392509
for _, tc := range []struct {
393510
uc string

driver/registry_default_verify.go

+9
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ func (m *RegistryDefault) VerificationExecutor() *verification.HookExecutor {
7373
return m.selfserviceVerificationExecutor
7474
}
7575

76+
func (m *RegistryDefault) PreVerificationHooks(ctx context.Context) (b []verification.PreHookExecutor) {
77+
for _, v := range m.getHooks("", m.Config().SelfServiceFlowVerificationBeforeHooks(ctx)) {
78+
if hook, ok := v.(verification.PreHookExecutor); ok {
79+
b = append(b, hook)
80+
}
81+
}
82+
return
83+
}
84+
7685
func (m *RegistryDefault) PostVerificationHooks(ctx context.Context) (b []verification.PostHookExecutor) {
7786
for _, v := range m.getHooks(config.HookGlobal, m.Config().SelfServiceFlowVerificationAfterHooks(ctx, config.HookGlobal)) {
7887
if hook, ok := v.(verification.PostHookExecutor); ok {

embedx/config.schema.json

+36
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,33 @@
790790
}
791791
}
792792
},
793+
"selfServiceBeforeSettings": {
794+
"type": "object",
795+
"additionalProperties": false,
796+
"properties": {
797+
"hooks": {
798+
"$ref": "#/definitions/selfServiceHooks"
799+
}
800+
}
801+
},
802+
"selfServiceBeforeRecovery": {
803+
"type": "object",
804+
"additionalProperties": false,
805+
"properties": {
806+
"hooks": {
807+
"$ref": "#/definitions/selfServiceHooks"
808+
}
809+
}
810+
},
811+
"selfServiceBeforeVerification": {
812+
"type": "object",
813+
"additionalProperties": false,
814+
"properties": {
815+
"hooks": {
816+
"$ref": "#/definitions/selfServiceHooks"
817+
}
818+
}
819+
},
793820
"selfServiceAfterRegistration": {
794821
"type": "object",
795822
"additionalProperties": false,
@@ -1023,6 +1050,9 @@
10231050
},
10241051
"after": {
10251052
"$ref": "#/definitions/selfServiceAfterSettings"
1053+
},
1054+
"before": {
1055+
"$ref": "#/definitions/selfServiceBeforeSettings"
10261056
}
10271057
}
10281058
},
@@ -1146,6 +1176,9 @@
11461176
"1m",
11471177
"1s"
11481178
]
1179+
},
1180+
"before": {
1181+
"$ref": "#/definitions/selfServiceBeforeVerification"
11491182
}
11501183
}
11511184
},
@@ -1184,6 +1217,9 @@
11841217
"1m",
11851218
"1s"
11861219
]
1220+
},
1221+
"before": {
1222+
"$ref": "#/definitions/selfServiceBeforeRecovery"
11871223
}
11881224
}
11891225
},

internal/testhelpers/selfservice.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestSelfServicePreHook(
4646

4747
t.Run("case=err if hooks err", func(t *testing.T) {
4848
t.Cleanup(SelfServiceHookConfigReset(t, conf))
49-
conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{"ExecuteLoginPreHook": "err","ExecuteRegistrationPreHook": "err"}`)}})
49+
conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{"ExecuteLoginPreHook": "err","ExecuteRegistrationPreHook": "err","ExecuteSettingsPreHook": "err","ExecuteVerificationPreHook": "err","ExecuteRecoveryPreHook": "err"}`)}})
5050

5151
res, body := makeRequestPre(t, newServer(t))
5252
assert.EqualValues(t, http.StatusInternalServerError, res.StatusCode, "%s", body)
@@ -55,7 +55,7 @@ func TestSelfServicePreHook(
5555

5656
t.Run("case=abort if hooks aborts", func(t *testing.T) {
5757
t.Cleanup(SelfServiceHookConfigReset(t, conf))
58-
conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{"ExecuteLoginPreHook": "abort","ExecuteRegistrationPreHook": "abort"}`)}})
58+
conf.MustSet(ctx, configKey, []config.SelfServiceHook{{Name: "err", Config: []byte(`{"ExecuteLoginPreHook": "abort","ExecuteRegistrationPreHook": "abort","ExecuteSettingsPreHook": "abort","ExecuteVerificationPreHook": "abort","ExecuteRecoveryPreHook": "abort"}`)}})
5959

6060
res, body := makeRequestPre(t, newServer(t))
6161
assert.EqualValues(t, http.StatusOK, res.StatusCode)
@@ -154,7 +154,7 @@ func SelfServiceHookRegistrationErrorHandler(t *testing.T, w http.ResponseWriter
154154
}
155155

156156
func SelfServiceHookSettingsErrorHandler(t *testing.T, w http.ResponseWriter, r *http.Request, err error) bool {
157-
return SelfServiceHookErrorHandler(t, w, r, settings.ErrHookAbortRequest, err)
157+
return SelfServiceHookErrorHandler(t, w, r, settings.ErrHookAbortFlow, err)
158158
}
159159

160160
func SelfServiceHookErrorHandler(t *testing.T, w http.ResponseWriter, r *http.Request, abortErr error, actualErr error) bool {
@@ -182,6 +182,18 @@ func SelfServiceMakeRegistrationPreHookRequest(t *testing.T, ts *httptest.Server
182182
return SelfServiceMakeHookRequest(t, ts, "/registration/pre", false, url.Values{})
183183
}
184184

185+
func SelfServiceMakeSettingsPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) {
186+
return SelfServiceMakeHookRequest(t, ts, "/settings/pre", false, url.Values{})
187+
}
188+
189+
func SelfServiceMakeRecoveryPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) {
190+
return SelfServiceMakeHookRequest(t, ts, "/recovery/pre", false, url.Values{})
191+
}
192+
193+
func SelfServiceMakeVerificationPreHookRequest(t *testing.T, ts *httptest.Server) (*http.Response, string) {
194+
return SelfServiceMakeHookRequest(t, ts, "/verification/pre", false, url.Values{})
195+
}
196+
185197
func SelfServiceMakeRegistrationPostHookRequest(t *testing.T, ts *httptest.Server, asAPI bool, query url.Values) (*http.Response, string) {
186198
return SelfServiceMakeHookRequest(t, ts, "/registration/post", asAPI, query)
187199
}

selfservice/flow/recovery/handler.go

+11
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type (
5151
x.CSRFProvider
5252
config.Provider
5353
ErrorHandlerProvider
54+
HookExecutorProvider
5455
}
5556
Handler struct {
5657
d handlerDependencies
@@ -127,6 +128,11 @@ func (h *Handler) initAPIFlow(w http.ResponseWriter, r *http.Request, _ httprout
127128
return
128129
}
129130

131+
if err := h.d.RecoveryExecutor().PreRecoveryHook(w, r, req); err != nil {
132+
h.d.Writer().WriteError(w, r, err)
133+
return
134+
}
135+
130136
if err := h.d.RecoveryFlowPersister().CreateRecoveryFlow(r.Context(), req); err != nil {
131137
h.d.Writer().WriteError(w, r, err)
132138
return
@@ -178,6 +184,11 @@ func (h *Handler) initBrowserFlow(w http.ResponseWriter, r *http.Request, _ http
178184
return
179185
}
180186

187+
if err := h.d.RecoveryExecutor().PreRecoveryHook(w, r, f); err != nil {
188+
h.d.Writer().WriteError(w, r, err)
189+
return
190+
}
191+
181192
if err := h.d.RecoveryFlowPersister().CreateRecoveryFlow(r.Context(), f); err != nil {
182193
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
183194
return

0 commit comments

Comments
 (0)