Skip to content

[client] fix/proxy close #2873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 11, 2024
6 changes: 5 additions & 1 deletion client/iface/wgproxy/bind/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package bind

import (
"context"
"errors"
"fmt"
"net"
"net/netip"
Expand Down Expand Up @@ -94,7 +95,10 @@ func (p *ProxyBind) close() error {

p.Bind.RemoveEndpoint(p.wgAddr)

return p.remoteConn.Close()
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}

func (p *ProxyBind) proxyToLocal(ctx context.Context) {
Expand Down
2 changes: 1 addition & 1 deletion client/iface/wgproxy/ebpf/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {

e.cancel()

if err := e.remoteConn.Close(); err != nil {
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion client/iface/wgproxy/udp/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (p *WGUDPProxy) close() error {
p.cancel()

var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}

Expand Down
13 changes: 11 additions & 2 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {

if conn.iceP2PIsActive() {
conn.log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
return
Expand All @@ -465,7 +465,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.currentConnPriority = connPriorityRelay
conn.statusRelay.Set(StatusConnected)
conn.wgProxyRelay = wgProxy
conn.setRelayedProxy(wgProxy)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
Expand Down Expand Up @@ -736,6 +736,15 @@ func (conn *Conn) logTraceConnState() {
}
}

func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
if conn.wgProxyRelay != nil {
if err := conn.wgProxyRelay.CloseConn(); err != nil {
conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
}
}
conn.wgProxyRelay = proxy
}

func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
Expand Down
88 changes: 53 additions & 35 deletions client/internal/peer/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,6 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.IP = receivedState.IP
}

if receivedState.GetRoutes() != nil {
peerState.SetRoutes(receivedState.GetRoutes())
}

skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)

if receivedState.ConnStatus != peerState.ConnStatus {
Expand All @@ -261,12 +257,40 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
d.notifyPeerListChanged()
return nil
}

func (d *Status) AddPeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()

peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}

peerState.AddRoute(route)
d.peers[peer] = peerState

// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}

func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock()
defer d.mux.Unlock()

peerState, ok := d.peers[peer]
if !ok {
return errors.New("peer doesn't exist")
}

peerState.DeleteRoute(route)
d.peers[peer] = peerState

// todo: consider to make sense of this notification or not
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -301,12 +325,7 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -334,12 +353,7 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand All @@ -366,12 +380,7 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -401,12 +410,7 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}

ch, found := d.changeNotify[receivedState.PubKey]
if found && ch != nil {
close(ch)
d.changeNotify[receivedState.PubKey] = nil
}

d.notifyPeerStateChangeListeners(receivedState.PubKey)
d.notifyPeerListChanged()
return nil
}
Expand Down Expand Up @@ -477,11 +481,14 @@ func (d *Status) FinishPeerListModifications() {
func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
d.mux.Lock()
defer d.mux.Unlock()

ch, found := d.changeNotify[peer]
if !found || ch == nil {
ch = make(chan struct{})
d.changeNotify[peer] = ch
if found {
return ch
}

ch = make(chan struct{})
d.changeNotify[peer] = ch
return ch
}

Expand Down Expand Up @@ -755,6 +762,17 @@ func (d *Status) onConnectionChanged() {
d.notifier.updateServerStates(d.managementState, d.signalState)
}

// notifyPeerStateChangeListeners notifies route manager about the change in peer state
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
ch, found := d.changeNotify[peerID]
if !found {
return
}

close(ch)
delete(d.changeNotify, peerID)
}

func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/peer/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {

peerState.IP = ip

err := status.UpdatePeerState(peerState)
err := status.UpdatePeerRelayedStateToDisconnected(peerState)
assert.NoError(t, err, "shouldn't return error")

select {
Expand Down
33 changes: 24 additions & 9 deletions client/internal/peer/worker_ice.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type WorkerICE struct {

localUfrag string
localPwd string

// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
}

func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
Expand Down Expand Up @@ -215,15 +218,18 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i

err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
w.conn.OnStatusChanged(StatusDisconnected)

w.muxAgent.Lock()
agentCancel()
_ = agent.Close()
w.agent = nil

w.muxAgent.Unlock()
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState != ice.ConnectionStateDisconnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.OnStatusChanged(StatusDisconnected)
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
Expand All @@ -249,6 +255,15 @@ func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []i
return agent, nil
}

func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()

cancel()
_ = w.agent.Close()
w.agent = nil
}

