Skip to content

Commit 91c6ec8

Browse files
committed
classify user errors in http handlers
Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent 3a9dae7 commit 91c6ec8

File tree

5 files changed

+55
-48
lines changed

5 files changed

+55
-48
lines changed

hscontrol/auth.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (h *Headscale) handleExistingNode(
7272
machineKey key.MachinePublic,
7373
) (*tailcfg.RegisterResponse, error) {
7474
if node.MachineKey != machineKey {
75-
return nil, errors.New("node already exists with different machine key")
75+
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
7676
}
7777

7878
expired := node.IsExpired()
@@ -81,7 +81,7 @@ func (h *Headscale) handleExistingNode(
8181

8282
// The client is trying to extend their key, this is not allowed.
8383
if requestExpiry.After(time.Now()) {
84-
return nil, errors.New("extending key is not allowed")
84+
return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil)
8585
}
8686

8787
// If the request expiry is in the past, we consider it a logout.
@@ -159,6 +159,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
159159
regReq tailcfg.RegisterRequest,
160160
machineKey key.MachinePublic,
161161
) (*tailcfg.RegisterResponse, error) {
162+
// TODO(kradalby) Refactor and get the validate away from the database
163+
// so we can return nice http errors.
162164
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
163165
if err != nil {
164166
return nil, fmt.Errorf("invalid pre auth key: %w", err)

hscontrol/handlers.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
7070
clientCapabilityStr := req.URL.Query().Get("v")
7171

7272
if clientCapabilityStr == "" {
73-
return 0, ErrNoCapabilityVersion
73+
return 0, NewHTTPError(http.StatusBadRequest, "capability version must be set", nil)
7474
}
7575

7676
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
7777
if err != nil {
78-
return 0, fmt.Errorf("failed to parse capability version: %w", err)
78+
return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err))
7979
}
8080

