1
1
use std:: { cell:: Cell , cell:: RefCell , fmt, mem, rc:: Rc } ;
2
- use std:: { collections:: VecDeque , time:: Duration , time :: Instant } ;
2
+ use std:: { collections:: VecDeque , time:: Instant } ;
3
3
4
4
use ntex_bytes:: { ByteString , Bytes } ;
5
5
use ntex_http:: { HeaderMap , Method } ;
@@ -408,7 +408,7 @@ impl Connection {
408
408
}
409
409
410
410
// Add ids to pending queue
411
- self . 0 . local_pending_reset . add ( id) ;
411
+ self . 0 . local_pending_reset . add ( id, & self . 0 . local_config ) ;
412
412
}
413
413
414
414
pub ( crate ) fn recv_half ( & self ) -> RecvHalfConnection {
@@ -722,7 +722,7 @@ impl RecvHalfConnection {
722
722
stream,
723
723
StreamError :: Reset ( frm. reason ( ) ) ,
724
724
) ) )
725
- } else if self . 0 . local_pending_reset . is_pending ( id) {
725
+ } else if self . 0 . local_pending_reset . remove ( id) {
726
726
self . update_rst_count ( )
727
727
} else {
728
728
self . update_rst_count ( ) ?;
@@ -866,79 +866,71 @@ async fn ping(st: Connection, timeout: time::Seconds, io: IoRef) {
866
866
}
867
867
}
868
868
869
- const BLOCKS : usize = 5 ;
870
- const LAST_BLOCK : usize = 4 ;
869
+ struct Pending ( Cell < Option < Box < PendingInner > > > ) ;
871
870
872
- #[ cfg( not( test) ) ]
873
- const SECS : u64 = 30 ;
874
- #[ cfg( test) ]
875
- const SECS : u64 = 1 ;
876
-
877
- const BLOCK_SIZE : Duration = Duration :: from_secs ( SECS ) ;
878
- const ALL_BLOCKS : Duration = Duration :: from_secs ( ( BLOCKS as u64 ) * SECS ) ;
879
-
880
- #[ derive( Default ) ]
881
- struct Pending {
882
- idx : Cell < u8 > ,
883
- blocks : RefCell < [ Block ; BLOCKS ] > ,
884
- }
885
-
886
- #[ derive( Debug ) ]
887
- struct Block {
888
- start_time : Instant ,
871
+ struct PendingInner {
889
872
ids : HashSet < StreamId > ,
873
+ queue : VecDeque < ( StreamId , Instant ) > ,
890
874
}
891
875
892
- impl Pending {
893
- fn add ( & self , id : StreamId ) {
894
- let cur = now ( ) ;
895
- let idx = self . idx . get ( ) as usize ;
896
- let mut blocks = self . blocks . borrow_mut ( ) ;
897
-
898
- // check if we need to insert new block
899
- if blocks[ idx] . start_time < ( cur - BLOCK_SIZE ) {
900
- // shift blocks
901
- let idx = if idx == 0 { LAST_BLOCK } else { idx - 1 } ;
902
- // insert new item
903
- blocks[ idx] . start_time = cur;
904
- blocks[ idx] . ids . clear ( ) ;
905
- blocks[ idx] . ids . insert ( id) ;
906
- self . idx . set ( idx as u8 ) ;
907
- } else {
908
- blocks[ idx] . ids . insert ( id) ;
909
- }
876
+ impl Default for Pending {
877
+ fn default ( ) -> Self {
878
+ Self ( Cell :: new ( Some ( Box :: new ( PendingInner {
879
+ ids : HashSet :: default ( ) ,
880
+ queue : VecDeque :: with_capacity ( 16 ) ,
881
+ } ) ) ) )
910
882
}
883
+ }
911
884
912
- fn is_pending ( & self , id : StreamId ) -> bool {
913
- let blocks = self . blocks . borrow_mut ( ) ;
885
+ impl Pending {
886
+ fn add ( & self , id : StreamId , config : & Config ) {
887
+ let mut inner = self . 0 . take ( ) . unwrap ( ) ;
914
888
915
- let max = now ( ) - ALL_BLOCKS ;
916
- let mut idx = self . idx . get ( ) as usize ;
889
+ let current_time = now ( ) ;
917
890
918
- loop {
919
- let item = & blocks[ idx] ;
920
- if item. start_time < max {
891
+ // remove old ids
892
+ let max_time = current_time - config. 0 . reset_duration . get ( ) ;
893
+ while let Some ( item) = inner. queue . front ( ) {
894
+ if item. 1 < max_time {
895
+ inner. ids . remove ( & item. 0 ) ;
896
+ inner. queue . pop_front ( ) ;
897
+ } else {
921
898
break ;
922
- } else if item. ids . contains ( & id) {
923
- return true ;
924
899
}
925
- idx += 1 ;
926
- if idx == BLOCKS {
927
- idx = 0 ;
928
- } else if idx == self . idx . get ( ) as usize {
929
- break ;
900
+ }
901
+
902
+ // shrink size of ids
903
+ while inner. queue . len ( ) >= config. 0 . reset_max . get ( ) {
904
+ if let Some ( ( id, _) ) = inner. queue . pop_front ( ) {
905
+ inner. ids . remove ( & id) ;
930
906
}
931
907
}
932
- false
908
+
909
+ inner. ids . insert ( id) ;
910
+ inner. queue . push_back ( ( id, current_time) ) ;
911
+ self . 0 . set ( Some ( inner) ) ;
933
912
}
934
- }
935
913
936
- impl Default for Block {
937
- fn default ( ) -> Self {
938
- Self {
939
- ids : HashSet :: default ( ) ,
940
- start_time : now ( ) - ALL_BLOCKS ,
914
+ fn remove ( & self , id : StreamId ) -> bool {
915
+ let mut inner = self . 0 . take ( ) . unwrap ( ) ;
916
+ let removed = inner. ids . remove ( & id) ;
917
+ if removed {
918
+ for idx in 0 ..inner. queue . len ( ) {
919
+ if inner. queue [ idx] . 0 == id {
920
+ inner. queue . remove ( idx) ;
921
+ break ;
922
+ }
923
+ }
941
924
}
925
+ self . 0 . set ( Some ( inner) ) ;
926
+ removed
927
+ }
928
+
929
+ fn is_pending ( & self , id : StreamId ) -> bool {
930
+ let inner = self . 0 . take ( ) . unwrap ( ) ;
931
+ let pending = inner. ids . contains ( & id) ;
932
+ self . 0 . set ( Some ( inner) ) ;
933
+ pending
942
934
}
943
935
}
944
936
@@ -1017,11 +1009,16 @@ mod tests {
1017
1009
1018
1010
#[ ntex:: test]
1019
1011
async fn test_delay_reset_queue ( ) {
1012
+ let _ = env_logger:: try_init ( ) ;
1013
+
1020
1014
let srv = test_server ( || {
1021
1015
fn_service ( |io : Io < _ > | async move {
1022
1016
let _ = h2:: server:: handle_one (
1023
1017
io. into ( ) ,
1024
- h2:: Config :: server ( ) . ping_timeout ( Seconds :: ZERO ) . clone ( ) ,
1018
+ h2:: Config :: server ( )
1019
+ . ping_timeout ( Seconds :: ZERO )
1020
+ . reset_stream_duration ( Seconds ( 1 ) )
1021
+ . clone ( ) ,
1025
1022
fn_service ( |msg : h2:: ControlMessage < h2:: StreamError > | async move {
1026
1023
Ok :: < _ , ( ) > ( msg. ack ( ) )
1027
1024
} ) ,
@@ -1037,7 +1034,7 @@ mod tests {
1037
1034
} ) ;
1038
1035
1039
1036
let addr = ntex:: connect:: Connect :: new ( "localhost" ) . set_addr ( Some ( srv. addr ( ) ) ) ;
1040
- let io = ntex:: connect:: connect ( addr) . await . unwrap ( ) ;
1037
+ let io = ntex:: connect:: connect ( addr. clone ( ) ) . await . unwrap ( ) ;
1041
1038
let codec = Codec :: default ( ) ;
1042
1039
let _ = io. with_write_buf ( |buf| buf. extend_from_slice ( & PREFACE ) ) ;
1043
1040
@@ -1081,14 +1078,45 @@ mod tests {
1081
1078
assert_eq ! ( res. reason( ) , Reason :: NO_ERROR ) ;
1082
1079
1083
1080
// prev closed stream
1084
- io. send ( pl. clone ( ) . into ( ) , & codec) . await . unwrap ( ) ;
1081
+ io. send ( pl. into ( ) , & codec) . await . unwrap ( ) ;
1082
+ let res = goaway ( io. recv ( & codec) . await . unwrap ( ) . unwrap ( ) ) ;
1083
+ assert_eq ! ( res. reason( ) , Reason :: PROTOCOL_ERROR ) ;
1084
+
1085
+ // SECOND connection
1086
+ let io = ntex:: connect:: connect ( addr) . await . unwrap ( ) ;
1087
+ let codec = Codec :: default ( ) ;
1088
+ let _ = io. with_write_buf ( |buf| buf. extend_from_slice ( & PREFACE ) ) ;
1089
+
1090
+ let settings = frame:: Settings :: default ( ) ;
1091
+ io. encode ( settings. into ( ) , & codec) . unwrap ( ) ;
1092
+
1093
+ // settings & window
1094
+ let _ = io. recv ( & codec) . await ;
1095
+ let _ = io. recv ( & codec) . await ;
1096
+ let _ = io. recv ( & codec) . await ;
1097
+
1098
+ let id = frame:: StreamId :: CLIENT ;
1099
+ let pseudo = frame:: PseudoHeaders {
1100
+ method : Some ( Method :: GET ) ,
1101
+ scheme : Some ( "HTTPS" . into ( ) ) ,
1102
+ authority : Some ( "localhost" . into ( ) ) ,
1103
+ path : Some ( "/" . into ( ) ) ,
1104
+ ..Default :: default ( )
1105
+ } ;
1106
+ let hdrs = frame:: Headers :: new ( id, pseudo. clone ( ) , HeaderMap :: new ( ) , false ) ;
1107
+ io. send ( hdrs. into ( ) , & codec) . await . unwrap ( ) ;
1108
+
1109
+ // server resets stream
1085
1110
let res = get_reset ( io. recv ( & codec) . await . unwrap ( ) . unwrap ( ) ) ;
1086
- assert_eq ! ( res. reason( ) , Reason :: STREAM_CLOSED ) ;
1111
+ assert_eq ! ( res. reason( ) , Reason :: NO_ERROR ) ;
1087
1112
1088
- sleep ( Millis ( 5100 ) ) . await ;
1113
+ // after server receives remote reset, any next frame cause protocol error
1114
+ io. send ( frame:: Reset :: new ( id, Reason :: NO_ERROR ) . into ( ) , & codec)
1115
+ . await
1116
+ . unwrap ( ) ;
1089
1117
1090
- // prev closed stream
1091
- io. send ( pl. into ( ) , & codec) . await . unwrap ( ) ;
1118
+ let pl = frame :: Data :: new ( id , Bytes :: from_static ( b"data" ) ) ;
1119
+ io. send ( pl. clone ( ) . into ( ) , & codec) . await . unwrap ( ) ;
1092
1120
let res = goaway ( io. recv ( & codec) . await . unwrap ( ) . unwrap ( ) ) ;
1093
1121
assert_eq ! ( res. reason( ) , Reason :: PROTOCOL_ERROR ) ;
1094
1122
}
0 commit comments