Skip to content

Commit 5522935

Browse files
authored
refactor: move IP validation to the PacketListener (#229)
* refactor: move IP validation to the PacketListener * Rename `validatePacket()` to reflect what it currently does. * Reorder for consistency. * Handle the `EOF` case and stop reading from the connection. * Pass through the connection errors from the IP validator. * Address review comments. * Rename `timeout` to `natTimeout` in test. * Only resolve the `net.Addr` if it's not already a `UDPAddr`. * Use `MakeNetAddr()` instead of resolving when extracting the addr.
1 parent c9f2547 commit 5522935

File tree

6 files changed

+78
-37
lines changed

6 files changed

+78
-37
lines changed

cmd/outline-ss-server/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
268268
service.WithCiphers(ciphers),
269269
service.WithMetrics(s.serviceMetrics),
270270
service.WithReplayCache(&s.replayCache),
271-
service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, 0)),
271+
service.WithPacketListener(service.MakeTargetUDPListener(onet.RequirePublicIP, s.natTimeout, 0)),
272272
service.WithLogger(slog.Default()),
273273
)
274274
ln, err := lnSet.ListenStream(addr)
@@ -301,7 +301,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) {
301301
service.WithMetrics(s.serviceMetrics),
302302
service.WithReplayCache(&s.replayCache),
303303
service.WithStreamDialer(service.MakeValidatingTCPStreamDialer(onet.RequirePublicIP, serviceConfig.Dialer.Fwmark)),
304-
service.WithPacketListener(service.MakeTargetUDPListener(s.natTimeout, serviceConfig.Dialer.Fwmark)),
304+
service.WithPacketListener(service.MakeTargetUDPListener(onet.RequirePublicIP, s.natTimeout, serviceConfig.Dialer.Fwmark)),
305305
service.WithLogger(slog.Default()),
306306
)
307307
if err != nil {

internal/integration_test/integration_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ import (
3535
"github.com/stretchr/testify/require"
3636
)
3737

38-
const maxUDPPacketSize = 64 * 1024
38+
const (
39+
maxUDPPacketSize = 64 * 1024
40+
natTimeout = 5 * time.Minute
41+
)
3942

4043
func init() {
4144
logging.SetLevel(logging.INFO, "")
@@ -321,7 +324,7 @@ func TestUDPEcho(t *testing.T) {
321324
}
322325
proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{})
323326

324-
proxy.SetTargetIPValidator(allowAll)
327+
proxy.SetTargetPacketListener(service.MakeTargetUDPListener(allowAll, natTimeout, 0))
325328
natMetrics := &natTestMetrics{}
326329
associationMetrics := &fakeUDPAssociationMetrics{}
327330
go service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) {
@@ -548,7 +551,7 @@ func BenchmarkUDPEcho(b *testing.B) {
548551
b.Fatal(err)
549552
}
550553
proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{})
551-
proxy.SetTargetIPValidator(allowAll)
554+
proxy.SetTargetPacketListener(service.MakeTargetUDPListener(allowAll, natTimeout, 0))
552555
done := make(chan struct{})
553556
go func() {
554557
service.PacketServe(server, func(ctx context.Context, conn net.Conn) {
@@ -594,7 +597,7 @@ func BenchmarkUDPManyKeys(b *testing.B) {
594597
b.Fatal(err)
595598
}
596599
proxy := service.NewAssociationHandler(cipherList, &fakeShadowsocksMetrics{})
597-
proxy.SetTargetIPValidator(allowAll)
600+
proxy.SetTargetPacketListener(service.MakeTargetUDPListener(allowAll, natTimeout, 0))
598601
done := make(chan struct{})
599602
go func() {
600603
service.PacketServe(proxyConn, func(ctx context.Context, conn net.Conn) {

service/udp.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func NewAssociationHandler(cipherList CipherList, ssMetrics ShadowsocksConnMetri
109109
ciphers: cipherList,
110110
ssm: ssMetrics,
111111
targetIPValidator: onet.RequirePublicIP,
112-
targetListener: MakeTargetUDPListener(defaultNatTimeout, 0),
112+
targetListener: MakeTargetUDPListener(onet.RequirePublicIP, defaultNatTimeout, 0),
113113
}
114114
}
115115

@@ -118,8 +118,6 @@ type AssociationHandler interface {
118118
HandleAssociation(ctx context.Context, conn net.Conn, assocMetrics UDPAssociationMetrics)
119119
// SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
120120
SetLogger(l *slog.Logger)
121-
// SetTargetIPValidator sets the function to be used to validate the target IP addresses.
122-
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
123121
// SetTargetPacketListener sets the packet listener to use for target connections.
124122
SetTargetPacketListener(targetListener transport.PacketListener)
125123
}
@@ -131,10 +129,6 @@ func (h *associationHandler) SetLogger(l *slog.Logger) {
131129
h.logger = l
132130
}
133131

134-
func (h *associationHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
135-
h.targetIPValidator = targetIPValidator
136-
}
137-
138132
func (h *associationHandler) SetTargetPacketListener(targetListener transport.PacketListener) {
139133
h.targetListener = targetListener
140134
}
@@ -176,7 +170,7 @@ func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn n
176170
}
177171

178172
var payload []byte
179-
var tgtUDPAddr *net.UDPAddr
173+
var tgtAddr net.Addr
180174
if targetConn == nil {
181175
ip := clientConn.RemoteAddr().(*net.UDPAddr).AddrPort().Addr()
182176
var textData []byte
@@ -194,7 +188,7 @@ func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn n
194188
assocMetrics.AddAuthentication(keyID)
195189

196190
var onetErr *onet.ConnectionError
197-
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
191+
if payload, tgtAddr, onetErr = h.extractPayloadAndDestination(textData); onetErr != nil {
198192
return onetErr
199193
}
200194

@@ -219,15 +213,15 @@ func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn n
219213
}
220214

221215
var onetErr *onet.ConnectionError
222-
if payload, tgtUDPAddr, onetErr = h.validatePacket(textData); onetErr != nil {
216+
if payload, tgtAddr, onetErr = h.extractPayloadAndDestination(textData); onetErr != nil {
223217
return onetErr
224218
}
225219
}
226220

227221
debugUDP(l, "Proxy exit.")
228-
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
222+
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtAddr)
229223
if err != nil {
230-
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
224+
return ensureConnectionError(err, "ERR_WRITE", "Failed to write to target")
231225
}
232226
return nil
233227
}()
@@ -246,21 +240,17 @@ func (h *associationHandler) HandleAssociation(ctx context.Context, clientConn n
246240
}
247241
}
248242

249-
// Given the decrypted contents of a UDP packet, return
250-
// the payload and the destination address, or an error if
251-
// this packet cannot or should not be forwarded.
252-
func (h *associationHandler) validatePacket(textData []byte) ([]byte, *net.UDPAddr, *onet.ConnectionError) {
243+
// extractPayloadAndDestination processes a decrypted Shadowsocks UDP packet and
244+
// extracts the payload data and destination address.
245+
func (h *associationHandler) extractPayloadAndDestination(textData []byte) ([]byte, net.Addr, *onet.ConnectionError) {
253246
tgtAddr := socks.SplitAddr(textData)
254247
if tgtAddr == nil {
255248
return nil, nil, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", nil)
256249
}
257250

258-
tgtUDPAddr, err := net.ResolveUDPAddr("udp", tgtAddr.String())
251+
tgtUDPAddr, err := transport.MakeNetAddr("udp", tgtAddr.String())
259252
if err != nil {
260-
return nil, nil, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", tgtAddr), err)
261-
}
262-
if err := h.targetIPValidator(tgtUDPAddr.IP); err != nil {
263-
return nil, nil, ensureConnectionError(err, "ERR_ADDRESS_INVALID", "invalid address")
253+
return nil, nil, onet.NewConnectionError("ERR_CONVERT_ADDRESS", fmt.Sprintf("Failed to convert target address %v", tgtAddr), err)
264254
}
265255

266256
payload := textData[len(tgtAddr):]
@@ -408,6 +398,23 @@ func isDNS(addr net.Addr) bool {
408398
return port == "53"
409399
}
410400

401+
type validatingPacketConn struct {
402+
net.PacketConn
403+
targetIPValidator onet.TargetIPValidator
404+
}
405+
406+
func (vpc *validatingPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
407+
udpAddr, err := net.ResolveUDPAddr("udp", addr.String())
408+
if err != nil {
409+
return 0, onet.NewConnectionError("ERR_RESOLVE_ADDRESS", fmt.Sprintf("Failed to resolve target address %v", udpAddr), err)
410+
}
411+
if err := vpc.targetIPValidator(udpAddr.IP); err != nil {
412+
return 0, ensureConnectionError(err, "ERR_ADDRESS_INVALID", "invalid address")
413+
}
414+
415+
return vpc.PacketConn.WriteTo(p, udpAddr) // accept only `net.UDPAddr` despite the signature
416+
}
417+
411418
type timedPacketConn struct {
412419
net.PacketConn
413420
// Connection timeout to apply for non-DNS packets.

service/udp_linux.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ import (
2323
"time"
2424

2525
"github.com/Jigsaw-Code/outline-sdk/transport"
26+
27+
onet "github.com/Jigsaw-Code/outline-ss-server/net"
2628
)
2729

2830
type udpListener struct {
31+
// The validator to be used to validate target IP addresses.
32+
targetIPValidator onet.TargetIPValidator
33+
2934
// NAT mapping timeout is the default time a mapping will stay active
3035
// without packets traversing the NAT, applied to non-DNS packets.
3136
timeout time.Duration
37+
3238
// fwmark can be used in conjunction with other Linux networking features like cgroups, network
3339
// namespaces, and TC (Traffic Control) for sophisticated network management.
3440
// Value of 0 disables fwmark (SO_MARK) (Linux only)
@@ -37,14 +43,14 @@ type udpListener struct {
3743

3844
// NewPacketListener creates a new PacketListener that listens on UDP
3945
// and optionally sets a firewall mark on the socket (Linux only).
40-
func MakeTargetUDPListener(timeout time.Duration, fwmark uint) transport.PacketListener {
41-
return &udpListener{timeout: timeout, fwmark: fwmark}
46+
func MakeTargetUDPListener(targetIPValidator onet.TargetIPValidator, timeout time.Duration, fwmark uint) transport.PacketListener {
47+
return &udpListener{timeout: timeout, targetIPValidator: targetIPValidator, fwmark: fwmark}
4248
}
4349

4450
func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
4551
conn, err := net.ListenUDP("udp", nil)
4652
if err != nil {
47-
return nil, fmt.Errorf("Failed to create UDP socket: %w", err)
53+
return nil, fmt.Errorf("failed to create UDP socket: %w", err)
4854
}
4955

5056
if ln.fwmark > 0 {
@@ -57,9 +63,12 @@ func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error)
5763
err = SetFwmark(rawConn, ln.fwmark)
5864
if err != nil {
5965
conn.Close()
60-
return nil, fmt.Errorf("Failed to set `fwmark`: %w", err)
66+
return nil, fmt.Errorf("failed to set `fwmark`: %w", err)
6167

6268
}
6369
}
64-
return &timedPacketConn{PacketConn: conn, defaultTimeout: ln.timeout}, nil
70+
return &validatingPacketConn{
71+
PacketConn: &timedPacketConn{PacketConn: conn, defaultTimeout: ln.timeout},
72+
targetIPValidator: ln.targetIPValidator,
73+
}, nil
6574
}

service/udp_other.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,41 @@ import (
2222
"time"
2323

2424
"github.com/Jigsaw-Code/outline-sdk/transport"
25+
26+
onet "github.com/Jigsaw-Code/outline-ss-server/net"
2527
)
2628

2729
type udpListener struct {
2830
*transport.UDPListener
2931

32+
// The validator to be used to validate target IP addresses.
33+
targetIPValidator onet.TargetIPValidator
34+
3035
// NAT mapping timeout is the default time a mapping will stay active
3136
// without packets traversing the NAT, applied to non-DNS packets.
3237
timeout time.Duration
3338
}
3439

3540
// fwmark can be used in conjunction with other Linux networking features like cgroups, network namespaces, and TC (Traffic Control) for sophisticated network management.
3641
// Value of 0 disables fwmark (SO_MARK)
37-
func MakeTargetUDPListener(timeout time.Duration, fwmark uint) transport.PacketListener {
42+
func MakeTargetUDPListener(targetIPValidator onet.TargetIPValidator, timeout time.Duration, fwmark uint) transport.PacketListener {
3843
if fwmark != 0 {
3944
panic("fwmark is linux-specific feature and should be 0")
4045
}
41-
return &udpListener{UDPListener: &transport.UDPListener{Address: ""}}
46+
return &udpListener{
47+
targetIPValidator: targetIPValidator,
48+
timeout: timeout,
49+
UDPListener: &transport.UDPListener{Address: ""},
50+
}
4251
}
4352

4453
func (ln *udpListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
4554
conn, err := ln.UDPListener.ListenPacket(ctx)
4655
if err != nil {
4756
return nil, err
4857
}
49-
return &timedPacketConn{PacketConn: conn, defaultTimeout: ln.timeout}, nil
58+
return &validatingPacketConn{
59+
PacketConn: &timedPacketConn{PacketConn: conn, defaultTimeout: ln.timeout},
60+
targetIPValidator: ln.targetIPValidator,
61+
}, nil
5062
}

service/udp_test.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"testing"
2525
"time"
2626

27+
"github.com/Jigsaw-Code/outline-sdk/transport"
2728
"github.com/Jigsaw-Code/outline-sdk/transport/shadowsocks"
2829
logging "github.com/op/go-logging"
2930
"github.com/shadowsocks/go-shadowsocks2/socks"
@@ -62,6 +63,15 @@ func (ln *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, err
6263
return ln.conn, nil
6364
}
6465

66+
func WrapWithValidatingPacketListener(conn net.PacketConn, targetIPValidator onet.TargetIPValidator) transport.PacketListener {
67+
return &packetListener{
68+
&validatingPacketConn{
69+
PacketConn: conn,
70+
targetIPValidator: targetIPValidator,
71+
},
72+
}
73+
}
74+
6575
type fakePacketConn struct {
6676
net.PacketConn
6777
send chan fakePacket
@@ -225,7 +235,7 @@ func TestAssociationCloseWhileReading(t *testing.T) {
225235
func TestAssociationHandler_Handle_IPFilter(t *testing.T) {
226236
t.Run("RequirePublicIP blocks localhost", func(t *testing.T) {
227237
handler, sendPayload, targetConn := startTestHandler()
228-
handler.SetTargetIPValidator(onet.RequirePublicIP)
238+
handler.SetTargetPacketListener(WrapWithValidatingPacketListener(targetConn, onet.RequirePublicIP))
229239

230240
sendPayload(&localAddr, []byte{1, 2, 3})
231241

@@ -239,7 +249,7 @@ func TestAssociationHandler_Handle_IPFilter(t *testing.T) {
239249

240250
t.Run("allowAll allows localhost", func(t *testing.T) {
241251
handler, sendPayload, targetConn := startTestHandler()
242-
handler.SetTargetIPValidator(allowAll)
252+
handler.SetTargetPacketListener(WrapWithValidatingPacketListener(targetConn, allowAll))
243253

244254
sendPayload(&localAddr, []byte{1, 2, 3})
245255

0 commit comments

Comments
 (0)