Skip to content

Commit ece488c

Browse files
committed
test using gorm serialiser instead of custom hooks
Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent 7309efd commit ece488c

20 files changed

+242
-353
lines changed

hscontrol/db/db.go

+16-5
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@ import (
2020
"gorm.io/driver/postgres"
2121
"gorm.io/gorm"
2222
"gorm.io/gorm/logger"
23+
"gorm.io/gorm/schema"
2324
"tailscale.com/util/set"
2425
)
2526

27+
func init() {
28+
schema.RegisterSerializer("text", TextSerialiser{})
29+
}
30+
2631
var errDatabaseNotSupported = errors.New("database type not supported")
2732

2833
// KV is a key-value store in a psql table. For future use...
@@ -33,7 +38,8 @@ type KV struct {
3338
}
3439

3540
type HSDatabase struct {
36-
DB *gorm.DB
41+
DB *gorm.DB
42+
cfg *types.DatabaseConfig
3743

3844
baseDomain string
3945
}
@@ -191,7 +197,7 @@ func NewHeadscaleDatabase(
191197

192198
type NodeAux struct {
193199
ID uint64
194-
EnabledRoutes types.IPPrefixes
200+
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
195201
}
196202

197203
nodesAux := []NodeAux{}
@@ -214,7 +220,7 @@ func NewHeadscaleDatabase(
214220
}
215221

216222
err = tx.Preload("Node").
217-
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
223+
Where("node_id = ? AND prefix = ?", node.ID, prefix).
218224
First(&types.Route{}).
219225
Error
220226
if err == nil {
@@ -229,7 +235,7 @@ func NewHeadscaleDatabase(
229235
NodeID: node.ID,
230236
Advertised: true,
231237
Enabled: true,
232-
Prefix: types.IPPrefix(prefix),
238+
Prefix: prefix,
233239
}
234240
if err := tx.Create(&route).Error; err != nil {
235241
log.Error().Err(err).Msg("Error creating route")
@@ -476,7 +482,8 @@ func NewHeadscaleDatabase(
476482
}
477483

478484
db := HSDatabase{
479-
DB: dbConn,
485+
DB: dbConn,
486+
cfg: &cfg,
480487

481488
baseDomain: baseDomain,
482489
}
@@ -676,6 +683,10 @@ func (hsdb *HSDatabase) Close() error {
676683
return err
677684
}
678685

686+
if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog {
687+
db.Exec("VACUUM")
688+
}
689+
679690
return db.Close()
680691
}
681692

hscontrol/db/db_test.go

+28-8
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ import (
1313
"github.com/google/go-cmp/cmp"
1414
"github.com/google/go-cmp/cmp/cmpopts"
1515
"github.com/juanfont/headscale/hscontrol/types"
16+
"github.com/juanfont/headscale/hscontrol/util"
1617
"github.com/stretchr/testify/assert"
1718
"gorm.io/gorm"
1819
)
1920

2021
func TestMigrations(t *testing.T) {
21-
ipp := func(p string) types.IPPrefix {
22-
return types.IPPrefix(netip.MustParsePrefix(p))
22+
ipp := func(p string) netip.Prefix {
23+
return netip.MustParsePrefix(p)
2324
}
2425
r := func(id uint64, p string, a, e, i bool) types.Route {
2526
return types.Route{
@@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
5657
r(31, "::/0", true, false, false),
5758
r(32, "192.168.0.24/32", true, true, true),
5859
}
59-
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
60-
return x == y
61-
})); diff != "" {
60+
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
6261
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
6362
}
6463
},
@@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
103102
r(13, "::/0", true, true, false),
104103
r(13, "10.18.80.2/32", true, true, true),
105104
}
106-
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
107-
return x == y
108-
})); diff != "" {
105+
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
109106
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
110107
}
111108
},
@@ -172,6 +169,29 @@ func TestMigrations(t *testing.T) {
172169
}
173170
},
174171
},
172+
{
173+
dbPath: "testdata/0-23-0-to-0-24-0-no-more-special-types.sqlite",
174+
wantFunc: func(t *testing.T, h *HSDatabase) {
175+
nodes, err := Read(h.DB, func(rx *gorm.DB) (types.Nodes, error) {
176+
return ListNodes(rx)
177+
})
178+
assert.NoError(t, err)
179+
180+
for _, node := range nodes {
181+
assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey")
182+
assert.Contains(t, node.MachineKey.String(), "mkey:")
183+
assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey")
184+
assert.Contains(t, node.NodeKey.String(), "nodekey:")
185+
assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey")
186+
assert.Contains(t, node.DiscoKey.String(), "discokey:")
187+
assert.NotNil(t, node.IPv4)
188+
assert.NotNil(t, node.IPv4)
189+
assert.Len(t, node.Endpoints, 1)
190+
assert.NotNil(t, node.Hostinfo)
191+
assert.NotNil(t, node.MachineKey)
192+
}
193+
},
194+
},
175195
}
176196

