6
6
"errors"
7
7
"fmt"
8
8
"net/http"
9
+ "net/url"
10
+ "strings"
9
11
"time"
10
12
11
13
"github.com/juanfont/headscale/hscontrol/db"
@@ -20,16 +22,18 @@ import (
20
22
21
23
type AuthProvider interface {
22
24
RegisterHandler (http.ResponseWriter , * http.Request )
23
- AuthURL (key. MachinePublic ) string
25
+ AuthURL (types. RegistrationID ) string
24
26
}
25
27
26
28
func logAuthFunc (
27
29
registerRequest tailcfg.RegisterRequest ,
28
30
machineKey key.MachinePublic ,
31
+ registrationId types.RegistrationID ,
29
32
) (func (string ), func (string ), func (error , string )) {
30
33
return func (msg string ) {
31
34
log .Info ().
32
35
Caller ().
36
+ Str ("registration_id" , registrationId .String ()).
33
37
Str ("machine_key" , machineKey .ShortString ()).
34
38
Str ("node_key" , registerRequest .NodeKey .ShortString ()).
35
39
Str ("node_key_old" , registerRequest .OldNodeKey .ShortString ()).
@@ -41,6 +45,7 @@ func logAuthFunc(
41
45
func (msg string ) {
42
46
log .Trace ().
43
47
Caller ().
48
+ Str ("registration_id" , registrationId .String ()).
44
49
Str ("machine_key" , machineKey .ShortString ()).
45
50
Str ("node_key" , registerRequest .NodeKey .ShortString ()).
46
51
Str ("node_key_old" , registerRequest .OldNodeKey .ShortString ()).
@@ -52,6 +57,7 @@ func logAuthFunc(
52
57
func (err error , msg string ) {
53
58
log .Error ().
54
59
Caller ().
60
+ Str ("registration_id" , registrationId .String ()).
55
61
Str ("machine_key" , machineKey .ShortString ()).
56
62
Str ("node_key" , registerRequest .NodeKey .ShortString ()).
57
63
Str ("node_key_old" , registerRequest .OldNodeKey .ShortString ()).
@@ -63,16 +69,64 @@ func logAuthFunc(
63
69
}
64
70
}
65
71
72
+ func (h * Headscale ) waitForFollowup (
73
+ req * http.Request ,
74
+ regReq tailcfg.RegisterRequest ,
75
+ logTrace func (string ),
76
+ ) {
77
+ logTrace ("register request is a followup" )
78
+ fu , err := url .Parse (regReq .Followup )
79
+ if err != nil {
80
+ logTrace ("failed to parse followup URL" )
81
+ return
82
+ }
83
+
84
+ followupReg , err := types .RegistrationIDFromString (strings .ReplaceAll (fu .Path , "/register/" , "" ))
85
+ if err != nil {
86
+ logTrace ("followup URL does not contains a valid registration ID" )
87
+ return
88
+ }
89
+
90
+ logTrace (fmt .Sprintf ("followup URL contains a valid registration ID, looking up in cache: %s" , followupReg ))
91
+
92
+ if reg , ok := h .registrationCache .Get (followupReg ); ok {
93
+ logTrace ("Node is waiting for interactive login" )
94
+
95
+ select {
96
+ case <- req .Context ().Done ():
97
+ logTrace ("node went away before it was registered" )
98
+ return
99
+ case <- reg .Registered :
100
+ logTrace ("node has successfully registered" )
101
+ return
102
+ }
103
+ }
104
+ }
105
+
66
106
// handleRegister is the logic for registering a client.
67
107
func (h * Headscale ) handleRegister (
68
108
writer http.ResponseWriter ,
69
109
req * http.Request ,
70
110
regReq tailcfg.RegisterRequest ,
71
111
machineKey key.MachinePublic ,
72
112
) {
73
- logInfo , logTrace , _ := logAuthFunc (regReq , machineKey )
113
+ registrationId , err := types .NewRegistrationID ()
114
+ if err != nil {
115
+ log .Error ().
116
+ Caller ().
117
+ Err (err ).
118
+ Msg ("Failed to generate registration ID" )
119
+ http .Error (writer , "Internal server error" , http .StatusInternalServerError )
120
+
121
+ return
122
+ }
123
+
124
+ logInfo , logTrace , _ := logAuthFunc (regReq , machineKey , registrationId )
74
125
now := time .Now ().UTC ()
75
126
logTrace ("handleRegister called, looking up machine in DB" )
127
+
128
+ // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
129
+ // key refreshes. This will allow us to remove the machineKey from the registration request.
76
130
node , err := h .db .GetNodeByAnyKey (machineKey , regReq .NodeKey , regReq .OldNodeKey )
77
131
logTrace ("handleRegister database lookup has returned" )
78
132
if errors .Is (err , gorm .ErrRecordNotFound ) {
@@ -84,27 +138,9 @@ func (h *Headscale) handleRegister(
84
138
}
85
139
86
140
// Check if the node is waiting for interactive login.
87
- //
88
- // TODO(juan): We could use this field to improve our protocol implementation,
89
- // and hold the request until the client closes it, or the interactive
90
- // login is completed (i.e., the user registers the node).
91
- // This is not implemented yet, as it is no strictly required. The only side-effect
92
- // is that the client will hammer headscale with requests until it gets a
93
- // successful RegisterResponse.
94
141
if regReq .Followup != "" {
95
- logTrace ("register request is a followup" )
96
- if _ , ok := h .registrationCache .Get (machineKey .String ()); ok {
97
- logTrace ("Node is waiting for interactive login" )
98
-
99
- select {
100
- case <- req .Context ().Done ():
101
- return
102
- case <- time .After (registrationHoldoff ):
103
- h .handleNewNode (writer , regReq , machineKey )
104
-
105
- return
106
- }
107
- }
142
+ h .waitForFollowup (req , regReq , logTrace )
143
+ return
108
144
}
109
145
110
146
logInfo ("Node not found in database, creating new" )
@@ -113,25 +149,28 @@ func (h *Headscale) handleRegister(
113
149
// that we rely on a method that calls back some how (OpenID or CLI)
114
150
// We create the node and then keep it around until a callback
115
151
// happens
116
- newNode := types.Node {
117
- MachineKey : machineKey ,
118
- Hostname : regReq .Hostinfo .Hostname ,
119
- NodeKey : regReq .NodeKey ,
120
- LastSeen : & now ,
121
- Expiry : & time.Time {},
152
+ newNode := types.RegisterNode {
153
+ Node : types.Node {
154
+ MachineKey : machineKey ,
155
+ Hostname : regReq .Hostinfo .Hostname ,
156
+ NodeKey : regReq .NodeKey ,
157
+ LastSeen : & now ,
158
+ Expiry : & time.Time {},
159
+ },
160
+ Registered : make (chan struct {}),
122
161
}
123
162
124
163
if ! regReq .Expiry .IsZero () {
125
164
logTrace ("Non-zero expiry time requested" )
126
- newNode .Expiry = & regReq .Expiry
165
+ newNode .Node . Expiry = & regReq .Expiry
127
166
}
128
167
129
168
h .registrationCache .Set (
130
- machineKey . String () ,
169
+ registrationId ,
131
170
newNode ,
132
171
)
133
172
134
- h .handleNewNode (writer , regReq , machineKey )
173
+ h .handleNewNode (writer , regReq , registrationId )
135
174
136
175
return
137
176
}
@@ -206,27 +245,28 @@ func (h *Headscale) handleRegister(
206
245
}
207
246
208
247
if regReq .Followup != "" {
209
- select {
210
- case <- req .Context ().Done ():
211
- return
212
- case <- time .After (registrationHoldoff ):
213
- }
248
+ h .waitForFollowup (req , regReq , logTrace )
249
+ return
214
250
}
215
251
216
252
// The node has expired or it is logged out
217
- h .handleNodeExpiredOrLoggedOut (writer , regReq , * node , machineKey )
253
+ h .handleNodeExpiredOrLoggedOut (writer , regReq , * node , machineKey , registrationId )
218
254
219
255
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
220
256
node .Expiry = & time.Time {}
221
257
258
+ // TODO(kradalby): do we need to rethink this as part of authflow?
222
259
// If we are here it means the client needs to be reauthorized,
223
260
// we need to make sure the NodeKey matches the one in the request
224
261
// TODO(juan): What happens when using fast user switching between two
225
262
// headscale-managed tailnets?
226
263
node .NodeKey = regReq .NodeKey
227
264
h .registrationCache .Set (
228
- machineKey .String (),
229
- * node ,
265
+ registrationId ,
266
+ types.RegisterNode {
267
+ Node : * node ,
268
+ Registered : make (chan struct {}),
269
+ },
230
270
)
231
271
232
272
return
@@ -296,6 +336,8 @@ func (h *Headscale) handleAuthKey(
296
336
// The error is not important, because if it does not
297
337
// exist, then this is a new node and we will move
298
338
// on to registration.
339
+ // TODO(kradalby): Use reqs NodeKey and OldNodeKey as indicators for new registrations vs
340
+ // key refreshes. This will allow us to remove the machineKey from the registration request.
299
341
node , _ := h .db .GetNodeByAnyKey (machineKey , registerRequest .NodeKey , registerRequest .OldNodeKey )
300
342
if node != nil {
301
343
log .Trace ().
@@ -444,16 +486,16 @@ func (h *Headscale) handleAuthKey(
444
486
func (h * Headscale ) handleNewNode (
445
487
writer http.ResponseWriter ,
446
488
registerRequest tailcfg.RegisterRequest ,
447
- machineKey key. MachinePublic ,
489
+ registrationId types. RegistrationID ,
448
490
) {
449
- logInfo , logTrace , logErr := logAuthFunc (registerRequest , machineKey )
491
+ logInfo , logTrace , logErr := logAuthFunc (registerRequest , key. MachinePublic {}, registrationId )
450
492
451
493
resp := tailcfg.RegisterResponse {}
452
494
453
495
// The node registration is new, redirect the client to the registration URL
454
- logTrace ("The node seems to be new, sending auth url" )
496
+ logTrace ("The node is new, sending auth url" )
455
497
456
- resp .AuthURL = h .authProvider .AuthURL (machineKey )
498
+ resp .AuthURL = h .authProvider .AuthURL (registrationId )
457
499
458
500
respBody , err := json .Marshal (resp )
459
501
if err != nil {
@@ -660,6 +702,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
660
702
regReq tailcfg.RegisterRequest ,
661
703
node types.Node ,
662
704
machineKey key.MachinePublic ,
705
+ registrationId types.RegistrationID ,
663
706
) {
664
707
resp := tailcfg.RegisterResponse {}
665
708
@@ -673,12 +716,12 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
673
716
log .Trace ().
674
717
Caller ().
675
718
Str ("node" , node .Hostname ).
676
- Str ("machine_key " , machineKey . ShortString ()).
719
+ Str ("registration_id " , registrationId . String ()).
677
720
Str ("node_key" , regReq .NodeKey .ShortString ()).
678
721
Str ("node_key_old" , regReq .OldNodeKey .ShortString ()).
679
722
Msg ("Node registration has expired or logged out. Sending a auth url to register" )
680
723
681
- resp .AuthURL = h .authProvider .AuthURL (machineKey )
724
+ resp .AuthURL = h .authProvider .AuthURL (registrationId )
682
725
683
726
respBody , err := json .Marshal (resp )
684
727
if err != nil {
@@ -703,7 +746,7 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut(
703
746
704
747
log .Trace ().
705
748
Caller ().
706
- Str ("machine_key " , machineKey . ShortString ()).
749
+ Str ("registration_id " , registrationId . String ()).
707
750
Str ("node_key" , regReq .NodeKey .ShortString ()).
708
751
Str ("node_key_old" , regReq .OldNodeKey .ShortString ()).
709
752
Str ("node" , node .Hostname ).
0 commit comments