Skip to content

Commit cd3b8e6

Browse files
authored
clean up handler methods, common logging (#2384)
* clean up handler methods, common logging Signed-off-by: Kristoffer Dalby <[email protected]> * streamline http.Error calls Signed-off-by: Kristoffer Dalby <[email protected]> --------- Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent f44b1d3 commit cd3b8e6

File tree

4 files changed

+53
-241
lines changed

4 files changed

+53
-241
lines changed

hscontrol/handlers.go

+17-56
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ const (
3232
reservedResponseHeaderSize = 4
3333
)
3434

35+
// httpError logs an error and sends an HTTP error response with the given
36+
func httpError(w http.ResponseWriter, err error, userError string, code int) {
37+
log.Error().Err(err).Msg(userError)
38+
http.Error(w, userError, code)
39+
}
40+
3541
var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New(
3642
"machines registered with CLI does not support expire",
3743
)
@@ -52,7 +58,7 @@ func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error)
5258
return tailcfg.CapabilityVersion(clientCapabilityVersion), nil
5359
}
5460

55-
func (h *Headscale) handleVerifyRequest(
61+
func (h *Headscale) derpRequestIsAllowed(
5662
req *http.Request,
5763
) (bool, error) {
5864
body, err := io.ReadAll(req.Body)
@@ -79,36 +85,22 @@ func (h *Headscale) VerifyHandler(
7985
req *http.Request,
8086
) {
8187
if req.Method != http.MethodPost {
82-
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
83-
88+
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
8489
return
8590
}
86-
log.Debug().
87-
Str("handler", "/verify").
88-
Msg("verify client")
8991

90-
allow, err := h.handleVerifyRequest(req)
92+
allow, err := h.derpRequestIsAllowed(req)
9193
if err != nil {
92-
log.Error().
93-
Caller().
94-
Err(err).
95-
Msg("Failed to verify client")
96-
http.Error(writer, "Internal error", http.StatusInternalServerError)
94+
httpError(writer, err, "Internal error", http.StatusInternalServerError)
95+
return
9796
}
9897

9998
resp := tailcfg.DERPAdmitClientResponse{
10099
Allow: allow,
101100
}
102101

103102
writer.Header().Set("Content-Type", "application/json")
104-
writer.WriteHeader(http.StatusOK)
105-
err = json.NewEncoder(writer).Encode(resp)
106-
if err != nil {
107-
log.Error().
108-
Caller().
109-
Err(err).
110-
Msg("Failed to write response")
111-
}
103+
json.NewEncoder(writer).Encode(resp)
112104
}
113105

114106
// KeyHandler provides the Headscale pub key
@@ -120,35 +112,17 @@ func (h *Headscale) KeyHandler(
120112
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
121113
capVer, err := parseCabailityVersion(req)
122114
if err != nil {
123-
log.Error().
124-
Caller().
125-
Err(err).
126-
Msg("could not get capability version")
127-
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
128-
writer.WriteHeader(http.StatusInternalServerError)
129-
115+
httpError(writer, err, "Internal error", http.StatusInternalServerError)
130116
return
131117
}
132118

133-
log.Debug().
134-
Str("handler", "/key").
135-
Int("cap_ver", int(capVer)).
136-
Msg("New noise client")
137-
138119
// TS2021 (Tailscale v2 protocol) requires to have a different key
139120
if capVer >= NoiseCapabilityVersion {
140121
resp := tailcfg.OverTLSPublicKeyResponse{
141122
PublicKey: h.noisePrivateKey.Public(),
142123
}
143124
writer.Header().Set("Content-Type", "application/json")
144-
writer.WriteHeader(http.StatusOK)
145-
err = json.NewEncoder(writer).Encode(resp)
146-
if err != nil {
147-
log.Error().
148-
Caller().
149-
Err(err).
150-
Msg("Failed to write response")
151-
}
125+
json.NewEncoder(writer).Encode(resp)
152126

153127
return
154128
}
@@ -169,18 +143,10 @@ func (h *Headscale) HealthHandler(
169143

170144
if err != nil {
171145
writer.WriteHeader(http.StatusInternalServerError)
172-
log.Error().Caller().Err(err).Msg("health check failed")
173146
res.Status = "fail"
174147
}
175148

176-
buf, err := json.Marshal(res)
177-
if err != nil {
178-
log.Error().Caller().Err(err).Msg("marshal failed")
179-
}
180-
_, err = writer.Write(buf)
181-
if err != nil {
182-
log.Error().Caller().Err(err).Msg("write failed")
183-
}
149+
json.NewEncoder(writer).Encode(res)
184150
}
185151

186152
if err := h.db.PingDB(req.Context()); err != nil {
@@ -233,16 +199,11 @@ func (a *AuthProviderWeb) RegisterHandler(
233199
// the template and log an error.
234200
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
235201
if err != nil {
236-
http.Error(writer, "invalid registration ID", http.StatusBadRequest)
202+
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
237203
return
238204
}
239205

240206
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
241207
writer.WriteHeader(http.StatusOK)
242-
if _, err := writer.Write([]byte(templates.RegisterWeb(registrationId).Render())); err != nil {
243-
log.Error().
244-
Caller().
245-
Err(err).
246-
Msg("Failed to write response")
247-
}
208+
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
248209
}

hscontrol/noise.go

+6-40
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ func (h *Headscale) NoiseUpgradeHandler(
8080
noiseServer.earlyNoise,
8181
)
8282
if err != nil {
83-
log.Error().Err(err).Msg("noise upgrade failed")
84-
http.Error(writer, err.Error(), http.StatusInternalServerError)
85-
83+
httpError(writer, err, "noise upgrade failed", http.StatusInternalServerError)
8684
return
8785
}
8886

@@ -160,12 +158,7 @@ func isSupportedVersion(version tailcfg.CapabilityVersion) bool {
160158
func rejectUnsupported(writer http.ResponseWriter, version tailcfg.CapabilityVersion) bool {
161159
// Reject unsupported versions
162160
if !isSupportedVersion(version) {
163-
log.Info().
164-
Caller().
165-
Int("min_version", int(MinimumCapVersion)).
166-
Int("client_version", int(version)).
167-
Msg("unsupported client connected")
168-
http.Error(writer, "unsupported client version", http.StatusBadRequest)
161+
httpError(writer, nil, "unsupported client version", http.StatusBadRequest)
169162

170163
return true
171164
}
@@ -190,23 +183,10 @@ func (ns *noiseServer) NoisePollNetMapHandler(
190183

191184
var mapRequest tailcfg.MapRequest
192185
if err := json.Unmarshal(body, &mapRequest); err != nil {
193-
log.Error().
194-
Caller().
195-
Err(err).
196-
Msg("Cannot parse MapRequest")
197-
http.Error(writer, "Internal error", http.StatusInternalServerError)
198-
186+
httpError(writer, err, "Internal error", http.StatusInternalServerError)
199187
return
200188
}
201189

202-
log.Trace().
203-
Caller().
204-
Str("handler", "NoisePollNetMap").
205-
Any("headers", req.Header).
206-
Str("node", mapRequest.Hostinfo.Hostname).
207-
Int("capver", int(mapRequest.Version)).
208-
Msg("PollNetMapHandler called")
209-
210190
// Reject unsupported versions
211191
if rejectUnsupported(writer, mapRequest.Version) {
212192
return
@@ -220,11 +200,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
220200
key.NodePublic{},
221201
)
222202
if err != nil {
223-
log.Error().
224-
Str("handler", "NoisePollNetMap").
225-
Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
226-
http.Error(writer, "Internal error", http.StatusInternalServerError)
227-
203+
httpError(writer, err, "Internal error", http.StatusInternalServerError)
228204
return
229205
}
230206

@@ -242,26 +218,16 @@ func (ns *noiseServer) NoiseRegistrationHandler(
242218
writer http.ResponseWriter,
243219
req *http.Request,
244220
) {
245-
log.Trace().Caller().Msgf("Noise registration handler for client %s", req.RemoteAddr)
246221
if req.Method != http.MethodPost {
247-
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
222+
httpError(writer, nil, "Wrong method", http.StatusMethodNotAllowed)
248223

249224
return
250225
}
251226

252-
log.Trace().
253-
Any("headers", req.Header).
254-
Caller().
255-
Msg("Headers")
256-
257227
body, _ := io.ReadAll(req.Body)
258228
var registerRequest tailcfg.RegisterRequest
259229
if err := json.Unmarshal(body, &registerRequest); err != nil {
260-
log.Error().
261-
Caller().
262-
Err(err).
263-
Msg("Cannot parse RegisterRequest")
264-
http.Error(writer, "Internal error", http.StatusInternalServerError)
230+
httpError(writer, err, "Internal error", http.StatusInternalServerError)
265231

266232
return
267233
}

hscontrol/oidc.go

+19-26
Original file line numberDiff line numberDiff line change
@@ -134,34 +134,28 @@ func (a *AuthProviderOIDC) RegisterHandler(
134134
req *http.Request,
135135
) {
136136
vars := mux.Vars(req)
137-
registrationIdStr, ok := vars["registration_id"]
137+
registrationIdStr, _ := vars["registration_id"]
138138

139139
// We need to make sure we dont open for XSS style injections, if the parameter that
140140
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
141141
// the template and log an error.
142142
registrationId, err := types.RegistrationIDFromString(registrationIdStr)
143143
if err != nil {
144-
http.Error(writer, "invalid registration ID", http.StatusBadRequest)
144+
httpError(writer, err, "invalid registration ID", http.StatusBadRequest)
145145
return
146146
}
147147

148-
log.Debug().
149-
Caller().
150-
Str("registration_id", registrationId.String()).
151-
Bool("ok", ok).
152-
Msg("Received oidc register call")
153-
154148
// Set the state and nonce cookies to protect against CSRF attacks
155149
state, err := setCSRFCookie(writer, req, "state")
156150
if err != nil {
157-
http.Error(writer, "Internal server error", http.StatusInternalServerError)
151+
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
158152
return
159153
}
160154

161155
// Set the state and nonce cookies to protect against CSRF attacks
162156
nonce, err := setCSRFCookie(writer, req, "nonce")
163157
if err != nil {
164-
http.Error(writer, "Internal server error", http.StatusInternalServerError)
158+
httpError(writer, err, "Internal server error", http.StatusInternalServerError)
165159
return
166160
}
167161

@@ -225,64 +219,64 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
225219
) {
226220
code, state, err := extractCodeAndStateParamFromRequest(req)
227221
if err != nil {
228-
http.Error(writer, err.Error(), http.StatusBadRequest)
222+
httpError(writer, err, err.Error(), http.StatusBadRequest)
229223
return
230224
}
231225

232-
log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback")
233226
cookieState, err := req.Cookie("state")
234227
if err != nil {
235-
http.Error(writer, "state not found", http.StatusBadRequest)
228+
httpError(writer, err, "state not found", http.StatusBadRequest)
236229
return
237230
}
238231

239232
if state != cookieState.Value {
240-
http.Error(writer, "state did not match", http.StatusBadRequest)
233+
httpError(writer, err, "state did not match", http.StatusBadRequest)
241234
return
242235
}
243236

244237
idToken, err := a.extractIDToken(req.Context(), code, state)
245238
if err != nil {
246-
http.Error(writer, err.Error(), http.StatusBadRequest)
239+
httpError(writer, err, err.Error(), http.StatusBadRequest)
247240
return
248241
}
249242

250243
nonce, err := req.Cookie("nonce")
251244
if err != nil {
252-
http.Error(writer, "nonce not found", http.StatusBadRequest)
245+
httpError(writer, err, "nonce not found", http.StatusBadRequest)
253246
return
254247
}
255248
if idToken.Nonce != nonce.Value {
256-
http.Error(writer, "nonce did not match", http.StatusBadRequest)
249+
httpError(writer, err, "nonce did not match", http.StatusBadRequest)
257250
return
258251
}
259252

260253
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
261254

262255
var claims types.OIDCClaims
263256
if err := idToken.Claims(&claims); err != nil {
264-
http.Error(writer, fmt.Errorf("failed to decode ID token claims: %w", err).Error(), http.StatusInternalServerError)
257+
err = fmt.Errorf("decoding ID token claims: %w", err)
258+
httpError(writer, err, err.Error(), http.StatusInternalServerError)
265259
return
266260
}
267261

268262
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
269-
http.Error(writer, err.Error(), http.StatusUnauthorized)
263+
httpError(writer, err, err.Error(), http.StatusUnauthorized)
270264
return
271265
}
272266

273267
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
274-
http.Error(writer, err.Error(), http.StatusUnauthorized)
268+
httpError(writer, err, err.Error(), http.StatusUnauthorized)
275269
return
276270
}
277271

278272
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
279-
http.Error(writer, err.Error(), http.StatusUnauthorized)
273+
httpError(writer, err, err.Error(), http.StatusUnauthorized)
280274
return
281275
}
282276

283277
user, err := a.createOrUpdateUserFromClaim(&claims)
284278
if err != nil {
285-
http.Error(writer, err.Error(), http.StatusInternalServerError)
279+
httpError(writer, err, err.Error(), http.StatusInternalServerError)
286280
return
287281
}
288282

@@ -297,7 +291,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
297291
verb := "Reauthenticated"
298292
newNode, err := a.handleRegistrationID(user, *registrationId, nodeExpiry)
299293
if err != nil {
300-
http.Error(writer, err.Error(), http.StatusInternalServerError)
294+
httpError(writer, err, err.Error(), http.StatusInternalServerError)
301295
return
302296
}
303297

@@ -308,7 +302,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
308302
// TODO(kradalby): replace with go-elem
309303
content, err := renderOIDCCallbackTemplate(user, verb)
310304
if err != nil {
311-
http.Error(writer, err.Error(), http.StatusInternalServerError)
305+
httpError(writer, err, err.Error(), http.StatusInternalServerError)
312306
return
313307
}
314308

@@ -323,7 +317,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
323317

324318
// Neither node nor machine key was found in the state cache meaning
325319
// that we could not reauth nor register the node.
326-
http.Error(writer, "login session expired, try again", http.StatusInternalServerError)
320+
httpError(writer, nil, "login session expired, try again", http.StatusInternalServerError)
327321
return
328322
}
329323

@@ -423,7 +417,6 @@ func validateOIDCAllowedUsers(
423417
) error {
424418
if len(allowedUsers) > 0 &&
425419
!slices.Contains(allowedUsers, claims.Email) {
426-
log.Trace().Msg("authenticated principal does not match any allowed user")
427420
return errOIDCAllowedUsers
428421
}
429422

0 commit comments

Comments
 (0)