8181
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
@@ -108,13 +108,13 @@ func (h *Headscale) VerifyHandler(
108108
req *http.Request,
109109
) {
110110
if req.Method != http.MethodPost {
111-
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
111+
httpError(writer, errMethodNotAllowed)
112112
return
113113
}
114114

115115
allow, err := h.derpRequestIsAllowed(req)
116116
if err != nil {
117-
httpError(writer, err, "Internal error", http.StatusInternalServerError)
117+
httpError(writer, err)
118118
return
119119
}
120120

@@ -135,7 +135,7 @@ func (h *Headscale) KeyHandler(
135135
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
136136
capVer, err := parseCabailityVersion(req)
137137
if err != nil {
138-
httpError(writer, err, "Internal error", http.StatusInternalServerError)
138+
httpError(writer, err)
139139
return
140140
}
141141

@@ -222,7 +222,7 @@ func (a *AuthProviderWeb) RegisterHandler(
222222
// the template and log an error.
223223
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
224224
if err != nil {
225-
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
225+
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
226226
return
227227
}
228228

hscontrol/noise.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package hscontrol
33
import (
44
"encoding/binary"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"io"
89
"net/http"
@@ -12,6 +13,7 @@ import (
1213
"github.com/juanfont/headscale/hscontrol/types"
1314
"github.com/rs/zerolog/log"
1415
"golang.org/x/net/http2"
16+
"gorm.io/gorm"
1517
"tailscale.com/control/controlbase"
1618
"tailscale.com/control/controlhttp/controlhttpserver"
1719
"tailscale.com/tailcfg"
@@ -81,7 +83,7 @@ func (h *Headscale) NoiseUpgradeHandler(
8183
noiseServer.earlyNoise,
8284
)
8385
if err != nil {
84-
httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError)
86+
httpError(writer, fmt.Errorf("noise upgrade failed: %w", err))
8587
return
8688
}
8789

@@ -198,7 +200,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
198200

199201
var mapRequest tailcfg.MapRequest
200202
if err := json.Unmarshal(body, &mapRequest); err != nil {
201-
httpError(writer, err, "Internal error", http.StatusInternalServerError)
203+
httpError(writer, err)
202204
return
203205
}
204206

@@ -211,7 +213,11 @@ func (ns *noiseServer) NoisePollNetMapHandler(
211213

212214
node, err := ns.headscale.db.GetNodeByNodeKey(mapRequest.NodeKey)
213215
if err != nil {
214-
httpError(writer, err, "Internal error", http.StatusInternalServerError)
216+
if errors.Is(err, gorm.ErrRecordNotFound) {
217+
httpError(writer, NewHTTPError(http.StatusNotFound, "node not found", nil))
218+
return
219+
}
220+
httpError(writer, err)
215221
return
216222
}
217223

@@ -230,7 +236,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
230236
req *http.Request,
231237
) {
232238
if req.Method != http.MethodPost {
233-
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
239+
httpError(writer, errMethodNotAllowed)
234240

235241
return
236242
}

hscontrol/oidc.go

+27-28
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,21 @@ func (a *AuthProviderOIDC) RegisterHandler(
141141
// the template and log an error.
142142
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
143143
if err != nil {
144-
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
144+
httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err))
145145
return
146146
}
147147

148148
// Set the state and nonce cookies to protect against CSRF attacks
149149
state, err := setCSRFCookie(writer, req, "state")
150150
if err != nil {
151-
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
151+
httpError(writer, err)
152152
return
153153
}
154154

155155
// Set the state and nonce cookies to protect against CSRF attacks
156156
nonce, err := setCSRFCookie(writer, req, "nonce")
157157
if err != nil {
158-
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
158+
httpError(writer, err)
159159
return
160160
}
161161

@@ -219,64 +219,63 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
219219
) {
220220
code, state, err := extractCodeAndStateParamFromRequest(req)
221221
if err != nil {
222-
httpError(writer, err, err.Error(), http.StatusBadRequest)
222+
httpError(writer, err)
223223
return
224224
}
225225

226226
cookieState, err := req.Cookie("state")
227227
if err != nil {
228-
httpError(writer, err, "state not found", http.StatusBadRequest)
228+
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
229229
return
230230
}
231231

232232
if state != cookieState.Value {
233-
httpError(writer, err, "state did not match", http.StatusBadRequest)
233+
httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil))
234234
return
235235
}
236236

237237
idToken, err := a.extractIDToken(req.Context(), code, state)
238238
if err != nil {
239-
httpError(writer, err, err.Error(), http.StatusBadRequest)
239+
httpError(writer, err)
240240
return
241241
}
242242

243243
nonce, err := req.Cookie("nonce")
244244
if err != nil {
245-
httpError(writer, err, "nonce not found", http.StatusBadRequest)
245+
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
246246
return
247247
}
248248
if idToken.Nonce != nonce.Value {
249-
httpError(writer, err, "nonce did not match", http.StatusBadRequest)
249+
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
250250
return
251251
}
252252

253253
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
254254

255255
var claims types.OIDCClaims
256256
if err := idToken.Claims(&claims); err != nil {
257-
err = fmt.Errorf("decoding ID token claims: %w", err)
258-
httpError(writer, err, err.Error(), http.StatusInternalServerError)
257+
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
259258
return
260259
}
261260

262261
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
263-
httpError(writer, err, err.Error(), http.StatusUnauthorized)
262+
httpError(writer, err)
264263
return
265264
}
266265

267266
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
268-
httpError(writer, err, err.Error(), http.StatusUnauthorized)
267+
httpError(writer, err)
269268
return
270269
}
271270

272271
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
273-
httpError(writer, err, err.Error(), http.StatusUnauthorized)
272+
httpError(writer, err)
274273
return
275274
}
276275

277276
user, err := a.createOrUpdateUserFromClaim(&claims)
278277
if err != nil {
279-
httpError(writer, err, err.Error(), http.StatusInternalServerError)
278+
httpError(writer, err)
280279
return
281280
}
282281

@@ -289,9 +288,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
289288
// Register the node if it does not exist.
290289
if registrationId != nil {
291290
verb := "Reauthenticated"
292-
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
291+
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
293292
if err != nil {
294-
httpError(writer, err, err.Error(), http.StatusInternalServerError)
293+
httpError(writer, err)
295294
return
296295
}
297296

@@ -302,7 +301,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
302301
// TODO(kradalby): replace with go-elem
303302
content, err := renderOIDCCallbackTemplate(user, verb)
304303
if err != nil {
305-
httpError(writer, err, err.Error(), http.StatusInternalServerError)
304+
httpError(writer, err)
306305
return
307306
}
308307

@@ -317,7 +316,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
317316

318317
// Neither node nor machine key was found in the state cache meaning
319318
// that we could not reauth nor register the node.
320-
httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError)
319+
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
321320
return
322321
}
323322

@@ -328,7 +327,7 @@ func extractCodeAndStateParamFromRequest(
328327
state := req.URL.Query().Get("state")
329328

330329
if code == "" || state == "" {
331-
return "", "", errEmptyOIDCCallbackParams
330+
return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams)
332331
}
333332

