Skip to content

Commit b49f4a0

Browse files
committed
Implementation
1 parent 8dcef46 commit b49f4a0

File tree

11 files changed

+377
-120
lines changed

11 files changed

+377
-120
lines changed

api/types/access_request.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ type AccessRequest interface {
131131
GetDryRun() bool
132132
// SetDryRun sets the dry run flag on the request.
133133
SetDryRun(bool)
134+
// GetDryRunEnrichment gets the dry run enrichment data.
135+
GetDryRunEnrichment() *AccessRequestDryRunEnrichment
136+
// SetDryRunEnrichment sets the dry run enrichment data.
137+
SetDryRunEnrichment(*AccessRequestDryRunEnrichment)
134138
// Copy returns a copy of the access request resource.
135139
Copy() AccessRequest
136140
}
@@ -514,6 +518,16 @@ func (r *AccessRequestV3) SetDryRun(dryRun bool) {
514518
r.Spec.DryRun = dryRun
515519
}
516520

521+
// GetDryRunEnrichment gets the dry run enrichment data.
522+
func (r *AccessRequestV3) GetDryRunEnrichment() *AccessRequestDryRunEnrichment {
523+
return r.Spec.DryRunEnrichment
524+
}
525+
526+
// SetDryRunEnrichment sets the dry run enrichment data.
527+
func (r *AccessRequestV3) SetDryRunEnrichment(enrichment *AccessRequestDryRunEnrichment) {
528+
r.Spec.DryRunEnrichment = enrichment
529+
}
530+
517531
// Copy returns a copy of the access request resource.
518532
func (r *AccessRequestV3) Copy() AccessRequest {
519533
return utils.CloneProtoMsg(r)

lib/auth/auth.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4641,7 +4641,7 @@ func (a *Server) getValidatedAccessRequest(ctx context.Context, identity tlsca.I
46414641
return nil, trace.AccessDenied("access request %q is awaiting approval", accessRequestID)
46424642
}
46434643

4644-
if err := services.ValidateAccessRequestForUser(ctx, a.clock, a, req, identity); err != nil {
4644+
if _, err := services.ValidateAccessRequestForUser(ctx, a.clock, a, req, identity); err != nil {
46454645
return nil, trace.Wrap(err)
46464646
}
46474647

@@ -5232,10 +5232,12 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ
52325232

52335233
req.SetCreationTime(now)
52345234

5235-
// Always perform variable expansion on creation only; this ensures the
5236-
// access request that is reviewed is the same that is approved.
5237-
expandOpts := services.ExpandVars()
5238-
if err := services.ValidateAccessRequestForUser(ctx, a.clock, a, req, identity, expandOpts); err != nil {
5235+
validateOpts := []services.ValidateRequestOption{
5236+
services.WithExpandVars(true), // always perform variable expansion on creation
5237+
services.WithDryRun(req.GetDryRun()),
5238+
}
5239+
dryRunEnrichment, err := services.ValidateAccessRequestForUser(ctx, a.clock, a, req, identity, validateOpts...)
5240+
if err != nil {
52395241
return nil, trace.Wrap(err)
52405242
}
52415243

@@ -5251,6 +5253,7 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ
52515253
}
52525254

52535255
if req.GetDryRun() {
5256+
req.SetDryRunEnrichment(dryRunEnrichment)
52545257
_, promotions := a.generateAccessRequestPromotions(ctx, req)
52555258
// update the request with additional reviewers if possible.
52565259
updateAccessRequestWithAdditionalReviewers(ctx, req, a.AccessLists, promotions)

lib/auth/auth_test.go

Lines changed: 217 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4196,6 +4196,50 @@ func TestAccessRequestAuditLog(t *testing.T) {
41964196
require.Equal(t, "APPROVED", arc.RequestState)
41974197
}
41984198

4199+
func testCreateRole(t *testing.T, server *TestTLSServer, name string, setup func(*types.RoleSpecV6)) types.Role {
4200+
t.Helper()
4201+
ctx := context.Background()
4202+
4203+
spec := types.RoleSpecV6{
4204+
Allow: types.RoleConditions{
4205+
Request: &types.AccessRequestConditions{
4206+
Reason: &types.AccessRequestConditionsReason{},
4207+
},
4208+
ReviewRequests: &types.AccessReviewConditions{},
4209+
},
4210+
Deny: types.RoleConditions{
4211+
Request: &types.AccessRequestConditions{},
4212+
ReviewRequests: &types.AccessReviewConditions{},
4213+
},
4214+
}
4215+
setup(&spec)
4216+
4217+
role, err := types.NewRole(name, spec)
4218+
require.NoError(t, err, "types.NewRole")
4219+
4220+
createdRole, err := server.AuthServer.AuthServer.UpsertRole(ctx, role)
4221+
require.NoError(t, err, "AuthServer.UpsertRole")
4222+
4223+
return createdRole
4224+
}
4225+
4226+
func testCreateUserWithRoles(t *testing.T, server *TestTLSServer, user string, roles ...string) (TestIdentity, *authclient.Client) {
4227+
t.Helper()
4228+
ctx := context.Background()
4229+
4230+
u, err := types.NewUser(user)
4231+
require.NoError(t, err, "types.NewUser")
4232+
u.SetRoles(roles)
4233+
_, err = server.AuthServer.AuthServer.UpsertUser(ctx, u)
4234+
require.NoError(t, err, "AuthServer.UpsertUser")
4235+
4236+
identity := TestUser(user)
4237+
client, err := server.NewClient(identity)
4238+
require.NoError(t, err, "server.NewClient")
4239+
4240+
return identity, client
4241+
}
4242+
41994243
func TestAccessRequestNotifications(t *testing.T) {
42004244
t.Parallel()
42014245
ctx := context.Background()
@@ -4214,69 +4258,36 @@ func TestAccessRequestNotifications(t *testing.T) {
42144258
requesterUsername := "requester"
42154259
requestRoleName := "requestRole"
42164260

4217-
reviewerRole, err := types.NewRole(reviewerUsername, types.RoleSpecV6{
4218-
Allow: types.RoleConditions{
4219-
Logins: []string{"user"},
4220-
ReviewRequests: &types.AccessReviewConditions{
4221-
Roles: []string{"requestRole"},
4222-
},
4223-
},
4261+
reviewerRole := testCreateRole(t, testTLSServer, reviewerUsername, func(spec *types.RoleSpecV6) {
4262+
spec.Allow.Logins = []string{"user"}
4263+
spec.Allow.ReviewRequests.Roles = []string{"requestRole"}
42244264
})
4225-
require.NoError(t, err)
42264265

4227-
requesterRole, err := types.NewRole(requesterUsername, types.RoleSpecV6{
4228-
Allow: types.RoleConditions{
4229-
Request: &types.AccessRequestConditions{
4230-
Roles: []string{requestRoleName},
4231-
},
4232-
},
4266+
requesterRole := testCreateRole(t, testTLSServer, requesterUsername, func(spec *types.RoleSpecV6) {
4267+
spec.Allow.Request.Roles = []string{requestRoleName}
42334268
})
4234-
require.NoError(t, err)
42354269

4236-
requestedRole, err := types.NewRole(requestRoleName, types.RoleSpecV6{
4237-
Allow: types.RoleConditions{
4238-
Request: &types.AccessRequestConditions{
4239-
Roles: []string{requestRoleName},
4240-
},
4241-
},
4270+
requestRole := testCreateRole(t, testTLSServer, requestRoleName, func(spec *types.RoleSpecV6) {
4271+
spec.Allow.Request.Roles = []string{requestRoleName}
42424272
})
4243-
require.NoError(t, err)
4244-
_, err = testTLSServer.AuthServer.AuthServer.UpsertRole(ctx, requestedRole)
4245-
require.NoError(t, err)
42464273

4247-
_, err = testTLSServer.AuthServer.AuthServer.UpsertRole(ctx, reviewerRole)
4248-
require.NoError(t, err)
4249-
reviewer, err := types.NewUser(reviewerUsername)
4250-
require.NoError(t, err)
4251-
reviewer.SetRoles([]string{reviewerUsername})
4252-
_, err = testTLSServer.AuthServer.AuthServer.UpsertUser(ctx, reviewer)
4253-
require.NoError(t, err)
4274+
reviewer, reviewerClient := testCreateUserWithRoles(t, testTLSServer, reviewerUsername, reviewerRole.GetName())
42544275

4255-
_, err = testTLSServer.AuthServer.AuthServer.UpsertRole(ctx, requesterRole)
4256-
require.NoError(t, err)
4257-
requester, err := types.NewUser(requesterUsername)
4258-
require.NoError(t, err)
4259-
requester.SetRoles([]string{requesterUsername})
4260-
_, err = testTLSServer.AuthServer.AuthServer.UpsertUser(ctx, requester)
4261-
require.NoError(t, err)
4276+
requester, _ := testCreateUserWithRoles(t, testTLSServer, requesterUsername, requesterRole.GetName())
42624277

4263-
accessRequest, err := types.NewAccessRequest(uuid.NewString(), requesterUsername, requestRoleName)
4278+
accessRequest, err := types.NewAccessRequest(uuid.NewString(), requester.GetUsername(), requestRole.GetName())
42644279
require.NoError(t, err)
4265-
req, err := testTLSServer.AuthServer.AuthServer.CreateAccessRequestV2(ctx, accessRequest, TestUser(requesterUsername).I.GetIdentity())
4280+
req, err := testTLSServer.AuthServer.AuthServer.CreateAccessRequestV2(ctx, accessRequest, reviewer.I.GetIdentity())
42664281
require.NoError(t, err)
42674282

42684283
// Verify that a global notification was created which matches for users who can review the requestRole.
42694284
globalNotifsResp, _, err := testTLSServer.AuthServer.AuthServer.Notifications.ListGlobalNotifications(ctx, 100, "")
42704285
require.NoError(t, err)
42714286
require.Len(t, globalNotifsResp, 1)
42724287
require.Equal(t, &types.AccessReviewConditions{
4273-
Roles: []string{requestRoleName},
4288+
Roles: []string{requestRole.GetName()},
42744289
}, globalNotifsResp[0].GetSpec().GetByPermissions().GetRoleConditions()[0].ReviewRequests)
42754290

4276-
reviewerIdentity := TestUser(reviewerUsername)
4277-
reviewerClient, err := testTLSServer.NewClient(reviewerIdentity)
4278-
require.NoError(t, err)
4279-
42804291
// Approve the request
42814292
_, err = reviewerClient.SubmitAccessReview(ctx, types.AccessReviewSubmission{
42824293
RequestID: req.GetName(),
@@ -4292,9 +4303,9 @@ func TestAccessRequestNotifications(t *testing.T) {
42924303
require.Contains(t, userNotifsResp[0].GetMetadata().GetLabels()[types.NotificationTitleLabel], "reviewer approved your access request")
42934304

42944305
// Create another access request.
4295-
accessRequest, err = types.NewAccessRequest(uuid.NewString(), requesterUsername, requestRoleName)
4306+
accessRequest, err = types.NewAccessRequest(uuid.NewString(), requester.GetUsername(), requestRole.GetName())
42964307
require.NoError(t, err)
4297-
req, err = testTLSServer.AuthServer.AuthServer.CreateAccessRequestV2(ctx, accessRequest, TestUser(requesterUsername).I.GetIdentity())
4308+
req, err = testTLSServer.AuthServer.AuthServer.CreateAccessRequestV2(ctx, accessRequest, TestUser(requester.GetUsername()).I.GetIdentity())
42984309
require.NoError(t, err)
42994310

43004311
// Deny the request.
@@ -4312,6 +4323,165 @@ func TestAccessRequestNotifications(t *testing.T) {
43124323
require.Contains(t, userNotifsResp[1].GetMetadata().GetLabels()[types.NotificationTitleLabel], "reviewer denied your access request")
43134324
}
43144325

4326+
func testNewAccessRequest(t *testing.T, user string, roles ...string) types.AccessRequest {
4327+
t.Helper()
4328+
r, err := types.NewAccessRequest(uuid.NewString(), user, roles...)
4329+
require.NoError(t, err, "types.NewAccessRequest")
4330+
return r
4331+
}
4332+
4333+
func TestAccessRequestDryRunEnrichment(t *testing.T) {
4334+
t.Parallel()
4335+
ctx := context.Background()
4336+
4337+
testAuthServer, err := NewTestAuthServer(TestAuthServerConfig{
4338+
Dir: t.TempDir(),
4339+
Clock: clockwork.NewFakeClock(),
4340+
})
4341+
require.NoError(t, err)
4342+
testTLSServer, err := testAuthServer.NewTestTLSServer()
4343+
require.NoError(t, err)
4344+
4345+
someRole := testCreateRole(t, testTLSServer, "some-role", func(spec *types.RoleSpecV6) {})
4346+
4347+
someRoleRequesterRole := testCreateRole(t, testTLSServer, "some-role-requester", func(spec *types.RoleSpecV6) {
4348+
spec.Allow.Request.Roles = []string{someRole.GetName()}
4349+
})
4350+
4351+
someRoleRequesterRoleRequiringReason := testCreateRole(t, testTLSServer, "some-role-requester-requiring-reason", func(spec *types.RoleSpecV6) {
4352+
spec.Allow.Request.Roles = []string{someRole.GetName()}
4353+
spec.Allow.Request.Reason.Mode = types.RequestReasonModeRequired
4354+
})
4355+
4356+
globalPromptRole1 := testCreateRole(t, testTLSServer, "prompt-role-1", func(spec *types.RoleSpecV6) {
4357+
spec.Options.RequestPrompt = "test prompt #1"
4358+
})
4359+
globalPromptRole2 := testCreateRole(t, testTLSServer, "prompt-role-2", func(spec *types.RoleSpecV6) {
4360+
spec.Options.RequestPrompt = "test prompt #2"
4361+
})
4362+
4363+
t.Run("requesting-role-no-reason-required-no-prompts", func(t *testing.T) {
4364+
requester, requesterClient := testCreateUserWithRoles(t, testTLSServer, "requester",
4365+
someRoleRequesterRole.GetName(),
4366+
)
4367+
4368+
dryRunAccessRequest := testNewAccessRequest(t, requester.GetUsername(), someRole.GetName())
4369+
dryRunAccessRequest.SetDryRun(true)
4370+
4371+
resp, err := requesterClient.CreateAccessRequestV2(ctx, dryRunAccessRequest)
4372+
require.NoError(t, err)
4373+
4374+
require.NotNil(t, resp.GetDryRunEnrichment())
4375+
// check reason mode
4376+
require.Equal(t, types.RequestReasonModeOptional, resp.GetDryRunEnrichment().ReasonMode)
4377+
// check prompts
4378+
require.Len(t, resp.GetDryRunEnrichment().ReasonPrompts, 0)
4379+
})
4380+
4381+
t.Run("requesting-role-reason-required", func(t *testing.T) {
4382+
requester, requesterClient := testCreateUserWithRoles(t, testTLSServer, "requester",
4383+
someRoleRequesterRoleRequiringReason.GetName(),
4384+
)
4385+
4386+
dryRunAccessRequest := testNewAccessRequest(t, requester.GetUsername(), someRole.GetName())
4387+
dryRunAccessRequest.SetDryRun(true)
4388+
4389+
resp, err := requesterClient.CreateAccessRequestV2(ctx, dryRunAccessRequest)
4390+
require.NoError(t, err)
4391+
4392+
require.NotNil(t, resp.GetDryRunEnrichment())
4393+
// check reason mode
4394+
require.Equal(t, types.RequestReasonModeRequired, resp.GetDryRunEnrichment().ReasonMode)
4395+
// check prompts
4396+
require.Len(t, resp.GetDryRunEnrichment().ReasonPrompts, 0)
4397+
})
4398+
4399+
t.Run("requesting-role-multiple-prompts", func(t *testing.T) {
4400+
requester, requesterClient := testCreateUserWithRoles(t, testTLSServer, "requester",
4401+
someRoleRequesterRole.GetName(),
4402+
globalPromptRole1.GetName(),
4403+
globalPromptRole2.GetName(),
4404+
)
4405+
4406+
dryRunAccessRequest := testNewAccessRequest(t, requester.GetUsername(), someRole.GetName())
4407+
dryRunAccessRequest.SetDryRun(true)
4408+
4409+
resp, err := requesterClient.CreateAccessRequestV2(ctx, dryRunAccessRequest)
4410+
require.NoError(t, err)
4411+
4412+
require.NotNil(t, resp.GetDryRunEnrichment())
4413+
// check reason mode
4414+
require.Equal(t, types.RequestReasonModeOptional, resp.GetDryRunEnrichment().ReasonMode)
4415+
// check prompts
4416+
require.Len(t, resp.GetDryRunEnrichment().ReasonPrompts, 2)
4417+
require.Contains(t, resp.GetDryRunEnrichment().ReasonPrompts, globalPromptRole1.GetOptions().RequestPrompt)
4418+
require.Contains(t, resp.GetDryRunEnrichment().ReasonPrompts, globalPromptRole2.GetOptions().RequestPrompt)
4419+
})
4420+
4421+
t.Run("requesting-role-reason-required-and-multiple-prompts", func(t *testing.T) {
4422+
requester, requesterClient := testCreateUserWithRoles(t, testTLSServer, "requester",
4423+
someRoleRequesterRole.GetName(),
4424+
someRoleRequesterRoleRequiringReason.GetName(),
4425+
globalPromptRole1.GetName(),
4426+
globalPromptRole2.GetName(),
4427+
)
4428+
4429+
dryRunAccessRequest := testNewAccessRequest(t, requester.GetUsername(), someRole.GetName())
4430+
dryRunAccessRequest.SetDryRun(true)
4431+
4432+
resp, err := requesterClient.CreateAccessRequestV2(ctx, dryRunAccessRequest)
4433+
require.NoError(t, err)
4434+
4435+
require.NotNil(t, resp.GetDryRunEnrichment())
4436+
// check reason mode
4437+
require.Equal(t, types.RequestReasonModeRequired, resp.GetDryRunEnrichment().ReasonMode)
4438+
// check prompts
4439+
require.Len(t, resp.GetDryRunEnrichment().ReasonPrompts, 2)
4440+
require.Contains(t, resp.GetDryRunEnrichment().ReasonPrompts, globalPromptRole1.GetOptions().RequestPrompt)
4441+
require.Contains(t, resp.GetDryRunEnrichment().ReasonPrompts, globalPromptRole2.GetOptions().RequestPrompt)
4442+
})
4443+
4444+
t.Run("requesting-role-prompts-sorted-and-duplicated", func(t *testing.T) {
4445+
globalPromptRole1 := testCreateRole(t, testTLSServer, "prompt-role-1", func(spec *types.RoleSpecV6) {
4446+
spec.Options.RequestPrompt = "C test prompt"
4447+
})
4448+
globalPromptRole2 := testCreateRole(t, testTLSServer, "prompt-role-2", func(spec *types.RoleSpecV6) {
4449+
spec.Options.RequestPrompt = "A test prompt"
4450+
})
4451+
globalPromptRole3 := testCreateRole(t, testTLSServer, "prompt-role-3", func(spec *types.RoleSpecV6) {
4452+
spec.Options.RequestPrompt = "B test prompt"
4453+
})
4454+
globalPromptRole4 := testCreateRole(t, testTLSServer, "prompt-role-4", func(spec *types.RoleSpecV6) {
4455+
spec.Options.RequestPrompt = "B test prompt"
4456+
})
4457+
globalPromptRole5 := testCreateRole(t, testTLSServer, "prompt-role-5", func(spec *types.RoleSpecV6) {
4458+
spec.Options.RequestPrompt = "C test prompt"
4459+
})
4460+
4461+
requester, requesterClient := testCreateUserWithRoles(t, testTLSServer, "requester",
4462+
someRoleRequesterRole.GetName(),
4463+
globalPromptRole1.GetName(),
4464+
globalPromptRole2.GetName(),
4465+
globalPromptRole3.GetName(),
4466+
globalPromptRole4.GetName(),
4467+
globalPromptRole5.GetName(),
4468+
)
4469+
4470+
dryRunAccessRequest := testNewAccessRequest(t, requester.GetUsername(), someRole.GetName())
4471+
dryRunAccessRequest.SetDryRun(true)
4472+
4473+
resp, err := requesterClient.CreateAccessRequestV2(ctx, dryRunAccessRequest)
4474+
require.NoError(t, err)
4475+
4476+
require.NotNil(t, resp.GetDryRunEnrichment())
4477+
// check prompts
4478+
require.Len(t, resp.GetDryRunEnrichment().ReasonPrompts, 3)
4479+
require.Equal(t, "A test prompt", resp.GetDryRunEnrichment().ReasonPrompts[0])
4480+
require.Equal(t, "B test prompt", resp.GetDryRunEnrichment().ReasonPrompts[1])
4481+
require.Equal(t, "C test prompt", resp.GetDryRunEnrichment().ReasonPrompts[2])
4482+
})
4483+
}
4484+
43154485
func TestCleanupNotifications(t *testing.T) {
43164486
ctx, cancel := context.WithCancel(context.Background())
43174487
t.Cleanup(cancel)

lib/auth/helpers.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,10 @@ func TestRemoteBuiltin(role types.SystemRole, remoteCluster string) TestIdentity
10551055
}
10561056
}
10571057

1058+
func (i TestIdentity) GetUsername() string {
1059+
return i.I.GetIdentity().Username
1060+
}
1061+
10581062
// NewClientFromWebSession returns new authenticated client from web session
10591063
func (t *TestTLSServer) NewClientFromWebSession(sess types.WebSession) (*authclient.Client, error) {
10601064
tlsConfig, err := t.Identity.TLSConfig(t.AuthServer.CipherSuites)

0 commit comments

Comments
 (0)