Skip to content

Commit 700c2b3

Browse files
committed
resolve user identifier to stable ID
currently, the policy approach node to user matching with a quite naive approach looking at the username provided in the policy and matched it with the username on the nodes. This worked ok as long as usernames were unique and did not change. As usernames are no longer guarenteed to be unique in an OIDC environment we cant rely on this. This changes the mechanism that matches the user string (now user token) with nodes: - first find all potential users by looking up: - database ID - provider ID (OIDC) - username/email If more than one user is matching, then the query is rejected, and zero matching nodes are returned. When a single user is found, the node is matched against the User database ID, which are also present on the actual node. This means that from this commit, users can use the following to identify users in the policy: - provider identity (iss + sub) - username - email - database id There are more changes coming to this, so it is not recommended to start using any of these new abilities, with the exception of email, which will not change since it includes an @. Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent e2d5ee0 commit 700c2b3

File tree

8 files changed

+428
-132
lines changed

8 files changed

+428
-132
lines changed

hscontrol/app.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -1026,14 +1026,18 @@ func (h *Headscale) loadACLPolicy() error {
10261026
if err != nil {
10271027
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
10281028
}
1029+
users, err := h.db.ListUsers()
1030+
if err != nil {
1031+
return fmt.Errorf("loading users from database to validate policy: %w", err)
1032+
}
10291033

1030-
_, err = pol.CompileFilterRules(nodes)
1034+
_, err = pol.CompileFilterRules(users, nodes)
10311035
if err != nil {
10321036
return fmt.Errorf("verifying policy rules: %w", err)
10331037
}
10341038

10351039
if len(nodes) > 0 {
1036-
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
1040+
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
10371041
if err != nil {
10381042
return fmt.Errorf("verifying SSH rules: %w", err)
10391043
}

hscontrol/db/node_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
255255
c.Assert(err, check.IsNil)
256256
c.Assert(len(testPeers), check.Equals, 9)
257257

258-
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
258+
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
259259
c.Assert(err, check.IsNil)
260260

261-
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
261+
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
262262
c.Assert(err, check.IsNil)
263263

264264
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)

hscontrol/db/routes.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
648648
if approvedAlias == node.User.Username() {
649649
approvedRoutes = append(approvedRoutes, advertisedRoute)
650650
} else {
651+
users, err := ListUsers(tx)
652+
if err != nil {
653+
return fmt.Errorf("looking up users to expand route alias: %w", err)
654+
}
655+
651656
// TODO(kradalby): figure out how to get this to depend on less stuff
652-
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
657+
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
653658
if err != nil {
654659
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
655660
}

hscontrol/grpcv1.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,18 @@ func (api headscaleV1APIServer) SetPolicy(
737737
if err != nil {
738738
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
739739
}
740+
users, err := api.h.db.ListUsers()
741+
if err != nil {
742+
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
743+
}
740744

741-
_, err = pol.CompileFilterRules(nodes)
745+
_, err = pol.CompileFilterRules(users, nodes)
742746
if err != nil {
743747
return nil, fmt.Errorf("verifying policy rules: %w", err)
744748
}
745749

746750
if len(nodes) > 0 {
747-
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
751+
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
748752
if err != nil {
749753
return nil, fmt.Errorf("verifying SSH rules: %w", err)
750754
}

hscontrol/mapper/mapper.go

+16-3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
153153
func (m *Mapper) fullMapResponse(
154154
node *types.Node,
155155
peers types.Nodes,
156+
users []types.User,
156157
pol *policy.ACLPolicy,
157158
capVer tailcfg.CapabilityVersion,
158159
) (*tailcfg.MapResponse, error) {
@@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
167168
pol,
168169
node,
169170
capVer,
171+
users,
170172
peers,
171173
peers,
172174
m.cfg,
@@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
189191
if err != nil {
190192
return nil, err
191193
}
194+
users, err := m.db.ListUsers()
195+
if err != nil {
196+
return nil, err
197+
}
192198

193-
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
199+
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
194200
if err != nil {
195201
return nil, err
196202
}
@@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
253259
return nil, err
254260
}
255261

262+
users, err := m.db.ListUsers()
263+
if err != nil {
264+
return nil, fmt.Errorf("listing users for map response: %w", err)
265+
}
266+
256267
var removedIDs []tailcfg.NodeID
257268
var changedIDs []types.NodeID
258269
for nodeID, nodeChanged := range changed {
@@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
276287
pol,
277288
node,
278289
mapRequest.Version,
290+
users,
279291
peers,
280292
changedNodes,
281293
m.cfg,
@@ -508,16 +520,17 @@ func appendPeerChanges(
508520
pol *policy.ACLPolicy,
509521
node *types.Node,
510522
capVer tailcfg.CapabilityVersion,
523+
users []types.User,
511524
peers types.Nodes,
512525
changed types.Nodes,
513526
cfg *types.Config,
514527
) error {
515-
packetFilter, err := pol.CompileFilterRules(append(peers, node))
528+
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
516529
if err != nil {
517530
return err
518531
}
519532

520-
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
533+
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
521534
if err != nil {
522535
return err
523536
}

hscontrol/mapper/mapper_test.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
159159
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
160160
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
161161

162+
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
163+
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
164+
162165
mini := &types.Node{
163166
ID: 0,
164167
MachineKey: mustMK(
@@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
173176
IPv4: iap("100.64.0.1"),
174177
Hostname: "mini",
175178
GivenName: "mini",
176-
UserID: 0,
177-
User: types.User{Name: "mini"},
179+
UserID: user1.ID,
180+
User: user1,
178181
ForcedTags: []string{},
179182
AuthKey: &types.PreAuthKey{},
180183
LastSeen: &lastSeen,
@@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
253256
IPv4: iap("100.64.0.2"),
254257
Hostname: "peer1",
255258
GivenName: "peer1",
256-
UserID: 0,
257-
User: types.User{Name: "mini"},
259+
UserID: user1.ID,
260+
User: user1,
258261
ForcedTags: []string{},
259262
LastSeen: &lastSeen,
260263
Expiry: &expire,
@@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
308311
IPv4: iap("100.64.0.3"),
309312
Hostname: "peer2",
310313
GivenName: "peer2",
311-
UserID: 1,
312-
User: types.User{Name: "peer2"},
314+
UserID: user2.ID,
315+
User: user2,
313316
ForcedTags: []string{},
314317
LastSeen: &lastSeen,
315318
Expiry: &expire,
@@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
468471
got, err := mappy.fullMapResponse(
469472
tt.node,
470473
tt.peers,
474+
[]types.User{user1, user2},
471475
tt.pol,
472476
0,
473477
)

0 commit comments

Comments
 (0)