334333
return code, state, nil
@@ -346,7 +345,7 @@ func (a *AuthProviderOIDC) extractIDToken(
346345
if a.cfg.PKCE.Enabled {
347346
regInfo, ok := a.registrationCache.Get(state)
348347
if !ok {
349-
return nil, errNoOIDCRegistrationInfo
348+
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
350349
}
351350
if regInfo.Verifier != nil {
352351
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
@@ -355,18 +354,18 @@ func (a *AuthProviderOIDC) extractIDToken(
355354

356355
oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...)
357356
if err != nil {
358-
return nil, fmt.Errorf("could not exchange code for token: %w", err)
357+
return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err))
359358
}
360359

361360
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
362361
if !ok {
363-
return nil, errNoOIDCIDToken
362+
return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken)
364363
}
365364

366365
verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID})
367366
idToken, err := verifier.Verify(ctx, rawIDToken)
368367
if err != nil {
369-
return nil, fmt.Errorf("failed to verify ID token: %w", err)
368+
return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err))
370369
}
371370

372371
return idToken, nil
@@ -381,7 +380,7 @@ func validateOIDCAllowedDomains(
381380
if len(allowedDomains) > 0 {
382381
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
383382
!slices.Contains(allowedDomains, claims.Email[at+1:]) {
384-
return errOIDCAllowedDomains
383+
return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains)
385384
}
386385
}
387386

@@ -403,7 +402,7 @@ func validateOIDCAllowedGroups(
403402
}
404403
}
405404

406-
return errOIDCAllowedGroups
405+
return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups)
407406
}
408407

409408
return nil
@@ -417,7 +416,7 @@ func validateOIDCAllowedUsers(
417416
) error {
418417
if len(allowedUsers) > 0 &&
419418
!slices.Contains(allowedUsers, claims.Email) {
420-
return errOIDCAllowedUsers
419+
return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers)
421420
}
422421

423422
return nil
@@ -488,7 +487,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
488487
return user, nil
489488
}
490489

491-
func (a *AuthProviderOIDC) handleRegistrationID(
490+
func (a *AuthProviderOIDC) handleRegistration(
492491
user *types.User,
493492
registrationID types.RegistrationID,
494493
expiry time.Time,

hscontrol/platform_config.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,19 @@ func (h *Headscale) ApplePlatformConfig(
3939
vars := mux.Vars(req)
4040
platform, ok := vars["platform"]
4141
if !ok {
42-
httpError(writer, nil, "No platform specified", http.StatusBadRequest)
42+
httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil))
4343
return
4444
}
4545

4646
id, err := uuid.NewV4()
4747
if err != nil {
48-
httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
48+
httpError(writer, err)
4949
return
5050
}
5151

5252
contentID, err := uuid.NewV4()
5353
if err != nil {
54-
httpError(writer, nil, "Failed to create UUID", http.StatusInternalServerError)
54+
httpError(writer, err)
5555
return
5656
}
5757

@@ -65,21 +65,21 @@ func (h *Headscale) ApplePlatformConfig(
6565
switch platform {
6666
case "macos-standalone":
6767
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
68-
httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
68+
httpError(writer, err)
6969
return
7070
}
7171
case "macos-app-store":
7272
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
73-
httpError(writer, err, "Could not render Apple macOS template", http.StatusInternalServerError)
73+
httpError(writer, err)
7474
return
7575
}
7676
case "ios":
7777
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
78-
httpError(writer, err, "Could not render Apple iOS template", http.StatusInternalServerError)
78+
httpError(writer, err)
7979
return
8080
}
8181
default:
82-
httpError(writer, err, "Invalid platform. Only ios, macos-app-store and macos-standalone are supported", http.StatusInternalServerError)
82+
httpError(writer, NewHTTPError(http.StatusBadRequest, "platform must be ios, macos-app-store or macos-standalone", nil))
8383
return
8484
}
8585

@@ -91,7 +91,7 @@ func (h *Headscale) ApplePlatformConfig(
9191

9292
var content bytes.Buffer
9393
if err := commonTemplate.Execute(&content, config); err != nil {
94-
httpError(writer, err, "Could not render platform iOS template", http.StatusInternalServerError)
94+
httpError(writer, err)
9595
return
9696
}
9797

0 commit comments

Comments
 (0)