Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return better web errors to the user #2398

Merged
merged 4 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
[#2396](https://github.com/juanfont/headscale/pull/2396)
- Pre auth keys that are used by a node can no longer be deleted
[#2396](https://github.com/juanfont/headscale/pull/2396)
- Rehaul HTTP errors, return better status code and errors to users
[#2398](https://github.com/juanfont/headscale/pull/2398)

## 0.24.2 (2025-01-30)

Expand Down
37 changes: 33 additions & 4 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (h *Headscale) handleExistingNode(
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
return nil, errors.New("node already exists with different machine key")
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
}

expired := node.IsExpired()
Expand All @@ -81,7 +81,7 @@ func (h *Headscale) handleExistingNode(

// The client is trying to extend their key, this is not allowed.
if requestExpiry.After(time.Now()) {
return nil, errors.New("extending key is not allowed")
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
}

// If the request expiry is in the past, we consider it a logout.
Expand Down Expand Up @@ -155,13 +155,42 @@ func (h *Headscale) waitForFollowup(
}
}

// canUsePreAuthKey checks if a pre auth key can be used.
func canUsePreAuthKey(pak *types.PreAuthKey) error {
if pak == nil {
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
}

// we don't need to check if has been used before
if pak.Reusable {
return nil
}

if pak.Used {
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
}

return nil
}

func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, fmt.Errorf("invalid pre auth key: %w", err)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
return nil, err
}

err = canUsePreAuthKey(pak)
if err != nil {
return nil, err
}

nodeToRegister := types.Node{
Expand Down
130 changes: 130 additions & 0 deletions hscontrol/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package hscontrol

import (
"net/http"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
)

func TestCanUsePreAuthKey(t *testing.T) {
now := time.Now()
past := now.Add(-time.Hour)
future := now.Add(time.Hour)

tests := []struct {
name string
pak *types.PreAuthKey
wantErr bool
err HTTPError
}{
{
name: "valid reusable key",
pak: &types.PreAuthKey{
Reusable: true,
Used: false,
Expiration: &future,
},
wantErr: false,
},
{
name: "valid non-reusable key",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: &future,
},
wantErr: false,
},
{
name: "expired key",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
},
{
name: "used non-reusable key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: &future,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
},
{
name: "used reusable key",
pak: &types.PreAuthKey{
Reusable: true,
Used: true,
Expiration: &future,
},
wantErr: false,
},
{
name: "no expiration date",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: nil,
},
wantErr: false,
},
{
name: "nil preauth key",
pak: nil,
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil),
},
{
name: "expired and used key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
},
{
name: "no expiration and used key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: nil,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := canUsePreAuthKey(tt.pak)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(HTTPError)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {
if diff := cmp.Diff(tt.err, httpErr); diff != "" {
t.Errorf("unexpected error (-want +got):\n%s", diff)
}
}
}
} else {
if err != nil {
t.Errorf("expected no error but got %v", err)
}
}
})
}
}
73 changes: 18 additions & 55 deletions hscontrol/db/preauth_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
"tailscale.com/types/ptr"
"tailscale.com/util/set"
)

Expand Down Expand Up @@ -64,6 +63,7 @@ func CreatePreAuthKey(
}

now := time.Now().UTC()
// TODO(kradalby): unify the key generations spread all over the code.
kstr, err := generateKey()
if err != nil {
return nil, err
Expand Down Expand Up @@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
return keys, nil
}

// GetPreAuthKey returns a PreAuthKey for a given key.
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
pak, err := ValidatePreAuthKey(tx, key)
if err != nil {
return nil, err
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return GetPreAuthKey(rx, key)
})
}

if pak.User.Name != user {
return nil, ErrUserMismatch
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
// for checking if the key is usable (expired or used).
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
return nil, ErrPreAuthKeyNotFound
}

return pak, nil
return &pak, nil
}

// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
Expand All @@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
})
}

// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}

return nil
}

// UsePreAuthKey marks a PreAuthKey as used.
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
Expand All @@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
return nil
}

func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
}

// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrPreAuthKeyNotFound
}

if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, ErrPreAuthKeyExpired
}

if pak.Reusable { // we don't need to check if has been used before
return &pak, nil
}

nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
Find(&nodes).Error; err != nil {
return nil, err
}

if len(nodes) != 0 || pak.Used {
return nil, ErrSingleUseAuthKeyHasBeenUsed
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}

return &pak, nil
return nil
}

func generateKey() (string, error) {
Expand Down
Loading
Loading