177197
for _, tt := range tests {

hscontrol/db/ip_test.go

-37
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package db
22

33
import (
4-
"database/sql"
54
"fmt"
65
"net/netip"
76
"strings"
@@ -294,15 +293,7 @@ func TestBackfillIPAddresses(t *testing.T) {
294293
v4 := fmt.Sprintf("100.64.0.%d", i)
295294
v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i)
296295
return &types.Node{
297-
IPv4DatabaseField: sql.NullString{
298-
Valid: true,
299-
String: v4,
300-
},
301296
IPv4: nap(v4),
302-
IPv6DatabaseField: sql.NullString{
303-
Valid: true,
304-
String: v6,
305-
},
306297
IPv6: nap(v6),
307298
}
308299
}
@@ -334,15 +325,7 @@ func TestBackfillIPAddresses(t *testing.T) {
334325

335326
want: types.Nodes{
336327
&types.Node{
337-
IPv4DatabaseField: sql.NullString{
338-
Valid: true,
339-
String: "100.64.0.1",
340-
},
341328
IPv4: nap("100.64.0.1"),
342-
IPv6DatabaseField: sql.NullString{
343-
Valid: true,
344-
String: "fd7a:115c:a1e0::1",
345-
},
346329
IPv6: nap("fd7a:115c:a1e0::1"),
347330
},
348331
},
@@ -367,15 +350,7 @@ func TestBackfillIPAddresses(t *testing.T) {
367350

368351
want: types.Nodes{
369352
&types.Node{
370-
IPv4DatabaseField: sql.NullString{
371-
Valid: true,
372-
String: "100.64.0.1",
373-
},
374353
IPv4: nap("100.64.0.1"),
375-
IPv6DatabaseField: sql.NullString{
376-
Valid: true,
377-
String: "fd7a:115c:a1e0::1",
378-
},
379354
IPv6: nap("fd7a:115c:a1e0::1"),
380355
},
381356
},
@@ -400,10 +375,6 @@ func TestBackfillIPAddresses(t *testing.T) {
400375

401376
want: types.Nodes{
402377
&types.Node{
403-
IPv4DatabaseField: sql.NullString{
404-
Valid: true,
405-
String: "100.64.0.1",
406-
},
407378
IPv4: nap("100.64.0.1"),
408379
},
409380
},
@@ -428,10 +399,6 @@ func TestBackfillIPAddresses(t *testing.T) {
428399

429400
want: types.Nodes{
430401
&types.Node{
431-
IPv6DatabaseField: sql.NullString{
432-
Valid: true,
433-
String: "fd7a:115c:a1e0::1",
434-
},
435402
IPv6: nap("fd7a:115c:a1e0::1"),
436403
},
437404
},
@@ -477,13 +444,9 @@ func TestBackfillIPAddresses(t *testing.T) {
477444

478445
comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{},
479446
"ID",
480-
"MachineKeyDatabaseField",
481-
"NodeKeyDatabaseField",
482-
"DiscoKeyDatabaseField",
483447
"User",
484448
"UserID",
485449
"Endpoints",
486-
"HostinfoDatabaseField",
487450
"Hostinfo",
488451
"Routes",
489452
"CreatedAt",

hscontrol/db/node.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package db
22

33
import (
4+
"encoding/json"
45
"errors"
56
"fmt"
67
"net/netip"
@@ -207,21 +208,26 @@ func SetTags(
207208
) error {
208209
if len(tags) == 0 {
209210
// if no tags are provided, we remove all forced tags
210-
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
211+
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
211212
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
212213
}
213214

214215
return nil
215216
}
216217

217-
var newTags types.StringList
218+
var newTags []string
218219
for _, tag := range tags {
219220
if !slices.Contains(newTags, tag) {
220221
newTags = append(newTags, tag)
221222
}
222223
}
223224

224-
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
225+
b, err := json.Marshal(newTags)
226+
if err != nil {
227+
return err
228+
}
229+
230+
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
225231
return fmt.Errorf("failed to update tags for node in the database: %w", err)
226232
}
227233

@@ -569,7 +575,7 @@ func enableRoutes(tx *gorm.DB,
569575
for _, prefix := range newRoutes {
570576
route := types.Route{}
571577
err := tx.Preload("Node").
572-
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
578+
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
573579
First(&route).Error
574580
if err == nil {
575581
route.Enabled = true

hscontrol/db/node_test.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
201201
nodeKey := key.NewNode()
202202
machineKey := key.NewMachine()
203203

204-
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1)))
204+
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
205205
node := types.Node{
206206
ID: types.NodeID(index),
207207
MachineKey: machineKey.Public(),
@@ -239,6 +239,8 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
239239

240240
adminNode, err := db.GetNodeByID(1)
241241
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
242+
c.Assert(adminNode.IPv4, check.NotNil)
243+
c.Assert(adminNode.IPv6, check.IsNil)
242244
c.Assert(err, check.IsNil)
243245

244246
testNode, err := db.GetNodeByID(2)
@@ -247,9 +249,11 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
247249

248250
adminPeers, err := db.ListPeers(adminNode.ID)
249251
c.Assert(err, check.IsNil)
252+
c.Assert(len(adminPeers), check.Equals, 9)
250253

251254
testPeers, err := db.ListPeers(testNode.ID)
252255
c.Assert(err, check.IsNil)
256+
c.Assert(len(testPeers), check.Equals, 9)
253257

254258
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers)
255259
c.Assert(err, check.IsNil)
@@ -259,14 +263,14 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
259263

260264
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
261265
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
262-
266+
c.Log(peersOfAdminNode)
263267
c.Log(peersOfTestNode)
268+
264269
c.Assert(len(peersOfTestNode), check.Equals, 9)
265270
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
266271
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
267272
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
268273

269-
c.Log(peersOfAdminNode)
270274
c.Assert(len(peersOfAdminNode), check.Equals, 9)
271275
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
272276
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
@@ -346,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
346350
c.Assert(err, check.IsNil)
347351
node, err = db.getNode("test", "testnode")
348352
c.Assert(err, check.IsNil)
349-
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
353+
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
350354

351355
// assign duplicate tags, expect no errors but no doubles in DB
352356
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
@@ -357,15 +361,15 @@ func (s *Suite) TestSetTags(c *check.C) {
357361
c.Assert(
358362
node.ForcedTags,
359363
check.DeepEquals,
360-
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
364+
[]string{"tag:bar", "tag:test", "tag:unknown"},
361365
)
362366

363367
// test removing tags
364368
err = db.SetTags(node.ID, []string{})
365369
c.Assert(err, check.IsNil)
366370
node, err = db.getNode("test", "testnode")
367371
c.Assert(err, check.IsNil)
368-
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
372+
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
369373
}
370374

371375
func TestHeadscale_generateGivenName(t *testing.T) {

hscontrol/db/preauth_keys.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func CreatePreAuthKey(
7777
Ephemeral: ephemeral,
7878
CreatedAt: &now,
7979
Expiration: expiration,
80-
Tags: types.StringList(aclTags),
80+
Tags: aclTags,
8181
}
8282

8383
if err := tx.Save(&key).Error; err != nil {

hscontrol/db/routes.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
4949
err := tx.
5050
Preload("Node").
5151
Preload("Node.User").
52-
Where("prefix = ?", types.IPPrefix(pref)).
52+
Where("prefix = ?", pref.String()).
5353
Find(&routes).Error
5454
if err != nil {
5555
return nil, err
@@ -286,7 +286,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
286286
var count int64
287287
tx.Model(&types.Route{}).
288288
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
289-
route.Prefix,
289+
route.Prefix.String(),
290290
route.NodeID,
291291
true, true).Count(&count)
292292

@@ -297,7 +297,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
297297
var route types.Route
298298
err := tx.
299299
Preload("Node").
300-
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
300+
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
301301
First(&route).Error
302302
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
303303
return nil, err
@@ -392,7 +392,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
392392
if !exists {
393393
route := types.Route{
394394
NodeID: node.ID.Uint64(),
395-
Prefix: types.IPPrefix(prefix),
395+
Prefix: prefix,
396396
Advertised: true,
397397
Enabled: false,
398398
}

0 commit comments

Comments
 (0)