func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
// wait local endpoint configuration
time.Sleep(time.Second)
Expand Down
64 changes: 27 additions & 37 deletions client/internal/routemanager/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,20 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore = float64(metricDiff) * 10
}

// in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route
latency := time.Second
// in some temporal cases, latency can be 0, so we set it to 999ms to not block but try to avoid this route
latency := 999 * time.Millisecond
if peerStatus.latency != 0 {
latency = peerStatus.latency
} else {
log.Warnf("peer %s has 0 latency", r.Peer)
log.Infof("peer %s has 0 latency, range %s", r.Peer, c.handler)
}

// avoid negative tempScore on the higher latency calculation
if latency > 1*time.Second {
latency = 999 * time.Millisecond
}

// higher latency is worse score
tempScore += 1 - latency.Seconds()

if !peerStatus.relayed {
Expand All @@ -150,6 +157,8 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
}
}

log.Debugf("chosen route: %s, chosen score: %f, current route: %s, current score: %f", chosen, chosenScore, currID, currScore)

switch {
case chosen == "":
var peers []string
Expand Down Expand Up @@ -195,15 +204,20 @@ func (c *clientNetwork) watchPeerStatusChanges(ctx context.Context, peerKey stri
func (c *clientNetwork) startPeersStatusChangeWatcher() {
for _, r := range c.routes {
_, found := c.routePeersNotifiers[r.Peer]
if !found {
c.routePeersNotifiers[r.Peer] = make(chan struct{})
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, c.routePeersNotifiers[r.Peer])
if found {
continue
}

closerChan := make(chan struct{})
c.routePeersNotifiers[r.Peer] = closerChan
go c.watchPeerStatusChanges(c.ctx, r.Peer, c.peerStateUpdate, closerChan)
}
}

func (c *clientNetwork) removeRouteFromWireguardPeer() error {
c.removeStateRoute()
func (c *clientNetwork) removeRouteFromWireGuardPeer() error {
if err := c.statusRecorder.RemovePeerStateRoute(c.currentChosen.Peer, c.handler.String()); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}

if err := c.handler.RemoveAllowedIPs(); err != nil {
return fmt.Errorf("remove allowed IPs: %w", err)
Expand All @@ -218,7 +232,7 @@ func (c *clientNetwork) removeRouteFromPeerAndSystem() error {

var merr *multierror.Error

if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err))
}
if err := c.handler.RemoveRoute(); err != nil {
Expand Down Expand Up @@ -257,7 +271,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
}
} else {
// Otherwise, remove the allowed IPs from the previous peer first
if err := c.removeRouteFromWireguardPeer(); err != nil {
if err := c.removeRouteFromWireGuardPeer(); err != nil {
return fmt.Errorf("remove allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}
}
Expand All @@ -268,35 +282,11 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
return fmt.Errorf("add allowed IPs for peer %s: %w", c.currentChosen.Peer, err)
}

c.addStateRoute()

return nil
}

func (c *clientNetwork) addStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}

state.AddRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
}
}

func (c *clientNetwork) removeStateRoute() {
state, err := c.statusRecorder.GetPeer(c.currentChosen.Peer)
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String())
if err != nil {
log.Errorf("Failed to get peer state: %v", err)
return
}

state.DeleteRoute(c.handler.String())
if err := c.statusRecorder.UpdatePeerState(state); err != nil {
log.Warnf("Failed to update peer state: %v", err)
return fmt.Errorf("add peer state route: %w", err)
}
return nil
}

func (c *clientNetwork) sendUpdateToClientNetworkWatcher(update routesUpdate) {
Expand Down
Loading
Loading