Skip to content

Commit fffd236

Browse files
authored
Resolve user to stable unique ID in policy (#2205)
1 parent 3a2589f commit fffd236

File tree

9 files changed

+507
-145
lines changed

9 files changed

+507
-145
lines changed

hscontrol/app.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -1029,14 +1029,18 @@ func (h *Headscale) loadACLPolicy() error {
10291029
if err != nil {
10301030
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
10311031
}
1032+
users, err := h.db.ListUsers()
1033+
if err != nil {
1034+
return fmt.Errorf("loading users from database to validate policy: %w", err)
1035+
}
10321036

1033-
_, err = pol.CompileFilterRules(nodes)
1037+
_, err = pol.CompileFilterRules(users, nodes)
10341038
if err != nil {
10351039
return fmt.Errorf("verifying policy rules: %w", err)
10361040
}
10371041

10381042
if len(nodes) > 0 {
1039-
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
1043+
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
10401044
if err != nil {
10411045
return fmt.Errorf("verifying SSH rules: %w", err)
10421046
}

hscontrol/db/node_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,10 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
256256
c.Assert(err, check.IsNil)
257257
c.Assert(len(testPeers), check.Equals, 9)
258258

259-
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
259+
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
260260
c.Assert(err, check.IsNil)
261261

262-
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers)
262+
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
263263
c.Assert(err, check.IsNil)
264264

265265
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)

hscontrol/db/routes.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -648,8 +648,13 @@ func EnableAutoApprovedRoutes(
648648
if approvedAlias == node.User.Username() {
649649
approvedRoutes = append(approvedRoutes, advertisedRoute)
650650
} else {
651+
users, err := ListUsers(tx)
652+
if err != nil {
653+
return fmt.Errorf("looking up users to expand route alias: %w", err)
654+
}
655+
651656
// TODO(kradalby): figure out how to get this to depend on less stuff
652-
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias)
657+
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
653658
if err != nil {
654659
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
655660
}

hscontrol/grpcv1.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -773,14 +773,18 @@ func (api headscaleV1APIServer) SetPolicy(
773773
if err != nil {
774774
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
775775
}
776+
users, err := api.h.db.ListUsers()
777+
if err != nil {
778+
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
779+
}
776780

777-
_, err = pol.CompileFilterRules(nodes)
781+
_, err = pol.CompileFilterRules(users, nodes)
778782
if err != nil {
779783
return nil, fmt.Errorf("verifying policy rules: %w", err)
780784
}
781785

782786
if len(nodes) > 0 {
783-
_, err = pol.CompileSSHPolicy(nodes[0], nodes)
787+
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
784788
if err != nil {
785789
return nil, fmt.Errorf("verifying SSH rules: %w", err)
786790
}

hscontrol/mapper/mapper.go

+16-3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
153153
func (m *Mapper) fullMapResponse(
154154
node *types.Node,
155155
peers types.Nodes,
156+
users []types.User,
156157
pol *policy.ACLPolicy,
157158
capVer tailcfg.CapabilityVersion,
158159
) (*tailcfg.MapResponse, error) {
@@ -167,6 +168,7 @@ func (m *Mapper) fullMapResponse(
167168
pol,
168169
node,
169170
capVer,
171+
users,
170172
peers,
171173
peers,
172174
m.cfg,
@@ -189,8 +191,12 @@ func (m *Mapper) FullMapResponse(
189191
if err != nil {
190192
return nil, err
191193
}
194+
users, err := m.db.ListUsers()
195+
if err != nil {
196+
return nil, err
197+
}
192198

193-
resp, err := m.fullMapResponse(node, peers, pol, mapRequest.Version)
199+
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
194200
if err != nil {
195201
return nil, err
196202
}
@@ -253,6 +259,11 @@ func (m *Mapper) PeerChangedResponse(
253259
return nil, err
254260
}
255261

262+
users, err := m.db.ListUsers()
263+
if err != nil {
264+
return nil, fmt.Errorf("listing users for map response: %w", err)
265+
}
266+
256267
var removedIDs []tailcfg.NodeID
257268
var changedIDs []types.NodeID
258269
for nodeID, nodeChanged := range changed {
@@ -276,6 +287,7 @@ func (m *Mapper) PeerChangedResponse(
276287
pol,
277288
node,
278289
mapRequest.Version,
290+
users,
279291
peers,
280292
changedNodes,
281293
m.cfg,
@@ -508,16 +520,17 @@ func appendPeerChanges(
508520
pol *policy.ACLPolicy,
509521
node *types.Node,
510522
capVer tailcfg.CapabilityVersion,
523+
users []types.User,
511524
peers types.Nodes,
512525
changed types.Nodes,
513526
cfg *types.Config,
514527
) error {
515-
packetFilter, err := pol.CompileFilterRules(append(peers, node))
528+
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
516529
if err != nil {
517530
return err
518531
}
519532

520-
sshPolicy, err := pol.CompileSSHPolicy(node, peers)
533+
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
521534
if err != nil {
522535
return err
523536
}

hscontrol/mapper/mapper_test.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ func Test_fullMapResponse(t *testing.T) {
159159
lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC)
160160
expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC)
161161

162+
user1 := types.User{Model: gorm.Model{ID: 0}, Name: "mini"}
163+
user2 := types.User{Model: gorm.Model{ID: 1}, Name: "peer2"}
164+
162165
mini := &types.Node{
163166
ID: 0,
164167
MachineKey: mustMK(
@@ -173,8 +176,8 @@ func Test_fullMapResponse(t *testing.T) {
173176
IPv4: iap("100.64.0.1"),
174177
Hostname: "mini",
175178
GivenName: "mini",
176-
UserID: 0,
177-
User: types.User{Name: "mini"},
179+
UserID: user1.ID,
180+
User: user1,
178181
ForcedTags: []string{},
179182
AuthKey: &types.PreAuthKey{},
180183
LastSeen: &lastSeen,
@@ -253,8 +256,8 @@ func Test_fullMapResponse(t *testing.T) {
253256
IPv4: iap("100.64.0.2"),
254257
Hostname: "peer1",
255258
GivenName: "peer1",
256-
UserID: 0,
257-
User: types.User{Name: "mini"},
259+
UserID: user1.ID,
260+
User: user1,
258261
ForcedTags: []string{},
259262
LastSeen: &lastSeen,
260263
Expiry: &expire,
@@ -308,8 +311,8 @@ func Test_fullMapResponse(t *testing.T) {
308311
IPv4: iap("100.64.0.3"),
309312
Hostname: "peer2",
310313
GivenName: "peer2",
311-
UserID: 1,
312-
User: types.User{Name: "peer2"},
314+
UserID: user2.ID,
315+
User: user2,
313316
ForcedTags: []string{},
314317
LastSeen: &lastSeen,
315318
Expiry: &expire,
@@ -468,6 +471,7 @@ func Test_fullMapResponse(t *testing.T) {
468471
got, err := mappy.fullMapResponse(
469472
tt.node,
470473
tt.peers,
474+
[]types.User{user1, user2},
471475
tt.pol,
472476
0,
473477
)

0 commit comments

Comments
 (0)