Skip to content

Commit 5b7016e

Browse files
committed
add support for websocket
1 parent 235a2da commit 5b7016e

File tree

11 files changed

+207
-17
lines changed

11 files changed

+207
-17
lines changed

core/network/conn.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type ConnErrorCode uint32
1616

1717
type ConnError struct {
1818
Remote bool
19-
ErrorCode uint32
19+
ErrorCode ConnErrorCode
2020
}
2121

2222
func (c *ConnError) Error() string {

core/network/mux.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ type MuxedConn interface {
109109
AcceptStream() (MuxedStream, error)
110110
}
111111

112-
type ConnWithErrorer interface {
112+
type CloseWithErrorer interface {
113113
// CloseWithError closes the connection with errCode. The errCode is sent
114114
// to the peer.
115-
ConnWithError(errCode ConnErrorCode) error
115+
CloseWithError(errCode ConnErrorCode) error
116116
}
117117

118118
// Multiplexer wraps a net.Conn with a stream multiplexing

p2p/muxer/yamux/conn.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (c *conn) IsClosed() bool {
3636
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
3737
s, err := c.yamux().OpenStream(ctx)
3838
if err != nil {
39-
return nil, err
39+
return nil, parseResetError(err)
4040
}
4141

4242
return (*stream)(s), nil
@@ -45,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
4545
// AcceptStream accepts a stream opened by the other side.
4646
func (c *conn) AcceptStream() (network.MuxedStream, error) {
4747
s, err := c.yamux().AcceptStream()
48-
return (*stream)(s), err
48+
return (*stream)(s), parseResetError(err)
4949
}
5050

5151
func (c *conn) yamux() *yamux.Session {

p2p/muxer/yamux/stream.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ func parseResetError(err error) error {
1818
if err == nil {
1919
return err
2020
}
21-
if errors.Is(err, yamux.ErrStreamReset) {
22-
se := &yamux.StreamError{}
23-
if errors.As(err, &se) {
24-
return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)}
25-
}
21+
se := &yamux.StreamError{}
22+
if errors.As(err, &se) {
23+
return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)}
24+
}
25+
ce := &yamux.GoAwayError{}
26+
if errors.As(err, &ce) {
27+
return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode)}
2628
}
2729
return err
2830
}

p2p/net/swarm/swarm.go

+8
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,14 @@ func (c connWithMetrics) Close() error {
838838
return c.CapableConn.Close()
839839
}
840840

