Skip to content

Commit f93c63e

Browse files
committed
remove type wrappers for db
Signed-off-by: Kristoffer Dalby <[email protected]>
1 parent 06a5ec0 commit f93c63e

14 files changed

+55
-143
lines changed

hscontrol/db/db.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ func NewHeadscaleDatabase(
197197

198198
type NodeAux struct {
199199
ID uint64
200-
EnabledRoutes types.IPPrefixes
200+
EnabledRoutes []netip.Prefix `gorm:"serializer:json"`
201201
}
202202

203203
nodesAux := []NodeAux{}
@@ -220,7 +220,7 @@ func NewHeadscaleDatabase(
220220
}
221221

222222
err = tx.Preload("Node").
223-
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
223+
Where("node_id = ? AND prefix = ?", node.ID, prefix).
224224
First(&types.Route{}).
225225
Error
226226
if err == nil {
@@ -235,7 +235,7 @@ func NewHeadscaleDatabase(
235235
NodeID: node.ID,
236236
Advertised: true,
237237
Enabled: true,
238-
Prefix: types.IPPrefix(prefix),
238+
Prefix: prefix,
239239
}
240240
if err := tx.Create(&route).Error; err != nil {
241241
log.Error().Err(err).Msg("Error creating route")

hscontrol/db/db_test.go

+5-8
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ import (
1313
"github.com/google/go-cmp/cmp"
1414
"github.com/google/go-cmp/cmp/cmpopts"
1515
"github.com/juanfont/headscale/hscontrol/types"
16+
"github.com/juanfont/headscale/hscontrol/util"
1617
"github.com/stretchr/testify/assert"
1718
"gorm.io/gorm"
1819
)
1920

2021
func TestMigrations(t *testing.T) {
21-
ipp := func(p string) types.IPPrefix {
22-
return types.IPPrefix(netip.MustParsePrefix(p))
22+
ipp := func(p string) netip.Prefix {
23+
return netip.MustParsePrefix(p)
2324
}
2425
r := func(id uint64, p string, a, e, i bool) types.Route {
2526
return types.Route{
@@ -56,9 +57,7 @@ func TestMigrations(t *testing.T) {
5657
r(31, "::/0", true, false, false),
5758
r(32, "192.168.0.24/32", true, true, true),
5859
}
59-
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
60-
return x == y
61-
})); diff != "" {
60+
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
6261
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
6362
}
6463
},
@@ -103,9 +102,7 @@ func TestMigrations(t *testing.T) {
103102
r(13, "::/0", true, true, false),
104103
r(13, "10.18.80.2/32", true, true, true),
105104
}
106-
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), cmp.Comparer(func(x, y types.IPPrefix) bool {
107-
return x == y
108-
})); diff != "" {
105+
if diff := cmp.Diff(want, routes, cmpopts.IgnoreFields(types.Route{}, "Model", "Node"), util.PrefixComparer); diff != "" {
109106
t.Errorf("TestMigrations() mismatch (-want +got):\n%s", diff)
110107
}
111108
},

hscontrol/db/node.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package db
22

33
import (
4+
"encoding/json"
45
"errors"
56
"fmt"
67
"net/netip"
@@ -206,21 +207,26 @@ func SetTags(
206207
) error {
207208
if len(tags) == 0 {
208209
// if no tags are provided, we remove all forced tags
209-
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", types.StringList{}).Error; err != nil {
210+
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
210211
return fmt.Errorf("failed to remove tags for node in the database: %w", err)
211212
}
212213

213214
return nil
214215
}
215216

216-
var newTags types.StringList
217+
var newTags []string
217218
for _, tag := range tags {
218219
if !util.StringOrPrefixListContains(newTags, tag) {
219220
newTags = append(newTags, tag)
220221
}
221222
}
222223

223-
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", newTags).Error; err != nil {
224+
b, err := json.Marshal(newTags)
225+
if err != nil {
226+
return err
227+
}
228+
229+
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
224230
return fmt.Errorf("failed to update tags for node in the database: %w", err)
225231
}
226232

@@ -578,7 +584,7 @@ func enableRoutes(tx *gorm.DB,
578584
for _, prefix := range newRoutes {
579585
route := types.Route{}
580586
err := tx.Preload("Node").
581-
Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)).
587+
Where("node_id = ? AND prefix = ?", node.ID, prefix.String()).
582588
First(&route).Error
583589
if err == nil {
584590
route.Enabled = true

hscontrol/db/node_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ func (s *Suite) TestSetTags(c *check.C) {
350350
c.Assert(err, check.IsNil)
351351
node, err = db.getNode("test", "testnode")
352352
c.Assert(err, check.IsNil)
353-
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags))
353+
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
354354

355355
// assign duplicate tags, expect no errors but no doubles in DB
356356
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
@@ -361,15 +361,15 @@ func (s *Suite) TestSetTags(c *check.C) {
361361
c.Assert(
362362
node.ForcedTags,
363363
check.DeepEquals,
364-
types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}),
364+
[]string{"tag:bar", "tag:test", "tag:unknown"},
365365
)
366366

