Skip to content

Commit 5bb8fa2

Browse files
authored
AppRole/Identity: Fix for race when creating an entity during login (#3932)
* possible fix for race in approle login while creating entity * Add a test that hits the login request concurrently * address review comments
1 parent e47c7e8 commit 5bb8fa2

5 files changed

+167
-24
lines changed
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package command
2+
3+
import (
4+
"sync"
5+
"testing"
6+
7+
"github.com/hashicorp/vault/api"
8+
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
9+
vaulthttp "github.com/hashicorp/vault/http"
10+
"github.com/hashicorp/vault/logical"
11+
"github.com/hashicorp/vault/vault"
12+
logxi "github.com/mgutz/logxi/v1"
13+
)
14+
15+
func TestAppRole_Integ_ConcurrentLogins(t *testing.T) {
16+
var err error
17+
coreConfig := &vault.CoreConfig{
18+
DisableMlock: true,
19+
DisableCache: true,
20+
Logger: logxi.NullLog,
21+
CredentialBackends: map[string]logical.Factory{
22+
"approle": credAppRole.Factory,
23+
},
24+
}
25+
26+
cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{
27+
HandlerFunc: vaulthttp.Handler,
28+
})
29+
30+
cluster.Start()
31+
defer cluster.Cleanup()
32+
33+
cores := cluster.Cores
34+
35+
vault.TestWaitActive(t, cores[0].Core)
36+
37+
client := cores[0].Client
38+
39+
err = client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{
40+
Type: "approle",
41+
})
42+
if err != nil {
43+
t.Fatal(err)
44+
}
45+
46+
_, err = client.Logical().Write("auth/approle/role/role1", map[string]interface{}{
47+
"bind_secret_id": "true",
48+
"period": "300",
49+
})
50+
if err != nil {
51+
t.Fatal(err)
52+
}
53+
54+
secret, err := client.Logical().Write("auth/approle/role/role1/secret-id", nil)
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
secretID := secret.Data["secret_id"].(string)
59+
60+
secret, err = client.Logical().Read("auth/approle/role/role1/role-id")
61+
if err != nil {
62+
t.Fatal(err)
63+
}
64+
roleID := secret.Data["role_id"].(string)
65+
66+
wg := &sync.WaitGroup{}
67+
68+
for i := 0; i < 100; i++ {
69+
wg.Add(1)
70+
go func() {
71+
defer wg.Done()
72+
secret, err = client.Logical().Write("auth/approle/login", map[string]interface{}{
73+
"role_id": roleID,
74+
"secret_id": secretID,
75+
})
76+
if err != nil {
77+
t.Fatal(err)
78+
}
79+
if secret.Auth.ClientToken == "" {
80+
t.Fatalf("expected a successful login")
81+
}
82+
}()
83+
84+
}
85+
wg.Wait()
86+
}

vault/identity_store.go

+43-6
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,27 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
249249
return nil, fmt.Errorf("missing alias name")
250250
}
251251

252-
alias, err := i.MemDBAliasByFactors(mountAccessor, aliasName, false, false)
252+
txn := i.db.Txn(false)
253+
254+
return i.entityByAliasFactorsInTxn(txn, mountAccessor, aliasName, clone)
255+
}
256+
257+
// entityByAlaisFactorsInTxn fetches the entity based on factors of alias, i.e
258+
// mount accessor and the alias name.
259+
func (i *IdentityStore) entityByAliasFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool) (*identity.Entity, error) {
260+
if txn == nil {
261+
return nil, fmt.Errorf("nil txn")
262+
}
263+
264+
if mountAccessor == "" {
265+
return nil, fmt.Errorf("missing mount accessor")
266+
}
267+
268+
if aliasName == "" {
269+
return nil, fmt.Errorf("missing alias name")
270+
}
271+
272+
alias, err := i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, false, false)
253273
if err != nil {
254274
return nil, err
255275
}
@@ -258,12 +278,12 @@ func (i *IdentityStore) entityByAliasFactors(mountAccessor, aliasName string, cl
258278
return nil, nil
259279
}
260280

261-
return i.MemDBEntityByAliasID(alias.ID, clone)
281+
return i.MemDBEntityByAliasIDInTxn(txn, alias.ID, clone)
262282
}
263283

