@@ -804,8 +804,8 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
804
804
func TestStreamErrorCode (t * testing.T ) {
805
805
for _ , tc := range transportsToTest {
806
806
t .Run (tc .Name , func (t * testing.T ) {
807
- if tc .Name != "QUIC" && tc . Name ! = "TCP / TLS / Yamux" && tc . Name != "WebRTC " {
808
- t .Skipf ("skipping: %s, only implemented for QUIC " , tc .Name )
807
+ if tc .Name = = "WebTransport " {
808
+ t .Skipf ("skipping: %s, not implemented" , tc .Name )
809
809
return
810
810
}
811
811
server := tc .HostGenerator (t , TransportTestCaseOpts {})
@@ -841,6 +841,9 @@ func TestStreamErrorCode(t *testing.T) {
841
841
}
842
842
_ , err = s .Read (b )
843
843
errCh <- err
844
+
845
+ _ , err = s .Write (b )
846
+ errCh <- err
844
847
})
845
848
ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
846
849
defer cancel ()
@@ -865,8 +868,163 @@ func TestStreamErrorCode(t *testing.T) {
865
868
_ , err = s .Write (buf )
866
869
checkError (err , 42 , false )
867
870
871
+ err = <- errCh // read error
872
+ checkError (err , 42 , true )
873
+
874
+ err = <- errCh // write error
875
+ checkError (err , 42 , true )
876
+ })
877
+ }
878
+ }
879
+
880
+ // TestStreamErrorCodeConnClosed tests correctness for resetting stream with errors
881
+ func TestStreamErrorCodeConnClosed (t * testing.T ) {
882
+ for _ , tc := range transportsToTest {
883
+ t .Run (tc .Name , func (t * testing.T ) {
884
+ if tc .Name == "WebTransport" || tc .Name == "WebRTC" {
885
+ t .Skipf ("skipping: %s, not implemented" , tc .Name )
886
+ return
887
+ }
888
+ server := tc .HostGenerator (t , TransportTestCaseOpts {})
889
+ client := tc .HostGenerator (t , TransportTestCaseOpts {NoListen : true })
890
+ defer server .Close ()
891
+ defer client .Close ()
892
+
893
+ checkError := func (err error , code network.ConnErrorCode , remote bool ) {
894
+ t .Helper ()
895
+ if err == nil {
896
+ t .Fatal ("expected non nil error" )
897
+ }
898
+ ce := & network.ConnError {}
899
+ if errors .As (err , & ce ) {
900
+ require .Equal (t , code , ce .ErrorCode )
901
+ require .Equal (t , remote , ce .Remote )
902
+ return
903
+ }
904
+ t .Fatal ("expected network.ConnError, got:" , err )
905
+ }
906
+
907
+ errCh := make (chan error )
908
+ server .SetStreamHandler ("/test" , func (s network.Stream ) {
909
+ defer s .Reset ()
910
+ b := make ([]byte , 10 )
911
+ n , err := s .Read (b )
912
+ if ! assert .NoError (t , err ) {
913
+ return
914
+ }
915
+ _ , err = s .Write (b [:n ])
916
+ if ! assert .NoError (t , err ) {
917
+ return
918
+ }
919
+ _ , err = s .Read (b )
920
+ errCh <- err
921
+
922
+ _ , err = s .Write (b )
923
+ errCh <- err
924
+ })
925
+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
926
+ defer cancel ()
927
+ client .Peerstore ().AddAddrs (server .ID (), server .Addrs (), peerstore .PermanentAddrTTL )
928
+ s , err := client .NewStream (ctx , server .ID (), "/test" )
929
+ require .NoError (t , err )
930
+
931
+ _ , err = s .Write ([]byte ("hello" ))
932
+ require .NoError (t , err )
933
+
934
+ buf := make ([]byte , 10 )
935
+ n , err := s .Read (buf )
936
+ require .NoError (t , err )
937
+ require .Equal (t , []byte ("hello" ), buf [:n ])
938
+
939
+ err = s .Conn ().CloseWithError (42 )
940
+ require .NoError (t , err )
941
+
942
+ _ , err = s .Read (buf )
943
+ checkError (err , 42 , false )
944
+
945
+ _ , err = s .Write (buf )
946
+ checkError (err , 42 , false )
947
+
948
+ err = <- errCh
949
+ checkError (err , 42 , true )
950
+
951
+ err = <- errCh
952
+ checkError (err , 42 , true )
953
+ })
954
+ }
955
+ }
956
+
957
+ // TestConnectionErrorCode tests correctness for resetting stream with errors
958
+ func TestConnectionErrorCode (t * testing.T ) {
959
+ for _ , tc := range transportsToTest {
960
+ t .Run (tc .Name , func (t * testing.T ) {
961
+ if tc .Name == "WebTransport" || tc .Name == "WebRTC" {
962
+ t .Skipf ("skipping: %s, not implemented" , tc .Name )
963
+ return
964
+ }
965
+ server := tc .HostGenerator (t , TransportTestCaseOpts {})
966
+ client := tc .HostGenerator (t , TransportTestCaseOpts {NoListen : true })
967
+ defer server .Close ()
968
+ defer client .Close ()
969
+
970
+ checkError := func (err error , code network.ConnErrorCode , remote bool ) {
971
+ t .Helper ()
972
+ if err == nil {
973
+ t .Fatal ("expected non nil error" )
974
+ }
975
+ ce := & network.ConnError {}
976
+ if errors .As (err , & ce ) {
977
+ require .Equal (t , code , ce .ErrorCode )
978
+ require .Equal (t , remote , ce .Remote )
979
+ return
980
+ }
981
+ t .Fatal ("expected network.ConnError, got:" , err )
982
+ }
983
+
984
+ errCh := make (chan error )
985
+ server .SetStreamHandler ("/test" , func (s network.Stream ) {
986
+ defer s .Reset ()
987
+ b := make ([]byte , 10 )
988
+ n , err := s .Read (b )
989
+ if ! assert .NoError (t , err ) {
990
+ return
991
+ }
992
+ _ , err = s .Write (b [:n ])
993
+ if ! assert .NoError (t , err ) {
994
+ return
995
+ }
996
+
997
+ _ , err = s .Read (b )
998
+ if ! assert .Error (t , err ) {
999
+ return
1000
+ }
1001
+ _ , err = s .Conn ().NewStream (context .Background ())
1002
+ errCh <- err
1003
+ })
1004
+ ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
1005
+ defer cancel ()
1006
+ client .Peerstore ().AddAddrs (server .ID (), server .Addrs (), peerstore .PermanentAddrTTL )
1007
+ s , err := client .NewStream (ctx , server .ID (), "/test" )
1008
+ require .NoError (t , err )
1009
+
1010
+ _ , err = s .Write ([]byte ("hello" ))
1011
+ require .NoError (t , err )
1012
+
1013
+ buf := make ([]byte , 10 )
1014
+ n , err := s .Read (buf )
1015
+ require .NoError (t , err )
1016
+ require .Equal (t , []byte ("hello" ), buf [:n ])
1017
+
1018
+ err = s .Conn ().CloseWithError (42 )
1019
+ require .NoError (t , err )
1020
+
1021
+ str , err := s .Conn ().NewStream (context .Background ())
1022
+ require .Nil (t , str )
1023
+ checkError (err , 42 , false )
1024
+
868
1025
err = <- errCh
869
1026
checkError (err , 42 , true )
1027
+
870
1028
})
871
1029
}
872
1030
}
0 commit comments