Skip to content

Commit 0942654

Browse files
committed
move validation of pre auth key out of db
This move separates the logic a bit and allow us to write specific errors for the caller, in this case the web layer so we can present the user with the correct error codes without bleeding web stuff into a generic validate. Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent 91c6ec8 commit 0942654

File tree

5 files changed

+184
-180
lines changed

5 files changed

+184
-180
lines changed

hscontrol/auth.go

+31-4
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,42 @@ func (h *Headscale) waitForFollowup(
155155
}
156156
}
157157

158+
// canUsePreAuthKey checks if a pre auth key can be used.
159+
func canUsePreAuthKey(pak *types.PreAuthKey) error {
160+
if pak == nil {
161+
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
162+
}
163+
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
164+
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
165+
}
166+
167+
// we don't need to check if has been used before
168+
if pak.Reusable {
169+
return nil
170+
}
171+
172+
if pak.Used {
173+
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
174+
}
175+
176+
return nil
177+
}
178+
158179
func (h *Headscale) handleRegisterWithAuthKey(
159180
regReq tailcfg.RegisterRequest,
160181
machineKey key.MachinePublic,
161182
) (*tailcfg.RegisterResponse, error) {
162-
// TODO(kradalby) Refactor and get the validate away from the database
163-
// so we can return nice http errors.
164-
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
183+
pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
165184
if err != nil {
166-
return nil, fmt.Errorf("invalid pre auth key: %w", err)
185+
if errors.Is(err, gorm.ErrRecordNotFound) {
186+
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
187+
}
188+
return nil, err
189+
}
190+
191+
err = canUsePreAuthKey(pak)
192+
if err != nil {
193+
return nil, err
167194
}
168195

169196
nodeToRegister := types.Node{

hscontrol/auth_test.go

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package hscontrol
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/go-cmp/cmp"
9+
"github.com/juanfont/headscale/hscontrol/types"
10+
)
11+
12+
func TestCanUsePreAuthKey(t *testing.T) {
13+
now := time.Now()
14+
past := now.Add(-time.Hour)
15+
future := now.Add(time.Hour)
16+
17+
tests := []struct {
18+
name string
19+
pak *types.PreAuthKey
20+
wantErr bool
21+
err HTTPError
22+
}{
23+
{
24+
name: "valid reusable key",
25+
pak: &types.PreAuthKey{
26+
Reusable: true,
27+
Used: false,
28+
Expiration: &future,
29+
},
30+
wantErr: false,
31+
},
32+
{
33+
name: "valid non-reusable key",
34+
pak: &types.PreAuthKey{
35+
Reusable: false,
36+
Used: false,
37+
Expiration: &future,
38+
},
39+
wantErr: false,
40+
},
41+
{
42+
name: "expired key",
43+
pak: &types.PreAuthKey{
44+
Reusable: false,
45+
Used: false,
46+
Expiration: &past,
47+
},
48+
wantErr: true,
49+
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
50+
},
51+
{
52+
name: "used non-reusable key",
53+
pak: &types.PreAuthKey{
54+
Reusable: false,
55+
Used: true,
56+
Expiration: &future,
57+
},
58+
wantErr: true,
59+
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
60+
},
61+
{
62+
name: "used reusable key",
63+
pak: &types.PreAuthKey{
64+
Reusable: true,
65+
Used: true,
66+
Expiration: &future,
67+
},
68+
wantErr: false,
69+
},
70+
{
71+
name: "no expiration date",
72+
pak: &types.PreAuthKey{
73+
Reusable: false,
74+
Used: false,
75+
Expiration: nil,
76+
},
77+
wantErr: false,
78+
},
79+
{
80+
name: "nil preauth key",
81+
pak: nil,
82+
wantErr: true,
83+
err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil),
84+
},
85+
{
86+
name: "expired and used key",
87+
pak: &types.PreAuthKey{
88+
Reusable: false,
89+
Used: true,
90+
Expiration: &past,
91+
},
92+
wantErr: true,
93+
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
94+
},
95+
{
96+
name: "no expiration and used key",
97+
pak: &types.PreAuthKey{
98+
Reusable: false,
99+
Used: true,
100+
Expiration: nil,
101+
},
102+
wantErr: true,
103+
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
104+
},
105+
}
106+
107+
for _, tt := range tests {
108+
t.Run(tt.name, func(t *testing.T) {
109+
err := canUsePreAuthKey(tt.pak)
110+
if tt.wantErr {
111+
if err == nil {
112+
t.Errorf("expected error but got none")
113+
} else {
114+
httpErr, ok := err.(HTTPError)
115+
if !ok {
116+
t.Errorf("expected HTTPError but got %T", err)
117+
} else {
118+
if diff := cmp.Diff(tt.err, httpErr); diff != "" {
119+
t.Errorf("unexpected error (-want +got):\n%s", diff)
120+
}
121+
}
122+
}
123+
} else {
124+
if err != nil {
125+
t.Errorf("expected no error but got %v", err)
126+
}
127+
}
128+
})
129+
}
130+
}

