@@ -4,6 +4,7 @@ use std::io::{Read, Write};
4
4
use std:: net:: { SocketAddr , TcpListener } ;
5
5
use std:: pin:: Pin ;
6
6
use std:: sync:: atomic:: Ordering ;
7
+ use std:: sync:: Arc ;
7
8
use std:: task:: Poll ;
8
9
use std:: thread;
9
10
use std:: time:: Duration ;
@@ -891,7 +892,6 @@ fn capture_connection_on_client() {
891
892
let addr = server. local_addr ( ) . unwrap ( ) ;
892
893
thread:: spawn ( move || {
893
894
let mut sock = server. accept ( ) . unwrap ( ) . 0 ;
894
- //drop(server);
895
895
sock. set_read_timeout ( Some ( Duration :: from_secs ( 5 ) ) ) . unwrap ( ) ;
896
896
sock. set_write_timeout ( Some ( Duration :: from_secs ( 5 ) ) )
897
897
. unwrap ( ) ;
@@ -908,3 +908,74 @@ fn capture_connection_on_client() {
908
908
rt. block_on ( client. request ( req) ) . expect ( "200 OK" ) ;
909
909
assert ! ( captured_conn. connection_metadata( ) . is_some( ) ) ;
910
910
}
911
+
912
+ #[ cfg( not( miri) ) ]
913
+ #[ test]
914
+ fn connection_poisoning ( ) {
915
+ use std:: sync:: atomic:: AtomicUsize ;
916
+
917
+ let _ = pretty_env_logger:: try_init ( ) ;
918
+
919
+ let rt = runtime ( ) ;
920
+ let connector = DebugConnector :: new ( ) ;
921
+
922
+ let client = Client :: builder ( TokioExecutor :: new ( ) ) . build ( connector) ;
923
+
924
+ let server = TcpListener :: bind ( "127.0.0.1:0" ) . unwrap ( ) ;
925
+ let addr = server. local_addr ( ) . unwrap ( ) ;
926
+ let num_conns: Arc < AtomicUsize > = Default :: default ( ) ;
927
+ let num_requests: Arc < AtomicUsize > = Default :: default ( ) ;
928
+ let num_requests_tracker = num_requests. clone ( ) ;
929
+ let num_conns_tracker = num_conns. clone ( ) ;
930
+ thread:: spawn ( move || loop {
931
+ let mut sock = server. accept ( ) . unwrap ( ) . 0 ;
932
+ num_conns_tracker. fetch_add ( 1 , Ordering :: Relaxed ) ;
933
+ let num_requests_tracker = num_requests_tracker. clone ( ) ;
934
+ thread:: spawn ( move || {
935
+ sock. set_read_timeout ( Some ( Duration :: from_secs ( 5 ) ) ) . unwrap ( ) ;
936
+ sock. set_write_timeout ( Some ( Duration :: from_secs ( 5 ) ) )
937
+ . unwrap ( ) ;
938
+ let mut buf = [ 0 ; 4096 ] ;
939
+ loop {
940
+ if sock. read ( & mut buf) . expect ( "read 1" ) > 0 {
941
+ num_requests_tracker. fetch_add ( 1 , Ordering :: Relaxed ) ;
942
+ sock. write_all ( b"HTTP/1.1 200 OK\r \n Content-Length: 0\r \n \r \n " )
943
+ . expect ( "write 1" ) ;
944
+ }
945
+ }
946
+ } ) ;
947
+ } ) ;
948
+ let make_request = || {
949
+ Request :: builder ( )
950
+ . uri ( & * format ! ( "http://{}/a" , addr) )
951
+ . body ( Empty :: < Bytes > :: new ( ) )
952
+ . unwrap ( )
953
+ } ;
954
+ let mut req = make_request ( ) ;
955
+ let captured_conn = capture_connection ( & mut req) ;
956
+ rt. block_on ( client. request ( req) ) . expect ( "200 OK" ) ;
957
+ assert_eq ! ( num_conns. load( Ordering :: SeqCst ) , 1 ) ;
958
+ assert_eq ! ( num_requests. load( Ordering :: SeqCst ) , 1 ) ;
959
+
960
+ rt. block_on ( client. request ( make_request ( ) ) ) . expect ( "200 OK" ) ;
961
+ rt. block_on ( client. request ( make_request ( ) ) ) . expect ( "200 OK" ) ;
962
+ // Before poisoning the connection is reused
963
+ assert_eq ! ( num_conns. load( Ordering :: SeqCst ) , 1 ) ;
964
+ assert_eq ! ( num_requests. load( Ordering :: SeqCst ) , 3 ) ;
965
+ captured_conn
966
+ . connection_metadata ( )
967
+ . as_ref ( )
968
+ . unwrap ( )
969
+ . poison ( ) ;
970
+
971
+ rt. block_on ( client. request ( make_request ( ) ) ) . expect ( "200 OK" ) ;
972
+
973
+ // After poisoning, a new connection is established
974
+ assert_eq ! ( num_conns. load( Ordering :: SeqCst ) , 2 ) ;
975
+ assert_eq ! ( num_requests. load( Ordering :: SeqCst ) , 4 ) ;
976
+
977
+ rt. block_on ( client. request ( make_request ( ) ) ) . expect ( "200 OK" ) ;
978
+ // another request can still reuse:
979
+ assert_eq ! ( num_conns. load( Ordering :: SeqCst ) , 2 ) ;
980
+ assert_eq ! ( num_requests. load( Ordering :: SeqCst ) , 5 ) ;
981
+ }
0 commit comments