Skip to content

Commit 0fde4dc

Browse files
MarcoPolosukunrt
authored andcommitted
Dial from your own listener
1 parent 9c49960 commit 0fde4dc

File tree

7 files changed

+156
-25
lines changed

7 files changed

+156
-25
lines changed

libp2p_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"crypto/rand"
66
"errors"
77
"fmt"
8+
"io"
89
"net"
910
"net/netip"
1011
"regexp"
@@ -587,3 +588,70 @@ func TestWebRTCReuseAddrWithQUIC(t *testing.T) {
587588
require.Contains(t, h1.Addrs()[0].String(), "quic-v1")
588589
})
589590
}
591+
592+
func TestUseCorrectTransportForDialOut(t *testing.T) {
593+
listAddrOrder := [][]string{
594+
{"/ip4/127.0.0.1/udp/0/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1/webtransport"},
595+
{"/ip4/127.0.0.1/udp/0/quic-v1/webtransport", "/ip4/127.0.0.1/udp/0/quic-v1"},
596+
{"/ip4/0.0.0.0/udp/0/quic-v1", "/ip4/0.0.0.0/udp/0/quic-v1/webtransport"},
597+
{"/ip4/0.0.0.0/udp/0/quic-v1/webtransport", "/ip4/0.0.0.0/udp/0/quic-v1"},
598+
}
599+
for _, order := range listAddrOrder {
600+
h1, err := New(ListenAddrStrings(order...), Transport(quic.NewTransport), Transport(webtransport.New))
601+
require.NoError(t, err)
602+
t.Cleanup(func() {
603+
h1.Close()
604+
})
605+
606+
go func() {
607+
h1.SetStreamHandler("/echo-port", func(s network.Stream) {
608+
m := s.Conn().RemoteMultiaddr()
609+
v, err := m.ValueForProtocol(ma.P_UDP)
610+
if err != nil {
611+
s.Reset()
612+
return
613+
}
614+
s.Write([]byte(v))
615+
s.Close()
616+
})
617+
}()
618+
619+
for _, addr := range h1.Addrs() {
620+
t.Run("order "+strings.Join(order, ",")+" Dial to "+addr.String(), func(t *testing.T) {
621+
h2, err := New(ListenAddrStrings(
622+
"/ip4/0.0.0.0/udp/0/quic-v1",
623+
"/ip4/0.0.0.0/udp/0/quic-v1/webtransport",
624+
), Transport(quic.NewTransport), Transport(webtransport.New))
625+
require.NoError(t, err)
626+
defer h2.Close()
627+
t.Log("H2 Addrs", h2.Addrs())
628+
var myExpectedDialOutAddr ma.Multiaddr
629+
addrIsWT, _ := webtransport.IsWebtransportMultiaddr(addr)
630+
isLocal := func(a ma.Multiaddr) bool {
631+
return strings.Contains(a.String(), "127.0.0.1")
632+
}
633+
addrIsLocal := isLocal(addr)
634+
for _, a := range h2.Addrs() {
635+
aIsWT, _ := webtransport.IsWebtransportMultiaddr(a)
636+
if addrIsWT == aIsWT && isLocal(a) == addrIsLocal {
637+
myExpectedDialOutAddr = a
638+
break
639+
}
640+
}
641+
642+
err = h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{addr}})
643+
require.NoError(t, err)
644+
645+
s, err := h2.NewStream(context.Background(), h1.ID(), "/echo-port")
646+
require.NoError(t, err)
647+
648+
port, err := io.ReadAll(s)
649+
require.NoError(t, err)
650+
651+
myExpectedPort, err := myExpectedDialOutAddr.ValueForProtocol(ma.P_UDP)
652+
require.NoError(t, err)
653+
require.Equal(t, myExpectedPort, string(port))
654+
})
655+
}
656+
}
657+
}

p2p/transport/quic/transport.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
136136
}
137137

138138
tlsConf, keyCh := t.identity.ConfigForPeer(p)
139+
ctx = quicreuse.WithAssociation(ctx, t)
139140
pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease)
140141
if err != nil {
141142
return nil, err
@@ -196,7 +197,7 @@ func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID
196197
if err != nil {
197198
return nil, err
198199
}
199-
tr, err := t.connManager.TransportForDial(network, addr)
200+
tr, err := t.connManager.TransportWithAssociationForDial(t, network, addr)
200201
if err != nil {
201202
return nil, err
202203
}
@@ -313,7 +314,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
313314
return nil, fmt.Errorf("can't listen on quic version %v, underlying listener doesn't support it", version)
314315
}
315316
} else {
316-
ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease)
317+
ln, err := t.connManager.ListenQUICAndAssociate(t, addr, &tlsConf, t.allowWindowIncrease)
317318
if err != nil {
318319
return nil, err
319320
}

p2p/transport/quicreuse/connmgr.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) {
102102
}
103103