hscontrol/db/preauth_keys.go

+18-55
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010

1111
"github.com/juanfont/headscale/hscontrol/types"
1212
"gorm.io/gorm"
13-
"tailscale.com/types/ptr"
1413
"tailscale.com/util/set"
1514
)
1615

@@ -64,6 +63,7 @@ func CreatePreAuthKey(
6463
}
6564

6665
now := time.Now().UTC()
66+
// TODO(kradalby): unify the key generations spread all over the code.
6767
kstr, err := generateKey()
6868
if err != nil {
6969
return nil, err
@@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
108108
return keys, nil
109109
}
110110

111-
// GetPreAuthKey returns a PreAuthKey for a given key.
112-
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
113-
pak, err := ValidatePreAuthKey(tx, key)
114-
if err != nil {
115-
return nil, err
116-
}
111+
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
112+
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
113+
return GetPreAuthKey(rx, key)
114+
})
115+
}
117116

118-
if pak.User.Name != user {
119-
return nil, ErrUserMismatch
117+
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
118+
// for checking if the key is usable (expired or used).
119+
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
120+
pak := types.PreAuthKey{}
121+
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
122+
return nil, ErrPreAuthKeyNotFound
120123
}
121124

122-
return pak, nil
125+
return &pak, nil
123126
}
124127

125128
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
@@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
140143
})
141144
}
142145

143-
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
144-
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
145-
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
146-
return err
147-
}
148-
149-
return nil
150-
}
151-
152146
// UsePreAuthKey marks a PreAuthKey as used.
153147
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
154148
k.Used = true
@@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
159153
return nil
160154
}
161155

162-
func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
163-
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
164-
return ValidatePreAuthKey(rx, k)
165-
})
166-
}
167-
168-
// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
169-
// If returns no error and a PreAuthKey, it can be used.
170-
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
171-
pak := types.PreAuthKey{}
172-
if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is(
173-
result.Error,
174-
gorm.ErrRecordNotFound,
175-
) {
176-
return nil, ErrPreAuthKeyNotFound
177-
}
178-
179-
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
180-
return nil, ErrPreAuthKeyExpired
181-
}
182-
183-
if pak.Reusable { // we don't need to check if has been used before
184-
return &pak, nil
185-
}
186-
187-
nodes := types.Nodes{}
188-
if err := tx.
189-
Preload("AuthKey").
190-
Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
191-
Find(&nodes).Error; err != nil {
192-
return nil, err
193-
}
194-
195-
if len(nodes) != 0 || pak.Used {
196-
return nil, ErrSingleUseAuthKeyHasBeenUsed
156+
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
157+
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
158+
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
159+
return err
197160
}
198161

199-
return &pak, nil
162+
return nil
200163
}
201164

202165
func generateKey() (string, error) {

0 commit comments

Comments
 (0)