Skip to content

Commit 4c8e847

Browse files
authored
use dedicated registration ID for auth flow (#2337)
1 parent 97e5d95 commit 4c8e847

26 files changed

+583
-583
lines changed

cmd/headscale/cli/debug.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ import (
44
"fmt"
55

66
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
7+
"github.com/juanfont/headscale/hscontrol/types"
78
"github.com/rs/zerolog/log"
89
"github.com/spf13/cobra"
910
"google.golang.org/grpc/status"
10-
"tailscale.com/types/key"
1111
)
1212

1313
const (
@@ -79,7 +79,7 @@ var createNodeCmd = &cobra.Command{
7979
)
8080
}
8181

82-
machineKey, err := cmd.Flags().GetString("key")
82+
registrationID, err := cmd.Flags().GetString("key")
8383
if err != nil {
8484
ErrorOutput(
8585
err,
@@ -88,8 +88,7 @@ var createNodeCmd = &cobra.Command{
8888
)
8989
}
9090

91-
var mkey key.MachinePublic
92-
err = mkey.UnmarshalText([]byte(machineKey))
91+
_, err = types.RegistrationIDFromString(registrationID)
9392
if err != nil {
9493
ErrorOutput(
9594
err,
@@ -108,7 +107,7 @@ var createNodeCmd = &cobra.Command{
108107
}
109108

110109
request := &v1.DebugCreateNodeRequest{
111-
Key: machineKey,
110+
Key: registrationID,
112111
Name: name,
113112
User: user,
114113
Routes: routes,

cmd/headscale/cli/nodes.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ var registerNodeCmd = &cobra.Command{
122122
defer cancel()
123123
defer conn.Close()
124124

125-
machineKey, err := cmd.Flags().GetString("key")
125+
registrationID, err := cmd.Flags().GetString("key")
126126
if err != nil {
127127
ErrorOutput(
128128
err,
@@ -132,7 +132,7 @@ var registerNodeCmd = &cobra.Command{
132132
}
133133

134134
request := &v1.RegisterNodeRequest{
135-
Key: machineKey,
135+
Key: registrationID,
136136
User: user,
137137
}
138138

hscontrol/app.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ type Headscale struct {
9696
mapper *mapper.Mapper
9797
nodeNotifier *notifier.Notifier
9898

99-
registrationCache *zcache.Cache[string, types.Node]
99+
registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
100100

101101
authProvider AuthProvider
102102

@@ -123,7 +123,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
123123
return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err)
124124
}
125125

126-
registrationCache := zcache.New[string, types.Node](
126+
registrationCache := zcache.New[types.RegistrationID, types.RegisterNode](
127127
registerCacheExpiration,
128128
registerCacheCleanup,
129129
)
@@ -462,7 +462,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
462462

463463
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
464464
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
465-
router.HandleFunc("/register/{mkey}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
465+
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
466466

467467
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
468468
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)

hscontrol/auth.go

+89-46
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"errors"
77
"fmt"
88
"net/http"
9+
"net/url"
10+
"strings"
911
"time"
1012

1113
"github.com/juanfont/headscale/hscontrol/db"
@@ -20,16 +22,18 @@ import (
2022

2123
type AuthProvider interface {
2224
RegisterHandler(http.ResponseWriter, *http.Request)
23-
AuthURL(key.MachinePublic) string
25+
AuthURL(types.RegistrationID) string
2426
}
2527

2628
func logAuthFunc(
2729
registerRequest tailcfg.RegisterRequest,
2830
machineKey key.MachinePublic,
31+
registrationId types.RegistrationID,
2932
) (func(string), func(string), func(error, string)) {
3033
return func(msg string) {
3134
log.Info().
3235
Caller().
36+
Str("registration_id", registrationId.String()).
3337
Str("machine_key", machineKey.ShortString()).
3438
Str("node_key", registerRequest.NodeKey.ShortString()).
3539
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@@ -41,6 +45,7 @@ func logAuthFunc(
4145
func(msg string) {
4246
log.Trace().
4347
Caller().
48+
Str("registration_id", registrationId.String()).
4449
Str("machine_key", machineKey.ShortString()).
4550
Str("node_key", registerRequest.NodeKey.ShortString()).
4651
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@@ -52,6 +57,7 @@ func logAuthFunc(
5257
func(err error, msg string) {
5358
log.Error().
5459
Caller().
60+
Str("registration_id", registrationId.String()).
5561
Str("machine_key", machineKey.ShortString()).
5662
Str("node_key", registerRequest.NodeKey.ShortString()).
5763
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
@@ -63,16 +69,64 @@ func logAuthFunc(
6369
}
6470
}
6571

72+
func (h *Headscale) waitForFollowup(
73+
req *http.Request,
74+
regReq tailcfg.RegisterRequest,
75+
logTrace func(string),
76+
) {
77+
logTrace("register request is a followup")
78+
fu, err := url.Parse(regReq.Followup)
79+
if err != nil {
80+
logTrace("failed to parse followup URL")
81+
return
82+
}
83+
84+
followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", ""))
85+
if err != nil {
86+
logTrace("followup URL does not contains a valid registration ID")
87+
return
88+
}
89+
90+
logTrace(fmt.Sprintf("followup URL contains a valid registration ID, looking up in cache: %s", followupReg))
91+
92+
if reg, ok := h.registrationCache.Get(followupReg); ok {
93+
logTrace("Node is waiting for interactive login")
94+
95+
select {
96+
case <-req.Context().Done():
97+
logTrace("node went away before it was registered")
98+
return
99+
case <-reg.Registered:
100+
logTrace("node has successfully registered")
101+
return
102+
}
103+
}
104+
}
105+
66106
// handleRegister is the logic for registering a client.
67107
func (h *Headscale) handleRegister(
68108
writer http.ResponseWriter,
69109
req *http.Request,
70110
regReq tailcfg.RegisterRequest,
71111
machineKey key.MachinePublic,
72112
) {
73-
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey)
113+
registrationId, err := types.NewRegistrationID()
114+
if err != nil {
115+
log.Error().
116+
Caller().
117+
Err(err).
118+
Msg("Failed to generate registration ID")
119+
http.Error(writer, "Internal server error", http.StatusInternalServerError)
120+
121+
return
122+
}
123+
124+
logInfo, logTrace, _ := logAuthFunc(regReq, machineKey, registrationId)
74125
now := time.Now().UTC()
75126
logTrace("handleRegister called, looking up machine in DB")
127+
128+
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
129+
// key refreshes. This will allow us to remove the machineKey from the registration request.
76130
node, err := h.db.GetNodeByAnyKey(machineKey, regReq.NodeKey, regReq.OldNodeKey)
77131
logTrace("handleRegister database lookup has returned")
78132
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -84,27 +138,9 @@ func (h *Headscale) handleRegister(
84138
}
85139

86140
// Check if the node is waiting for interactive login.
87-
//
88-
// TODO(juan): We could use this field to improve our protocol implementation,
89-
// and hold the request until the client closes it, or the interactive
90-
// login is completed (i.e., the user registers the node).
91-
// This is not implemented yet, as it is no strictly required. The only side-effect
92-
// is that the client will hammer headscale with requests until it gets a
93-
// successful RegisterResponse.
94141
if regReq.Followup != "" {
95-
logTrace("register request is a followup")
96-
if _, ok := h.registrationCache.Get(machineKey.String()); ok {
97-
logTrace("Node is waiting for interactive login")
98-
99-
select {
100-
case <-req.Context().Done():
101-
return
102-
case <-time.After(registrationHoldoff):
103-
h.handleNewNode(writer, regReq, machineKey)
104-
105-
return
106-
}
107-
}
142+
h.waitForFollowup(req, regReq, logTrace)
143+
return
108144
}
109145

110146
logInfo("Node not found in database, creating new")
@@ -113,25 +149,28 @@ func (h *Headscale) handleRegister(
113149
// that we rely on a method that calls back some how (OpenID or CLI)
114150
// We create the node and then keep it around until a callback
115151
// happens
116-
newNode := types.Node{
117-
MachineKey: machineKey,
118-
Hostname: regReq.Hostinfo.Hostname,
119-
NodeKey: regReq.NodeKey,
120-
LastSeen: &now,
121-
Expiry: &time.Time{},
152+
newNode := types.RegisterNode{
153+
Node: types.Node{
154+
MachineKey: machineKey,
155+
Hostname: regReq.Hostinfo.Hostname,
156+
NodeKey: regReq.NodeKey,
157+
LastSeen: &now,
158+
Expiry: &time.Time{},
159+
},
160+
Registered: make(chan struct{}),
122161
}
123162

124163
if !regReq.Expiry.IsZero() {
125164
logTrace("Non-zero expiry time requested")
126-
newNode.Expiry = &regReq.Expiry
165+
newNode.Node.Expiry = &regReq.Expiry
127166
}
128167

129168
h.registrationCache.Set(
130-
machineKey.String(),
169+
registrationId,
131170
newNode,
132171
)
133172

134-
h.handleNewNode(writer, regReq, machineKey)
173+
h.handleNewNode(writer, regReq, registrationId)
135174

136175
return
137176
}
@@ -206,27 +245,28 @@ func (h *Headscale) handleRegister(
206245
}
207246

208247
if regReq.Followup != "" {
209-
select {
210-
case <-req.Context().Done():
211-
return
212-
case <-time.After(registrationHoldoff):
213-
}
248+
h.waitForFollowup(req, regReq, logTrace)
249+
return
214250
}
215251

216252
// The node has expired or it is logged out
217-
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey)
253+
h.handleNodeExpiredOrLoggedOut(writer, regReq, *node, machineKey, registrationId)
218254

219255
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
220256
node.Expiry = &time.Time{}
221257

258+
// TODO(kradalby): do we need to rethink this as part of authflow?
222259
// If we are here it means the client needs to be reauthorized,
223260
// we need to make sure the NodeKey matches the one in the request
224261
// TODO(juan): What happens when using fast user switching between two
225262
// headscale-managed tailnets?
226263
node.NodeKey = regReq.NodeKey
227264
h.registrationCache.Set(
228-
machineKey.String(),
229-
*node,
265+
registrationId,
266+
types.RegisterNode{
267+
Node: *node,
268+
Registered: make(chan struct{}),
269+
},
230270
)
231271

232272
return
@@ -296,6 +336,8 @@ func (h *Headscale) handleAuthKey(
296336
// The error is not important, because if it does not
297337
// exist, then this is a new node and we will move
298338
// on to registration.
339+
// TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
340+
// key refreshes. This will allow us to remove the machineKey from the registration request.
299341
node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
300342
if node != nil {
301343
log.Trace().
@@ -444,16 +486,16 @@ func (h *Headscale) handleAuthKey(
444486
func (h *Headscale) handleNewNode(
445487
writer http.ResponseWriter,
446488
registerRequest tailcfg.RegisterRequest,
447-
machineKey key.MachinePublic,
489+
registrationId types.RegistrationID,
448490
) {
449-
logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey)
491+
logInfo, logTrace, logErr := logAuthFunc(registerRequest, key.MachinePublic{}, registrationId)
450492

451493
resp := tailcfg.RegisterResponse{}
452494

453495
// The node registration is new, redirect the client to the registration URL
454-
logTrace("The node seems to be new, sending auth url")
496+
logTrace("The node is new, sending auth url")
455497

456-
resp.AuthURL = h.authProvider.AuthURL(machineKey)
498+
resp.AuthURL = h.authProvider.AuthURL(registrationId)
457499

458500
respBody, err := json.Marshal(resp)
459501
if err != nil {
@@ -660,6 +702,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
660702
regReq tailcfg.RegisterRequest,
661703
node types.Node,
662704
machineKey key.MachinePublic,
705+
registrationId types.RegistrationID,
663706
) {
664707
resp := tailcfg.RegisterResponse{}
665708

@@ -673,12 +716,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
673716
log.Trace().
674717
Caller().
675718
Str("node", node.Hostname).
676-
Str("machine_key", machineKey.ShortString()).
719+
Str("registration_id", registrationId.String()).
677720
Str("node_key", regReq.NodeKey.ShortString()).
678721
Str("node_key_old", regReq.OldNodeKey.ShortString()).
679722
Msg("Node registration has expired or logged out. Sending a auth url to register")
680723

681-
resp.AuthURL = h.authProvider.AuthURL(machineKey)
724+
resp.AuthURL = h.authProvider.AuthURL(registrationId)
682725

683726
respBody, err := json.Marshal(resp)
684727
if err != nil {
@@ -703,7 +746,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
703746

704747
log.Trace().
705748
Caller().
706-
Str("machine_key", machineKey.ShortString()).
749+
Str("registration_id", registrationId.String()).
707750
Str("node_key", regReq.NodeKey.ShortString()).
708751
Str("node_key_old", regReq.OldNodeKey.ShortString()).
709752
Str("node", node.Hostname).

0 commit comments

Comments
 (0)