Skip to content

Commit fa641e3

Browse files
authored
Set CSRF cookies for OIDC (#2328)
* set state and nounce in oidc to prevent csrf Fixes #2276 * try to fix new postgres issue Signed-off-by: Kristoffer Dalby <[email protected]> --------- Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent 41bad2b commit fa641e3

File tree

3 files changed

+100
-21
lines changed

3 files changed

+100
-21
lines changed

.github/workflows/test.yml

+6
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,10 @@ jobs:
3434

3535
- name: Run tests
3636
if: steps.changed-files.outputs.files == 'true'
37+
env:
38+
# As of 2025-01-06, these env vars was not automatically
39+
# set anymore which breaks the initdb for postgres on
40+
# some of the database migration tests.
41+
LC_ALL: "en_US.UTF-8"
42+
LC_CTYPE: "en_US.UTF-8"
3743
run: nix develop --command -- gotestsum

hscontrol/oidc.go

+54-7
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package hscontrol
33
import (
44
"bytes"
55
"context"
6-
"crypto/rand"
76
_ "embed"
8-
"encoding/hex"
97
"errors"
108
"fmt"
119
"html/template"
@@ -157,13 +155,19 @@ func (a *AuthProviderOIDC) RegisterHandler(
157155
return
158156
}
159157

160-
randomBlob := make([]byte, randomByteSize)
161-
if _, err := rand.Read(randomBlob); err != nil {
158+
// Set the state and nonce cookies to protect against CSRF attacks
159+
state, err := setCSRFCookie(writer, req, "state")
160+
if err != nil {
162161
http.Error(writer, "Internal server error", http.StatusInternalServerError)
163162
return
164163
}
165164

166-
stateStr := hex.EncodeToString(randomBlob)[:32]
165+
// Set the state and nonce cookies to protect against CSRF attacks
166+
nonce, err := setCSRFCookie(writer, req, "nonce")
167+
if err != nil {
168+
http.Error(writer, "Internal server error", http.StatusInternalServerError)
169+
return
170+
}
167171

168172
// Initialize registration info with machine key
169173
registrationInfo := RegistrationInfo{
@@ -191,11 +195,12 @@ func (a *AuthProviderOIDC) RegisterHandler(
191195
for k, v := range a.cfg.ExtraParams {
192196
extras = append(extras, oauth2.SetAuthURLParam(k, v))
193197
}
198+
extras = append(extras, oidc.Nonce(nonce))
194199

195200
// Cache the registration info
196-
a.registrationCache.Set(stateStr, registrationInfo)
201+
a.registrationCache.Set(state, registrationInfo)
197202

198-
authURL := a.oauth2Config.AuthCodeURL(stateStr, extras...)
203+
authURL := a.oauth2Config.AuthCodeURL(state, extras...)
199204
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
200205

201206
http.Redirect(writer, req, authURL, http.StatusFound)
@@ -228,11 +233,34 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
228233
return
229234
}
230235

236+
log.Debug().Interface("cookies", req.Cookies()).Msg("Received oidc callback")
237+
cookieState, err := req.Cookie("state")
238+
if err != nil {
239+
http.Error(writer, "state not found", http.StatusBadRequest)
240+
return
241+
}
242+
243+
if state != cookieState.Value {
244+
http.Error(writer, "state did not match", http.StatusBadRequest)
245+
return
246+
}
247+
231248
idToken, err := a.extractIDToken(req.Context(), code, state)
232249
if err != nil {
233250
http.Error(writer, err.Error(), http.StatusBadRequest)
234251
return
235252
}
253+
254+
nonce, err := req.Cookie("nonce")
255+
if err != nil {
256+
http.Error(writer, "nonce not found", http.StatusBadRequest)
257+
return
258+
}
259+
if idToken.Nonce != nonce.Value {
260+
http.Error(writer, "nonce did not match", http.StatusBadRequest)
261+
return
262+
}
263+
236264
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
237265

238266
var claims types.OIDCClaims
@@ -592,3 +620,22 @@ func getUserName(
592620

593621
return userName, nil
594622
}
623+
624+
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
625+
val, err := util.GenerateRandomStringURLSafe(64)
626+
if err != nil {
627+
return val, err
628+
}
629+
630+
c := &http.Cookie{
631+
Path: "/oidc/callback",
632+
Name: name,
633+
Value: val,
634+
MaxAge: int(time.Hour.Seconds()),
635+
Secure: r.TLS != nil,
636+
HttpOnly: true,
637+
}
638+
http.SetCookie(w, c)
639+
640+
return val, nil
641+
}

integration/auth_oidc_test.go

+40-14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"log"
1111
"net"
1212
"net/http"
13+
"net/http/cookiejar"
14+
"net/http/httptest"
1315
"net/netip"
1416
"sort"
1517
"strconv"
@@ -747,6 +749,24 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc
747749
}, nil
748750
}
749751

752+
type LoggingRoundTripper struct{}
753+
754+
func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
755+
noTls := &http.Transport{
756+
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
757+
}
758+
resp, err := noTls.RoundTrip(req)
759+
if err != nil {
760+
return nil, err
761+
}
762+
763+
log.Printf("---")
764+
log.Printf("method: %s | url: %s", resp.Request.Method, resp.Request.URL.String())
765+
log.Printf("status: %d | cookies: %+v", resp.StatusCode, resp.Cookies())
766+
767+
return resp, nil
768+
}
769+
750770
func (s *AuthOIDCScenario) runTailscaleUp(
751771
userStr, loginServer string,
752772
) error {
@@ -758,44 +778,50 @@ func (s *AuthOIDCScenario) runTailscaleUp(
758778
log.Printf("running tailscale up for user %s", userStr)
759779
if user, ok := s.users[userStr]; ok {
760780
for _, client := range user.Clients {
761-
c := client
781+
tsc := client
762782
user.joinWaitGroup.Go(func() error {
763-
loginURL, err := c.LoginWithURL(loginServer)
783+
loginURL, err := tsc.LoginWithURL(loginServer)
764784
if err != nil {
765-
log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err)
785+
log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err)
766786
}
767787

768-
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
788+
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetHostname())
769789
loginURL.Scheme = "http"
770790

771791
if len(headscale.GetCert()) > 0 {
772792
loginURL.Scheme = "https"
773793
}
774794

775-
insecureTransport := &http.Transport{
776-
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
795+
httptest.NewRecorder()
796+
hc := &http.Client{
797+
Transport: LoggingRoundTripper{},
798+
}
799+
hc.Jar, err = cookiejar.New(nil)
800+
if err != nil {
801+
log.Printf("failed to create cookie jar: %s", err)
777802
}
778803

779-
log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String())
804+
log.Printf("%s login url: %s\n", tsc.Hostname(), loginURL.String())
780805

781-
log.Printf("%s logging in with url", c.Hostname())
782-
httpClient := &http.Client{Transport: insecureTransport}
806+
log.Printf("%s logging in with url", tsc.Hostname())
783807
ctx := context.Background()
784808
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
785-
resp, err := httpClient.Do(req)
809+
resp, err := hc.Do(req)
786810
if err != nil {
787811
log.Printf(
788812
"%s failed to login using url %s: %s",
789-
c.Hostname(),
813+
tsc.Hostname(),
790814
loginURL,
791815
err,
792816
)
793817

794818
return err
795819
}
796820

821+
log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL))
822+
797823
if resp.StatusCode != http.StatusOK {
798-
log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status)
824+
log.Printf("%s response code of oidc login request was %s", tsc.Hostname(), resp.Status)
799825
body, _ := io.ReadAll(resp.Body)
800826
log.Printf("body: %s", body)
801827

@@ -806,12 +832,12 @@ func (s *AuthOIDCScenario) runTailscaleUp(
806832

807833
_, err = io.ReadAll(resp.Body)
808834
if err != nil {
809-
log.Printf("%s failed to read response body: %s", c.Hostname(), err)
835+
log.Printf("%s failed to read response body: %s", tsc.Hostname(), err)
810836

811837
return err
812838
}
813839

814-
log.Printf("Finished request for %s to join tailnet", c.Hostname())
840+
log.Printf("Finished request for %s to join tailnet", tsc.Hostname())
815841
return nil
816842
})
817843

0 commit comments

Comments
 (0)