841+
func (c connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error {
842+
c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr())
843+
if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok {
844+
return ce.CloseWithError(errCode)
845+
}
846+
return c.CapableConn.Close()
847+
}
848+
841849
func (c connWithMetrics) Stat() network.ConnStats {
842850
if cs, ok := c.CapableConn.(network.ConnStat); ok {
843851
return cs.Stat()

p2p/net/swarm/swarm_conn.go

+1-3
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ func (c *Conn) doClose(errCode network.ConnErrorCode) {
8181
c.streams.Unlock()
8282

8383
if errCode != 0 {
84-
if ce, ok := c.conn.(interface {
85-
CloseWithError(network.ConnErrorCode) error
86-
}); ok {
84+
if ce, ok := c.conn.(network.CloseWithErrorer); ok {
8785
c.err = ce.CloseWithError(errCode)
8886
} else {
8987
c.err = c.conn.Close()

p2p/net/upgrader/conn.go

+7
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,10 @@ func (t *transportConn) ConnState() network.ConnectionState {
6363
UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation,
6464
}
6565
}
66+
67+
func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error {
68+
if ce, ok := t.MuxedConn.(network.CloseWithErrorer); ok {
69+
return ce.CloseWithError(errCode)
70+
}
71+
return t.Close()
72+
}

p2p/test/transport/transport_test.go

+160-2
Original file line numberDiff line numberDiff line change
@@ -804,8 +804,8 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
804804
func TestStreamErrorCode(t *testing.T) {
805805
for _, tc := range transportsToTest {
806806
t.Run(tc.Name, func(t *testing.T) {
807-
if tc.Name != "QUIC" && tc.Name != "TCP / TLS / Yamux" && tc.Name != "WebRTC" {
808-
t.Skipf("skipping: %s, only implemented for QUIC", tc.Name)
807+
if tc.Name == "WebTransport" {
808+
t.Skipf("skipping: %s, not implemented", tc.Name)
809809
return
810810
}
811811
server := tc.HostGenerator(t, TransportTestCaseOpts{})
@@ -841,6 +841,9 @@ func TestStreamErrorCode(t *testing.T) {
841841
}
842842
_, err = s.Read(b)
843843
errCh <- err
844+
845+
_, err = s.Write(b)
846+
errCh <- err
844847
})
845848
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
846849
defer cancel()
@@ -865,8 +868,163 @@ func TestStreamErrorCode(t *testing.T) {
865868
_, err = s.Write(buf)
866869
checkError(err, 42, false)
867870

871+
err = <-errCh // read error
872+
checkError(err, 42, true)
873+
874+
err = <-errCh // write error
875+
checkError(err, 42, true)
876+
})
877+
}
878+
}
879+
880+
// TestStreamErrorCodeConnClosed tests correctness for resetting stream with errors
881+
func TestStreamErrorCodeConnClosed(t *testing.T) {
882+
for _, tc := range transportsToTest {
883+
t.Run(tc.Name, func(t *testing.T) {
884+
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
885+
t.Skipf("skipping: %s, not implemented", tc.Name)
886+
return
887+
}
888+
server := tc.HostGenerator(t, TransportTestCaseOpts{})
889+
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
890+
defer server.Close()
891+
defer client.Close()
892+
893+
checkError := func(err error, code network.ConnErrorCode, remote bool) {
894+
t.Helper()
895+
if err == nil {
896+
t.Fatal("expected non nil error")
897+
}
898+
ce := &network.ConnError{}
899+
if errors.As(err, &ce) {
900+
require.Equal(t, code, ce.ErrorCode)
901+
require.Equal(t, remote, ce.Remote)
902+
return
903+
}
904+
t.Fatal("expected network.ConnError, got:", err)
905+
}
906+
907+
errCh := make(chan error)
908+
server.SetStreamHandler("/test", func(s network.Stream) {
909+
defer s.Reset()
910+
b := make([]byte, 10)
911+
n, err := s.Read(b)
912+
if !assert.NoError(t, err) {
913+
return
914+
}
915+
_, err = s.Write(b[:n])
916+
if !assert.NoError(t, err) {
917+
return
918+
}
919+
_, err = s.Read(b)
920+
errCh <- err
921+
922+
_, err = s.Write(b)
923+
errCh <- err
924+
})
925+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
926+
defer cancel()
927+
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
928+
s, err := client.NewStream(ctx, server.ID(), "/test")
929+
require.NoError(t, err)
930+
931+
_, err = s.Write([]byte("hello"))
932+
require.NoError(t, err)
933+
934+
buf := make([]byte, 10)
935+
n, err := s.Read(buf)
936+
require.NoError(t, err)
937+
require.Equal(t, []byte("hello"), buf[:n])
938+
939+
err = s.Conn().CloseWithError(42)
940+
require.NoError(t, err)
941+
942+
_, err = s.Read(buf)
943+
checkError(err, 42, false)
944+
945+
_, err = s.Write(buf)
946+
checkError(err, 42, false)
947+
948+
err = <-errCh
949+
checkError(err, 42, true)
950+
951+
err = <-errCh
952+
checkError(err, 42, true)
953+
})
954+
}
955+
}
956+
957+
// TestConnectionErrorCode tests correctness for resetting stream with errors
958+
func TestConnectionErrorCode(t *testing.T) {
959+
for _, tc := range transportsToTest {
960+
t.Run(tc.Name, func(t *testing.T) {
961+
if tc.Name == "WebTransport" || tc.Name == "WebRTC" {
962+
t.Skipf("skipping: %s, not implemented", tc.Name)
963+
return
964+
}
965+
server := tc.HostGenerator(t, TransportTestCaseOpts{})
966+
client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
967+
defer server.Close()
968+
defer client.Close()
969+
970+
checkError := func(err error, code network.ConnErrorCode, remote bool) {
971+
t.Helper()
972+
if err == nil {
973+
t.Fatal("expected non nil error")
974+
}
975+
ce := &network.ConnError{}
976+
if errors.As(err, &ce) {
977+
require.Equal(t, code, ce.ErrorCode)
978+
require.Equal(t, remote, ce.Remote)
979+
return
980+
}
981+
t.Fatal("expected network.ConnError, got:", err)
982+
}
983+
984+
errCh := make(chan error)
985+
server.SetStreamHandler("/test", func(s network.Stream) {
986+
defer s.Reset()
987+
b := make([]byte, 10)
988+
n, err := s.Read(b)
989+
if !assert.NoError(t, err) {
990+
return
991+
}
992+
_, err = s.Write(b[:n])
993+
if !assert.NoError(t, err) {
994+
return
995+
}
996+
997+
_, err = s.Read(b)
998+
if !assert.Error(t, err) {
999+
return
1000+
}
1001+
_, err = s.Conn().NewStream(context.Background())
1002+
errCh <- err
1003+
})
1004+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
1005+
defer cancel()
1006+
client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
1007+
s, err := client.NewStream(ctx, server.ID(), "/test")
1008+
require.NoError(t, err)
1009+
1010+
_, err = s.Write([]byte("hello"))
1011+
require.NoError(t, err)
1012+
1013+
buf := make([]byte, 10)
1014+
n, err := s.Read(buf)
1015+
require.NoError(t, err)
1016+
require.Equal(t, []byte("hello"), buf[:n])
1017+
1018+
err = s.Conn().CloseWithError(42)
1019+
require.NoError(t, err)
1020+
1021+
str, err := s.Conn().NewStream(context.Background())
1022+
require.Nil(t, str)
1023+
checkError(err, 42, false)
1024+
8681025
err = <-errCh
8691026
checkError(err, 42, true)
1027+
8701028
})
8711029
}
8721030
}

p2p/transport/quic/conn.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
6161
func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
6262
qstr, err := c.quicConn.OpenStreamSync(ctx)
6363
if err != nil {
64-
return nil, err
64+
return nil, parseStreamError(err)
6565
}
6666
return &stream{Stream: qstr}, nil
6767
}
@@ -70,7 +70,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) {
7070
func (c *conn) AcceptStream() (network.MuxedStream, error) {
7171
qstr, err := c.quicConn.AcceptStream(context.Background())
7272
if err != nil {
73-
return nil, err
73+
return nil, parseStreamError(err)
7474
}
7575
return &stream{Stream: qstr}, nil
7676
}

p2p/transport/quic/stream.go

+9
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,22 @@ func parseStreamError(err error) error {
2727
if errors.As(err, &se) {
2828
code := se.ErrorCode
2929
if code > math.MaxUint32 {
30+
// TODO(sukunrt): do we need this?
3031
code = reset
3132
}
3233
err = &network.StreamError{
3334
ErrorCode: network.StreamErrorCode(code),
3435
Remote: se.Remote,
3536
}
3637
}
38+
ae := &quic.ApplicationError{}
39+
if errors.As(err, &ae) {
40+
code := ae.ErrorCode
41+
err = &network.ConnError{
42+
ErrorCode: network.ConnErrorCode(code),
43+
Remote: ae.Remote,
44+
}
45+
}
3746
return err
3847
}
3948

p2p/transport/websocket/conn.go

+8
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,11 @@ func (c *capableConn) ConnState() network.ConnectionState {
162162
cs.Transport = "websocket"
163163
return cs
164164
}
165+
166+
// CloseWithError implements network.CloseWithErrorer
167+
func (c *capableConn) CloseWithError(errCode network.ConnErrorCode) error {
168+
if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok {
169+
return ce.CloseWithError(errCode)
170+
}
171+
return c.Close()
172+
}

0 commit comments

Comments
 (0)