367367
// test removing tags
368368
err = db.SetTags(node.ID, []string{})
369369
c.Assert(err, check.IsNil)
370370
node, err = db.getNode("test", "testnode")
371371
c.Assert(err, check.IsNil)
372-
c.Assert(node.ForcedTags, check.DeepEquals, types.StringList([]string{}))
372+
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
373373
}
374374

375375
func TestHeadscale_generateGivenName(t *testing.T) {

hscontrol/db/preauth_keys.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func CreatePreAuthKey(
7777
Ephemeral: ephemeral,
7878
CreatedAt: &now,
7979
Expiration: expiration,
80-
Tags: types.StringList(aclTags),
80+
Tags: aclTags,
8181
}
8282

8383
if err := tx.Save(&key).Error; err != nil {

hscontrol/db/routes.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func getRoutesByPrefix(tx *gorm.DB, pref netip.Prefix) (types.Routes, error) {
4848
err := tx.
4949
Preload("Node").
5050
Preload("Node.User").
51-
Where("prefix = ?", types.IPPrefix(pref)).
51+
Where("prefix = ?", pref.String()).
5252
Find(&routes).Error
5353
if err != nil {
5454
return nil, err
@@ -285,7 +285,7 @@ func isUniquePrefix(tx *gorm.DB, route types.Route) bool {
285285
var count int64
286286
tx.Model(&types.Route{}).
287287
Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?",
288-
route.Prefix,
288+
route.Prefix.String(),
289289
route.NodeID,
290290
true, true).Count(&count)
291291

@@ -296,7 +296,7 @@ func getPrimaryRoute(tx *gorm.DB, prefix netip.Prefix) (*types.Route, error) {
296296
var route types.Route
297297
err := tx.
298298
Preload("Node").
299-
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true).
299+
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", prefix.String(), true, true, true).
300300
First(&route).Error
301301
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
302302
return nil, err
@@ -391,7 +391,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) {
391391
if !exists {
392392
route := types.Route{
393393
NodeID: node.ID.Uint64(),
394-
Prefix: types.IPPrefix(prefix),
394+
Prefix: prefix,
395395
Advertised: true,
396396
Enabled: false,
397397
}

hscontrol/db/routes_test.go

+6-13
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
286286
}
287287

288288
var (
289-
ipp = func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) }
289+
ipp = func(s string) netip.Prefix { return netip.MustParsePrefix(s) }
290290
mkNode = func(nid types.NodeID) types.Node {
291291
return types.Node{ID: nid}
292292
}
@@ -297,7 +297,7 @@ var np = func(nid types.NodeID) *types.Node {
297297
return &no
298298
}
299299

300-
var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
300+
var r = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
301301
return types.Route{
302302
Model: gorm.Model{
303303
ID: id,
@@ -309,7 +309,7 @@ var r = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary
309309
}
310310
}
311311

312-
var rp = func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
312+
var rp = func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
313313
ro := r(id, nid, prefix, enabled, primary)
314314
return &ro
315315
}
@@ -1065,7 +1065,7 @@ func TestFailoverRouteTx(t *testing.T) {
10651065
}
10661066

10671067
func TestFailoverRoute(t *testing.T) {
1068-
r := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) types.Route {
1068+
r := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) types.Route {
10691069
return types.Route{
10701070
Model: gorm.Model{
10711071
ID: id,
@@ -1078,7 +1078,7 @@ func TestFailoverRoute(t *testing.T) {
10781078
IsPrimary: primary,
10791079
}
10801080
}
1081-
rp := func(id uint, nid types.NodeID, prefix types.IPPrefix, enabled, primary bool) *types.Route {
1081+
rp := func(id uint, nid types.NodeID, prefix netip.Prefix, enabled, primary bool) *types.Route {
10821082
ro := r(id, nid, prefix, enabled, primary)
10831083
return &ro
10841084
}
@@ -1201,13 +1201,6 @@ func TestFailoverRoute(t *testing.T) {
12011201
},
12021202
}
12031203

1204-
cmps := append(
1205-
util.Comparers,
1206-
cmp.Comparer(func(x, y types.IPPrefix) bool {
1207-
return netip.Prefix(x) == netip.Prefix(y)
1208-
}),
1209-
)
1210-
12111204
for _, tt := range tests {
12121205
t.Run(tt.name, func(t *testing.T) {
12131206
gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes)
@@ -1231,7 +1224,7 @@ func TestFailoverRoute(t *testing.T) {
12311224
"old": gotf.old,
12321225
}
12331226

1234-
if diff := cmp.Diff(want, got, cmps...); diff != "" {
1227+
if diff := cmp.Diff(want, got, util.Comparers...); diff != "" {
12351228
t.Fatalf("failoverRoute unexpected result (-want +got):\n%s", diff)
12361229
}
12371230
}

hscontrol/mapper/mapper_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -195,19 +195,19 @@ func Test_fullMapResponse(t *testing.T) {
195195
Hostinfo: &tailcfg.Hostinfo{},
196196
Routes: []types.Route{
197197
{
198-
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
198+
Prefix: netip.MustParsePrefix("0.0.0.0/0"),
199199
Advertised: true,
200200
Enabled: true,
201201
IsPrimary: false,
202202
},
203203
{
204-
Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")),
204+
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
205205
Advertised: true,
206206
Enabled: true,
207207
IsPrimary: true,
208208
},
209209
{
210-
Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")),
210+
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
211211
Advertised: true,
212212
Enabled: false,
213213
IsPrimary: true,

hscontrol/mapper/tail_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,19 @@ func TestTailNode(t *testing.T) {
108108
Hostinfo: &tailcfg.Hostinfo{},
109109
Routes: []types.Route{
110110
{
111-
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
111+
Prefix: netip.MustParsePrefix("0.0.0.0/0"),
112112
Advertised: true,
113113
Enabled: true,
114114
IsPrimary: false,
115115
},
116116
{
117-
Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")),
117+
Prefix: netip.MustParsePrefix("192.168.0.0/24"),
118118
Advertised: true,
119119
Enabled: true,
120120
IsPrimary: true,
121121
},
122122
{
123-
Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")),
123+
Prefix: netip.MustParsePrefix("172.0.0.0/10"),
124124
Advertised: true,
125125
Enabled: false,
126126
IsPrimary: true,

hscontrol/policy/acls_test.go

+5-12
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ func TestParsing(t *testing.T) {
341341
],
342342
},
343343
],
344-
}
344+
}
345345
`,
346346
want: []tailcfg.FilterRule{
347347
{
@@ -2383,7 +2383,7 @@ func TestReduceFilterRules(t *testing.T) {
23832383
Hostinfo: &tailcfg.Hostinfo{
23842384
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
23852385
},
2386-
ForcedTags: types.StringList{"tag:access-servers"},
2386+
ForcedTags: []string{"tag:access-servers"},
23872387
},
23882388
peers: types.Nodes{
23892389
&types.Node{
@@ -3180,7 +3180,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
31803180
Routes: types.Routes{
31813181
types.Route{
31823182
NodeID: 2,
3183-
Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")),
3183+
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
31843184
IsPrimary: true,
31853185
Enabled: true,
31863186
},
@@ -3213,7 +3213,7 @@ func Test_getFilteredByACLPeers(t *testing.T) {
32133213
Routes: types.Routes{
32143214
types.Route{
32153215
NodeID: 2,
3216-
Prefix: types.IPPrefix(netip.MustParsePrefix("10.33.0.0/16")),
3216+
Prefix: netip.MustParsePrefix("10.33.0.0/16"),
32173217
IsPrimary: true,
32183218
Enabled: true,
32193219
},
@@ -3223,21 +3223,14 @@ func Test_getFilteredByACLPeers(t *testing.T) {
32233223
},
32243224
}
32253225

3226-
// TODO(kradalby): Remove when we have gotten rid of IPPrefix type
3227-
prefixComparer := cmp.Comparer(func(x, y types.IPPrefix) bool {
3228-
return x == y
3229-
})
3230-
comparers := append([]cmp.Option{}, util.Comparers...)
3231-
comparers = append(comparers, prefixComparer)
3232-
32333226
for _, tt := range tests {
32343227
t.Run(tt.name, func(t *testing.T) {
32353228
got := FilterNodesByACL(
32363229
tt.args.node,
32373230
tt.args.nodes,
32383231
tt.args.rules,
32393232
)
3240-
if diff := cmp.Diff(tt.want, got, comparers...); diff != "" {
3233+
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
32413234
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
32423235
}
32433236
})

0 commit comments

Comments
 (0)