Skip to content

Commit 5906633

Browse files
committed
Add unit tests for the wireguard cable driver
Also simplified or removed unnecessary code in the driver to increase coverage and eliminate paths that need to be tested. Signed-off-by: Tom Pantelis <[email protected]>
1 parent fb199cc commit 5906633

File tree

4 files changed

+741
-69
lines changed

4 files changed

+741
-69
lines changed

pkg/cable/wireguard/driver.go

+25-48
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ import (
2323
"crypto/sha256"
2424
"fmt"
2525
"net"
26-
"os"
2726
"slices"
2827
"time"
2928

3029
"github.com/kelseyhightower/envconfig"
3130
"github.com/pkg/errors"
3231
"github.com/submariner-io/admiral/pkg/log"
32+
"github.com/submariner-io/admiral/pkg/resource"
3333
v1 "github.com/submariner-io/submariner/pkg/apis/submariner.io/v1"
3434
"github.com/submariner-io/submariner/pkg/cable"
3535
"github.com/submariner-io/submariner/pkg/endpoint"
@@ -50,18 +50,20 @@ const (
5050
// PublicKey is name (key) of publicKey entry in back-end map.
5151
PublicKey = "publicKey"
5252

53-
// KeepAliveInterval to use for wg peers.
54-
KeepAliveInterval = 10 * time.Second
55-
56-
// handshakeTimeout is maximal time from handshake a connections is still considered connected.
57-
handshakeTimeout = 2*time.Minute + 10*time.Second
58-
5953
cableDriverName = "wireguard"
6054
receiveBytes = "ReceiveBytes" // for peer connection status
6155
transmitBytes = "TransmitBytes" // for peer connection status
6256
lastChecked = "LastChecked" // for connection peer status
6357
)
6458

59+
var (
60+
// KeepAliveInterval to use for wg peers.
61+
KeepAliveInterval = 10 * time.Second
62+
63+
// HandshakeTimeout is maximal time from handshake a connections is still considered connected.
64+
HandshakeTimeout = 2*time.Minute + 10*time.Second
65+
)
66+
6567
var logger = log.Logger{Logger: logf.Log.WithName("wireguard")}
6668

6769
func init() {
@@ -104,10 +106,6 @@ func NewDriver(localEndpoint *endpoint.Local, _ *types.SubmarinerCluster) (cable
104106

105107
// Create the controller.
106108
if w.client, err = NewClient(); err != nil {
107-
if os.IsNotExist(err) {
108-
return nil, errors.New("wgctrl is not available on this system")
109-
}
110-
111109
return nil, errors.Wrap(err, "failed to open wgctl client")
112110
}
113111

@@ -173,10 +171,6 @@ func NewDriver(localEndpoint *endpoint.Local, _ *types.SubmarinerCluster) (cable
173171
func (w *wireguard) Init() error {
174172
logger.V(log.DEBUG).Infof("Initializing WireGuard device for cluster %s", w.localEndpoint.ClusterID)
175173

176-
if len(w.connections) != 0 {
177-
return fmt.Errorf("cannot initialize with existing connections: %+v", w.connections)
178-
}
179-
180174
l, err := w.netLink.InterfaceByName(DefaultDeviceName)
181175
if err != nil {
182176
return errors.Wrapf(err, "cannot get wireguard link by name %s", DefaultDeviceName)
@@ -187,11 +181,7 @@ func (w *wireguard) Init() error {
187181
return errors.Wrap(err, "wgctrl cannot find WireGuard device")
188182
}
189183

190-
k, err := keyFromSpec(&w.localEndpoint)
191-
if err != nil {
192-
return errors.Wrapf(err, "endpoint is missing public key %s", d.PublicKey)
193-
}
194-
184+
k, _ := keyFromSpec(&w.localEndpoint)
195185
if k.String() != d.PublicKey.String() {
196186
return fmt.Errorf("endpoint public key %s is different from device key %s", k, d.PublicKey)
197187
}
@@ -227,10 +217,10 @@ func (w *wireguard) ConnectToEndpoint(endpointInfo *natdiscovery.NATEndpointInfo
227217
// Parse remote public key.
228218
remoteKey, err := keyFromSpec(&remoteEndpoint.Spec)
229219
if err != nil {
230-
return "", errors.Wrap(err, "failed to parse peer public key")
220+
return "", errors.Wrapf(err, "failed to obtain public key for endpoint %s", resource.ToJSON(remoteEndpoint.Spec))
231221
}
232222

233-
logger.V(log.DEBUG).Infof("Connecting cluster %s endpoint %s with publicKey %s",
223+
logger.V(log.DEBUG).Infof("Connecting cluster %q endpoint %q with publicKey %q",
234224
remoteEndpoint.Spec.ClusterID, remoteIP, remoteKey)
235225

236226
// Delete or update old peers for ClusterID.
@@ -240,7 +230,7 @@ func (w *wireguard) ConnectToEndpoint(endpointInfo *natdiscovery.NATEndpointInfo
240230
if oldKey.String() == remoteKey.String() {
241231
// Existing connection, update status and skip.
242232
w.updatePeerStatus(oldCon, oldKey)
243-
logger.V(log.DEBUG).Infof("Skipping connect for existing peer key %s", oldKey)
233+
logger.V(log.DEBUG).Infof("Skipping connect for existing peer key %q", oldKey)
244234

245235
return ip, nil
246236
}
@@ -254,9 +244,11 @@ func (w *wireguard) ConnectToEndpoint(endpointInfo *natdiscovery.NATEndpointInfo
254244
// create connection, overwrite existing connection
255245
connection := v1.NewConnection(&remoteEndpoint.Spec, ip, endpointInfo.UseNAT)
256246
connection.SetStatus(v1.Connecting, "Connection has been created but not yet started")
257-
logger.V(log.DEBUG).Infof("Adding connection for cluster %s, %v", remoteEndpoint.Spec.ClusterID, connection)
258247
w.connections[remoteEndpoint.Spec.ClusterID] = connection
259248

249+
logger.V(log.DEBUG).Infof("Added connection for cluster %q: %s", remoteEndpoint.Spec.ClusterID,
250+
resource.ToJSON(connection))
251+
260252
port, err := remoteEndpoint.Spec.GetBackendPort(v1.UDPPortConfig, w.spec.NATTPort)
261253
if err != nil {
262254
logger.Warningf("Error parsing %q from remote endpoint %q - using port %dº instead: %v", v1.UDPPortConfig,
@@ -293,7 +285,7 @@ func (w *wireguard) ConnectToEndpoint(endpointInfo *natdiscovery.NATEndpointInfo
293285
logger.Errorf(err, "Failed to verify peer configuration")
294286
}
295287

296-
logger.V(log.DEBUG).Infof("Done connecting endpoint peer %s@%s", *remoteKey, remoteIP)
288+
logger.V(log.DEBUG).Infof("Successfully connected endpoint peer %q with IP %q", *remoteKey, remoteIP)
297289

298290
cable.RecordConnection(cableDriverName, &w.localEndpoint, &connection.Endpoint, string(v1.Connected), true, endpointInfo.UseFamily)
299291

@@ -303,20 +295,17 @@ func (w *wireguard) ConnectToEndpoint(endpointInfo *natdiscovery.NATEndpointInfo
303295
func keyFromSpec(ep *v1.EndpointSpec) (*wgtypes.Key, error) {
304296
s, found := ep.BackendConfig[PublicKey]
305297
if !found {
306-
return nil, errors.New("endpoint is missing public key")
298+
return &wgtypes.Key{}, errors.New("endpoint is missing public key")
307299
}
308300

309301
key, err := wgtypes.ParseKey(s)
310-
if err != nil {
311-
return nil, errors.Wrapf(err, "failed to parse public key %s", s)
312-
}
313302

314-
return &key, nil
303+
return &key, errors.Wrapf(err, "failed to parse public key %s", s)
315304
}
316305

317306
func (w *wireguard) DisconnectFromEndpoint(remoteEndpoint *types.SubmarinerEndpoint, family k8snet.IPFamily) error {
318307
// We'll panic if remoteEndpoint is nil, this is intentional
319-
logger.V(log.DEBUG).Infof("Removing IPv%v endpoint %v+", family, remoteEndpoint)
308+
logger.V(log.DEBUG).Infof("Removing IPv%v endpoint %s", family, resource.ToJSON(remoteEndpoint))
320309

321310
// parse remote public key
322311
remoteKey, err := keyFromSpec(&remoteEndpoint.Spec)
@@ -335,7 +324,7 @@ func (w *wireguard) DisconnectFromEndpoint(remoteEndpoint *types.SubmarinerEndpo
335324

336325
delete(w.connections, remoteEndpoint.Spec.ClusterID)
337326

338-
logger.V(log.DEBUG).Infof("Done removing endpoint for cluster %s", remoteEndpoint.Spec.ClusterID)
327+
logger.V(log.DEBUG).Infof("Done removing endpoint for cluster %q", remoteEndpoint.Spec.ClusterID)
339328
cable.RecordDisconnected(cableDriverName, &w.localEndpoint, &remoteEndpoint.Spec, family)
340329

341330
return nil
@@ -384,13 +373,8 @@ func (w *wireguard) removePeer(key *wgtypes.Key) error {
384373
ReplacePeers: false,
385374
Peers: peerCfg,
386375
})
387-
if err != nil {
388-
return errors.Wrapf(err, "failed to remove WireGuard peer with key %s", key)
389-
}
390-
391-
logger.V(log.DEBUG).Infof("Done removing WireGuard peer with key %s", key)
392376

393-
return nil
377+
return errors.Wrapf(err, "failed to remove WireGuard peer with key %s", key)
394378
}
395379

396380
func (w *wireguard) peerByKey(key *wgtypes.Key) (*wgtypes.Peer, error) {
@@ -441,12 +425,7 @@ func (w *wireguard) keyMismatch(cid string, key *wgtypes.Key) bool {
441425
return true
442426
}
443427

444-
oldKey, err := keyFromSpec(&c.Endpoint)
445-
if err != nil {
446-
logger.Warningf("Could not find old key of cluster %s, mismatched endpoint key %s", cid, key)
447-
return true
448-
}
449-
428+
oldKey, _ := keyFromSpec(&c.Endpoint)
450429
if oldKey.String() != key.String() {
451430
logger.Warningf("Key mismatch, cluster %s key is %s, endpoint key is %s", cid, oldKey, key)
452431
return true
@@ -469,9 +448,7 @@ func (w *wireguard) Cleanup() error {
469448
return errors.Wrapf(err, "error retrieving the wireguard interface %q", DefaultDeviceName)
470449
}
471450

472-
if err := w.netLink.LinkDel(link); err != nil {
473-
return errors.Wrapf(err, "failed to delete existing WireGuard device %q", DefaultDeviceName)
474-
}
451+
err = w.netLink.LinkDel(link)
475452

476-
return nil
453+
return errors.Wrapf(err, "failed to delete existing WireGuard device %q", DefaultDeviceName)
477454
}

pkg/cable/wireguard/getconnections.go

+16-21
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ limitations under the License.
1919
package wireguard
2020

2121
import (
22-
"fmt"
2322
"strconv"
2423
"time"
2524

@@ -41,8 +40,8 @@ func (w *wireguard) GetConnections() ([]v1.Connection, error) {
4140
for i := range d.Peers {
4241
key := d.Peers[i].PublicKey
4342

44-
connection, err := w.connectionByKey(&key)
45-
if err != nil {
43+
connection := w.connectionByKey(&key)
44+
if connection == nil {
4645
logger.Warningf("Found unknown peer with key %s, removing", key)
4746

4847
if err := w.removePeer(&key); err != nil {
@@ -59,18 +58,17 @@ func (w *wireguard) GetConnections() ([]v1.Connection, error) {
5958
return connections, nil
6059
}
6160

62-
func (w *wireguard) connectionByKey(key *wgtypes.Key) (*v1.Connection, error) {
61+
func (w *wireguard) connectionByKey(key *wgtypes.Key) *v1.Connection {
6362
for i := range w.connections {
64-
if k, err := keyFromSpec(&w.connections[i].Endpoint); err == nil {
65-
if key.String() == k.String() {
66-
return w.connections[i], nil
67-
}
68-
} else {
69-
logger.Errorf(err, "Could not compare key for cluster %s, skipping", i)
63+
// Since the endpoint was added to the connections list, it must have a valid public key, so we can
64+
// safely ignore the error.
65+
k, _ := keyFromSpec(&w.connections[i].Endpoint)
66+
if key.String() == k.String() {
67+
return w.connections[i]
7068
}
7169
}
7270

73-
return nil, fmt.Errorf("connection not found for key %s", key)
71+
return nil
7472
}
7573

7674
// Update logic, based on delta from last check good state requires a handshake and traffic if no handshake or stale handshake.
@@ -92,21 +90,17 @@ func (w *wireguard) updateConnectionForPeer(p *wgtypes.Peer, connection *v1.Conn
9290
connectionFamily := connection.GetFamily()
9391

9492
if p.LastHandshakeTime.IsZero() {
95-
if lc > handshakeTimeout.Milliseconds() {
93+
if lc > HandshakeTimeout.Milliseconds() {
9694
// No initial handshake for too long.
9795
connection.SetStatus(v1.ConnectionError, "no initial handshake for %.1f seconds", lcSec)
9896
cable.RecordConnection(cableDriverName, &w.localEndpoint, &connection.Endpoint, string(connection.Status), false, connectionFamily)
99-
100-
return
101-
}
102-
103-
if tx > 0 || rx > 0 {
97+
} else if tx > 0 || rx > 0 {
10498
// No handshake, but at least some communication in progress.
10599
connection.SetStatus(v1.Connecting, "no initial handshake yet")
106100
cable.RecordConnection(cableDriverName, &w.localEndpoint, &connection.Endpoint, string(connection.Status), false, connectionFamily)
107-
108-
return
109101
}
102+
103+
return
110104
}
111105

112106
if tx > 0 || rx > 0 {
@@ -120,7 +114,7 @@ func (w *wireguard) updateConnectionForPeer(p *wgtypes.Peer, connection *v1.Conn
120114

121115
handshakeDelta := time.Since(p.LastHandshakeTime)
122116

123-
if handshakeDelta > handshakeTimeout {
117+
if handshakeDelta > HandshakeTimeout {
124118
// Hard error, really long time since handshake.
125119
connection.SetStatus(v1.ConnectionError, "no handshake for %.1f seconds",
126120
handshakeDelta.Seconds())
@@ -131,10 +125,11 @@ func (w *wireguard) updateConnectionForPeer(p *wgtypes.Peer, connection *v1.Conn
131125

132126
if lc < 2*keepAliveMS {
133127
// Grace period, leave status unchanged.
134-
logger.Warningf("No traffic for %.1f seconds; handshake was %.1f seconds ago: %v", lcSec,
128+
logger.Warningf("No traffic for %.1f seconds; handshake was %.1f seconds ago: %#v", lcSec,
135129
handshakeDelta.Seconds(), connection)
136130
return
137131
}
132+
138133
// Soft error, no traffic, stale handshake.
139134
connection.SetStatus(v1.ConnectionError, "no bytes sent or received for %.1f seconds",
140135
lcSec)

0 commit comments

Comments
 (0)