264-
// CreateEntity creates a new entity. This is used by core to
284+
// CreateOrFetchEntity creates a new entity. This is used by core to
265285
// associate each login attempt by an alias to a unified entity in Vault.
266-
func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, error) {
286+
func (i *IdentityStore) CreateOrFetchEntity(alias *logical.Alias) (*identity.Entity, error) {
267287
var entity *identity.Entity
268288
var err error
269289

@@ -290,9 +310,24 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
290310
return nil, err
291311
}
292312
if entity != nil {
293-
return nil, fmt.Errorf("alias already belongs to a different entity")
313+
return entity, nil
294314
}
295315

316+
// Create a MemDB transaction to update both alias and entity
317+
txn := i.db.Txn(true)
318+
defer txn.Abort()
319+
320+
// Check if an entity was created before acquiring the lock
321+
entity, err = i.entityByAliasFactorsInTxn(txn, alias.MountAccessor, alias.Name, false)
322+
if err != nil {
323+
return nil, err
324+
}
325+
if entity != nil {
326+
return entity, nil
327+
}
328+
329+
i.logger.Debug("identity: creating a new entity", "alias", alias)
330+
296331
entity = &identity.Entity{}
297332

298333
err = i.sanitizeEntity(entity)
@@ -320,10 +355,12 @@ func (i *IdentityStore) CreateEntity(alias *logical.Alias) (*identity.Entity, er
320355
}
321356

322357
// Update MemDB and persist entity object
323-
err = i.upsertEntity(entity, nil, true)
358+
err = i.upsertEntityInTxn(txn, entity, nil, true, false)
324359
if err != nil {
325360
return nil, err
326361
}
327362

363+
txn.Commit()
364+
328365
return entity, nil
329366
}

vault/identity_store_test.go

+16-6
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ import (
99
"github.com/hashicorp/vault/logical"
1010
)
1111

12-
func TestIdentityStore_CreateEntity(t *testing.T) {
12+
func TestIdentityStore_CreateOrFetchEntity(t *testing.T) {
1313
is, ghAccessor, _ := testIdentityStoreWithGithubAuth(t)
1414
alias := &logical.Alias{
1515
MountType: "github",
1616
MountAccessor: ghAccessor,
1717
Name: "githubuser",
1818
}
1919

20-
entity, err := is.CreateEntity(alias)
20+
entity, err := is.CreateOrFetchEntity(alias)
2121
if err != nil {
2222
t.Fatal(err)
2323
}
@@ -33,10 +33,20 @@ func TestIdentityStore_CreateEntity(t *testing.T) {
3333
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
3434
}
3535

36-
// Try recreating an entity with the same alias details. It should fail.
37-
entity, err = is.CreateEntity(alias)
38-
if err == nil {
39-
t.Fatalf("expected an error")
36+
entity, err = is.CreateOrFetchEntity(alias)
37+
if err != nil {
38+
t.Fatal(err)
39+
}
40+
if entity == nil {
41+
t.Fatalf("expected a non-nil entity")
42+
}
43+
44+
if len(entity.Aliases) != 1 {
45+
t.Fatalf("bad: length of aliases; expected: 1, actual: %d", len(entity.Aliases))
46+
}
47+
48+
if entity.Aliases[0].Name != alias.Name {
49+
t.Fatalf("bad: alias name; expected: %q, actual: %q", alias.Name, entity.Aliases[0].Name)
4050
}
4151
}
4252

vault/identity_store_util.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,29 @@ func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clo
666666
return nil, fmt.Errorf("missing mount accessor")
667667
}
668668

669+
txn := i.db.Txn(false)
670+
671+
return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
672+
}
673+
674+
func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
675+
if txn == nil {
676+
return nil, fmt.Errorf("nil txn")
677+
}
678+
679+
if aliasName == "" {
680+
return nil, fmt.Errorf("missing alias name")
681+
}
682+
683+
if mountAccessor == "" {
684+
return nil, fmt.Errorf("missing mount accessor")
685+
}
686+
669687
tableName := entityAliasesTable
670688
if groupAlias {
671689
tableName = groupAliasesTable
672690
}
673691

674-
txn := i.db.Txn(false)
675692
aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
676693
if err != nil {
677694
return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %v", err)

vault/request_handling.go

+4-11
Original file line numberDiff line numberDiff line change
@@ -436,22 +436,15 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re
436436

437437
var err error
438438

439-
// Check if an entity already exists for the given alias
440-
entity, err = c.identityStore.entityByAliasFactors(auth.Alias.MountAccessor, auth.Alias.Name, false)
439+
// Fetch the entity for the alias, or create an entity if one
440+
// doesn't exist.
441+
entity, err = c.identityStore.CreateOrFetchEntity(auth.Alias)
441442
if err != nil {
442443
return nil, nil, err
443444
}
444445

445-
// If not, create one.
446446
if entity == nil {
447-
c.logger.Debug("core: creating a new entity", "alias", auth.Alias)
448-
entity, err = c.identityStore.CreateEntity(auth.Alias)
449-
if err != nil {
450-
return nil, nil, err
451-
}
452-
if entity == nil {
453-
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
454-
}
447+
return nil, nil, fmt.Errorf("failed to create an entity for the authenticated alias")
455448
}
456449

457450
auth.EntityID = entity.ID

0 commit comments

Comments
 (0)