104104
func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
105+
return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease)
106+
}
107+
108+
// ListenQUICAndAssociate returns a QUIC listener and associates the underlying transport with the given association.
109+
func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
105110
netw, host, err := manet.DialArgs(addr)
106111
if err != nil {
107112
return nil, err
@@ -117,7 +122,7 @@ func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWi
117122
key := laddr.String()
118123
entry, ok := c.quicListeners[key]
119124
if !ok {
120-
tr, err := c.transportForListen(netw, laddr)
125+
tr, err := c.transportForListen(association, netw, laddr)
121126
if err != nil {
122127
return nil, err
123128
}
@@ -176,13 +181,18 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr
176181
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
177182
}
178183

179-
func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
184+
func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) {
180185
if c.enableReuseport {
181186
reuse, err := c.getReuse(network)
182187
if err != nil {
183188
return nil, err
184189
}
185-
return reuse.TransportForListen(network, laddr)
190+
tr, err := reuse.TransportForListen(network, laddr)
191+
if err != nil {
192+
return nil, err
193+
}
194+
tr.associate(association)
195+
return tr, nil
186196
}
187197

188198
conn, err := net.ListenUDP(network, laddr)
@@ -199,6 +209,14 @@ func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (re
199209
}, nil
200210
}
201211

212+
type associationKey struct{}
213+
214+
// WithAssociation returns a new context with the given association. Used in
215+
// DialQUIC to prefer a transport that has the given association.
216+
func WithAssociation(ctx context.Context, association any) context.Context {
217+
return context.WithValue(ctx, associationKey{}, association)
218+
}
219+
202220
func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) {
203221
naddr, v, err := FromQuicMultiaddr(raddr)
204222
if err != nil {
@@ -219,7 +237,12 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
219237
return nil, errors.New("unknown QUIC version")
220238
}
221239

222-
tr, err := c.TransportForDial(netw, naddr)
240+
var tr refCountedQuicTransport
241+
if association := ctx.Value(associationKey{}); association != nil {
242+
tr, err = c.TransportWithAssociationForDial(association, netw, naddr)
243+
} else {
244+
tr, err = c.TransportForDial(netw, naddr)
245+
}
223246
if err != nil {
224247
return nil, err
225248
}
@@ -232,12 +255,17 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf
232255
}
233256

234257
func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
258+
return c.TransportWithAssociationForDial(nil, network, raddr)
259+
}
260+
261+
// TransportWithAssociationForDial returns a QUIC transport for dialing, preferring a transport with the given association.
262+
func (c *ConnManager) TransportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) {
235263
if c.enableReuseport {
236264
reuse, err := c.getReuse(network)
237265
if err != nil {
238266
return nil, err
239267
}
240-
return reuse.TransportForDial(network, raddr)
268+
return reuse.transportWithAssociationForDial(association, network, raddr)
241269
}
242270

243271
var laddr *net.UDPAddr

p2p/transport/quicreuse/connmgr_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ func testListenOnSameProto(t *testing.T, enableReuseport bool) {
6161

6262
const alpn = "proto"
6363

64-
var tlsConf tls.Config
65-
tlsConf.NextProtos = []string{alpn}
6664
ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil)
6765
require.NoError(t, err)
6866
defer ln1.Close()
@@ -96,7 +94,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {
9694

9795
_, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil)
9896
require.NoError(t, err)
99-
quicTr, err := cm.transportForListen(netw, naddr)
97+
quicTr, err := cm.transportForListen(nil, netw, naddr)
10098
require.NoError(t, err)
10199
defer quicTr.Close()
102100
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {

p2p/transport/quicreuse/reuse.go

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,36 @@ type refcountedTransport struct {
6969
mutex sync.Mutex
7070
refCount int
7171
unusedSince time.Time
72+
73+
assocations map[any]struct{}
74+
}
75+
76+
// associate an arbitrary value with this transport.
77+
// This lets us "tag" the refcountedTransport when listening so we can use it
78+
// later for dialing. Necessary for holepunching and learning about our own
79+
// observed listening address.
80+
func (c *refcountedTransport) associate(a any) {
81+
if a == nil {
82+
return
83+
}
84+
c.mutex.Lock()
85+
defer c.mutex.Unlock()
86+
if c.assocations == nil {
87+
c.assocations = make(map[any]struct{})
88+
}
89+
c.assocations[a] = struct{}{}
90+
}
91+
92+
// hasAssociation returns true if the transport has the given association.
93+
// If it is a nil association, it will always return true.
94+
func (c *refcountedTransport) hasAssociation(a any) bool {
95+
if a == nil {
96+
return true
97+
}
98+
c.mutex.Lock()
99+
defer c.mutex.Unlock()
100+
_, ok := c.assocations[a]
101+
return ok
72102
}
73103

74104
func (c *refcountedTransport) IncreaseCount() {
@@ -204,7 +234,7 @@ func (r *reuse) gc() {
204234
}
205235
}
206236

207-
func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
237+
func (r *reuse) transportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (*refcountedTransport, error) {
208238
var ip *net.IP
209239

210240
// Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners.
@@ -224,29 +254,34 @@ func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcounte
224254
r.mutex.Lock()
225255
defer r.mutex.Unlock()
226256

227-
tr, err := r.transportForDialLocked(network, ip)
257+
tr, err := r.transportForDialLocked(association, network, ip)
228258
if err != nil {
229259
return nil, err
230260
}
231261
tr.IncreaseCount()
232262
return tr, nil
233263
}
234264

235-
func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcountedTransport, error) {
265+
func (r *reuse) transportForDialLocked(association any, network string, source *net.IP) (*refcountedTransport, error) {
236266
if source != nil {
237267
// We already have at least one suitable transport...
238268
if trs, ok := r.unicast[source.String()]; ok {
239-
// ... we don't care which port we're dialing from. Just use the first.
269+
// Prefer a transport that has the given association. We want to
270+
// reuse the transport the association used for listening.
240271
for _, tr := range trs {
241-
return tr, nil
272+
if tr.hasAssociation(association) {
273+
return tr, nil
274+
}
242275
}
243276
}
244277
}
245278

246279
// Use a transport listening on 0.0.0.0 (or ::).
247-
// Again, we don't care about the port number.
280+
// Again, prefer a transport that has the given association.
248281
for _, tr := range r.globalListeners {
249-
return tr, nil
282+
if tr.hasAssociation(association) {
283+
return tr, nil
284+
}
250285
}
251286

252287
// Use a transport we've previously dialed from

p2p/transport/quicreuse/reuse_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) {
9191

9292
addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
9393
require.NoError(t, err)
94-
conn, err := reuse.TransportForDial("udp4", addr)
94+
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", addr)
9595
require.NoError(t, err)
9696
require.Equal(t, 1, conn.GetCount())
9797
laddr := conn.LocalAddr().(*net.UDPAddr)
@@ -111,7 +111,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) {
111111
// dial
112112
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
113113
require.NoError(t, err)
114-
conn, err := reuse.TransportForDial("udp4", raddr)
114+
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
115115
require.NoError(t, err)
116116
require.Equal(t, 2, conn.GetCount())
117117
}
@@ -122,7 +122,7 @@ func TestReuseConnectionWhenListening(t *testing.T) {
122122

123123
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
124124
require.NoError(t, err)
125-
tr, err := reuse.TransportForDial("udp4", raddr)
125+
tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
126126
require.NoError(t, err)
127127
laddr := &net.UDPAddr{IP: net.IPv4zero, Port: tr.LocalAddr().(*net.UDPAddr).Port}
128128
lconn, err := reuse.TransportForListen("udp4", laddr)
@@ -138,7 +138,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
138138
// dial any address
139139
raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234")
140140
require.NoError(t, err)
141-
rTr, err := reuse.TransportForDial("udp4", raddr)
141+
rTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
142142
require.NoError(t, err)
143143

144144
// open a listener
@@ -149,7 +149,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) {
149149
// new dials should go via the listener connection
150150
raddr, err = net.ResolveUDPAddr("udp4", "1.1.1.1:1235")
151151
require.NoError(t, err)
152-
tr, err := reuse.TransportForDial("udp4", raddr)
152+
tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
153153
require.NoError(t, err)
154154
require.Equal(t, lTr, tr)
155155
require.Equal(t, 2, tr.GetCount())
@@ -183,7 +183,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) {
183183
require.NoError(t, err)
184184
require.Equal(t, 1, lconn.GetCount())
185185
// dial
186-
conn, err := reuse.TransportForDial("udp4", raddr)
186+
conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
187187
require.NoError(t, err)
188188
require.Equal(t, 1, conn.GetCount())
189189
}
@@ -214,7 +214,7 @@ func TestReuseGarbageCollect(t *testing.T) {
214214

215215
raddr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:1234")
216216
require.NoError(t, err)
217-
dTr, err := reuse.TransportForDial("udp4", raddr)
217+
dTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr)
218218
require.NoError(t, err)
219219
require.Equal(t, 1, dTr.GetCount())
220220

p2p/transport/webtransport/transport.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
207207
return verifyRawCerts(rawCerts, certHashes)
208208
}
209209
}
210+
ctx = quicreuse.WithAssociation(ctx, t)
210211
conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease)
211212
if err != nil {
212213
return nil, nil, err
@@ -331,7 +332,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) {
331332
}
332333
tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3)
333334

334-
ln, err := t.connManager.ListenQUIC(laddr, tlsConf, t.allowWindowIncrease)
335+
ln, err := t.connManager.ListenQUICAndAssociate(t, laddr, tlsConf, t.allowWindowIncrease)
335336
if err != nil {
336337
return nil, err
337338
}

0 commit comments

Comments
 (0)