@@ -525,118 +525,167 @@ handle_frps_msg(uint8_t *buf, int len, void *ctx)
525
525
526
526
static struct tmux_stream abandon_stream ;
527
527
528
- // ctx: if recv_cb was called by common control, ctx == NULL
529
- // else ctx == client struct
530
- static void
531
- recv_cb (struct bufferevent * bev , void * ctx )
528
+ static int handle_tcp_mux_header (struct bufferevent * bev ,
529
+ struct tcp_mux_header * tmux_hdr ,
530
+ int * len ,
531
+ uint32_t * stream_len )
532
+ {
533
+ uint8_t * data = (uint8_t * )tmux_hdr ;
534
+ size_t nr = bufferevent_read (bev , data , sizeof (* tmux_hdr ));
535
+ if (nr != sizeof (* tmux_hdr )) {
536
+ debug (LOG_ERR , "Failed to read TCP mux header" );
537
+ return -1 ;
538
+ }
539
+
540
+ * len -= nr ;
541
+
542
+ if (!validate_tcp_mux_protocol (tmux_hdr )) {
543
+ debug (LOG_ERR , "Invalid TCP mux protocol" );
544
+ return -1 ;
545
+ }
546
+
547
+ if (tmux_hdr -> type == DATA ) {
548
+ * stream_len = ntohl (tmux_hdr -> length );
549
+ }
550
+
551
+ return 0 ;
552
+ }
553
+
554
+ static struct tmux_stream * get_stream_for_data (const struct tcp_mux_header * tmux_hdr ,
555
+ uint32_t stream_len )
556
+ {
557
+ struct tmux_stream * cur = NULL ;
558
+ if (tmux_hdr -> type == DATA ) {
559
+ uint32_t stream_id = ntohl (tmux_hdr -> stream_id );
560
+ cur = get_stream_by_id (stream_id );
561
+ if (!cur && stream_len > 0 ) {
562
+ debug (LOG_INFO , "Using abandon stream for id %d, len %d" ,
563
+ stream_id , stream_len );
564
+ cur = & abandon_stream ;
565
+ }
566
+ }
567
+ return cur ;
568
+ }
569
+
570
+ static int process_stream_data (struct bufferevent * bev ,
571
+ struct tmux_stream * cur ,
572
+ int * len ,
573
+ uint32_t * stream_len )
574
+ {
575
+ size_t nr ;
576
+ if (* len >= * stream_len ) {
577
+ nr = tmux_stream_read (bev , cur , * stream_len );
578
+ if (nr != * stream_len ) {
579
+ return -1 ;
580
+ }
581
+ * len -= * stream_len ;
582
+ * stream_len = 0 ;
583
+ } else {
584
+ nr = tmux_stream_read (bev , cur , * len );
585
+ if (nr != * len ) {
586
+ return -1 ;
587
+ }
588
+ * stream_len -= * len ;
589
+ * len = 0 ;
590
+ set_cur_stream (cur );
591
+ }
592
+ return 0 ;
593
+ }
594
+
595
+ static void handle_tcp_mux_message (const struct tcp_mux_header * tmux_hdr ,
596
+ struct tmux_stream * cur )
597
+ {
598
+ if (cur == & abandon_stream ) {
599
+ debug (LOG_INFO , "Abandoning stream data" );
600
+ memset (cur , 0 , sizeof (abandon_stream ));
601
+ set_cur_stream (NULL );
602
+ return ;
603
+ }
604
+
605
+ switch (tmux_hdr -> type ) {
606
+ case DATA :
607
+ case WINDOW_UPDATE :
608
+ handle_tcp_mux_stream (tmux_hdr , handle_frps_msg );
609
+ break ;
610
+ case PING :
611
+ handle_tcp_mux_ping (tmux_hdr );
612
+ break ;
613
+ case GO_AWAY :
614
+ handle_tcp_mux_go_away (tmux_hdr );
615
+ break ;
616
+ default :
617
+ debug (LOG_ERR , "Invalid TCP mux message type" );
618
+ exit (-1 );
619
+ }
620
+ }
621
+
622
+ static void handle_tcp_mux (struct bufferevent * bev , int len )
623
+ {
624
+ static struct tcp_mux_header tmux_hdr ;
625
+ static uint32_t stream_len = 0 ;
626
+
627
+ while (len > 0 ) {
628
+ struct tmux_stream * cur = get_cur_stream ();
629
+
630
+ if (!cur ) {
631
+ if (len < sizeof (tmux_hdr )) {
632
+ debug (LOG_INFO , "Incomplete header: len=%d" , len );
633
+ break ;
634
+ }
635
+
636
+ memset (& tmux_hdr , 0 , sizeof (tmux_hdr ));
637
+ if (handle_tcp_mux_header (bev , & tmux_hdr , & len , & stream_len ) < 0 ) {
638
+ break ;
639
+ }
640
+
641
+ cur = get_stream_for_data (& tmux_hdr , stream_len );
642
+ if (!cur || len == 0 ) {
643
+ set_cur_stream (cur );
644
+ break ;
645
+ }
646
+ }
647
+
648
+ if (process_stream_data (bev , cur , & len , & stream_len ) < 0 ) {
649
+ break ;
650
+ }
651
+
652
+ if (len > 0 ) {
653
+ handle_tcp_mux_message (& tmux_hdr , cur );
654
+ set_cur_stream (NULL );
655
+ }
656
+ }
657
+ }
658
+
659
+ static void handle_direct_message (struct bufferevent * bev , void * ctx )
532
660
{
533
661
struct evbuffer * input = bufferevent_get_input (bev );
534
662
int len = evbuffer_get_length (input );
535
- if (len <= 0 ) {
536
- return ;
663
+
664
+ uint8_t * buf = calloc (len , 1 );
665
+ if (!buf ) {
666
+ debug (LOG_ERR , "Failed to allocate buffer" );
667
+ return ;
537
668
}
538
669
539
- struct common_conf * c_conf = get_common_config ();
540
- if (c_conf -> tcp_mux ) {
541
- static struct tcp_mux_header tmux_hdr ;
542
- static uint32_t stream_len = 0 ;
543
- while (len > 0 ) {
544
- struct tmux_stream * cur = get_cur_stream ();
545
- size_t nr = 0 ;
546
- if (!cur ) {
547
- memset (& tmux_hdr , 0 , sizeof (tmux_hdr ));
548
- uint8_t * data = (uint8_t * )& tmux_hdr ;
549
- if (len < sizeof (tmux_hdr )) {
550
- debug (LOG_INFO , "len [%d] < sizeof tmux_hdr" , len );
551
- break ;
552
- }
553
- nr = bufferevent_read (bev , data , sizeof (tmux_hdr ));
554
- assert (nr == sizeof (tmux_hdr ));
555
- assert (validate_tcp_mux_protocol (& tmux_hdr ) > 0 );
556
- len -= nr ;
557
- if (tmux_hdr .type == DATA ) {
558
- uint32_t stream_id = ntohl (tmux_hdr .stream_id );
559
- stream_len = ntohl (tmux_hdr .length );
560
- cur = get_stream_by_id (stream_id );
561
- if (!cur ) {
562
- debug (LOG_INFO , "cur is NULL stream_id is %d, stream_len is %d len is %d" ,
563
- stream_id , stream_len , len );
564
- if (stream_len > 0 )
565
- cur = & abandon_stream ;
566
- else
567
- continue ;
568
- }
569
-
570
- if (len == 0 ) {
571
- set_cur_stream (cur );
572
- break ;
573
- }
574
- if (len >= stream_len ) {
575
- nr = tmux_stream_read (bev , cur , stream_len );
576
- assert (nr == stream_len );
577
- len -= stream_len ;
578
- } else {
579
- nr = tmux_stream_read (bev , cur , len );
580
- stream_len -= len ;
581
- assert (nr == len );
582
- set_cur_stream (cur );
583
- len -= nr ;
584
- break ;
585
- }
586
- }
587
- } else {
588
- assert (tmux_hdr .type == DATA );
589
- if (len >= stream_len ) {
590
- nr = tmux_stream_read (bev , cur , stream_len );
591
- assert (nr == stream_len );
592
- len -= stream_len ;
593
- } else {
594
- nr = tmux_stream_read (bev , cur , len );
595
- stream_len -= len ;
596
- assert (nr == len );
597
- len -= nr ;
598
- break ;
599
- }
600
- }
601
-
602
- if (cur == & abandon_stream ) {
603
- debug (LOG_INFO , "abandon stream data ..." );
604
- memset (cur , 0 , sizeof (abandon_stream ));
605
- set_cur_stream (NULL );
606
- continue ;
607
- }
608
-
609
- switch (tmux_hdr .type ) {
610
- case DATA :
611
- case WINDOW_UPDATE :
612
- {
613
- handle_tcp_mux_stream (& tmux_hdr , handle_frps_msg );
614
- break ;
615
- }
616
- case PING :
617
- handle_tcp_mux_ping (& tmux_hdr );
618
- break ;
619
- case GO_AWAY :
620
- handle_tcp_mux_go_away (& tmux_hdr );
621
- break ;
622
- default :
623
- debug (LOG_ERR , "impossible here!!!!" );
624
- exit (-1 );
625
- }
626
-
627
- set_cur_stream (NULL );
628
- }
629
- } else {
630
- uint8_t * buf = calloc (len , 1 );
631
- assert (buf );
632
- evbuffer_remove (input , buf , len );
670
+ evbuffer_remove (input , buf , len );
671
+ handle_frps_msg (buf , len , ctx );
672
+ free (buf );
673
+ }
633
674
634
- handle_frps_msg (buf , len , ctx );
635
- SAFE_FREE (buf );
675
+ static void recv_cb (struct bufferevent * bev , void * ctx )
676
+ {
677
+ struct evbuffer * input = bufferevent_get_input (bev );
678
+ int len = evbuffer_get_length (input );
679
+ if (len <= 0 ) {
680
+ return ;
636
681
}
637
-
638
682
639
- return ;
683
+ struct common_conf * c_conf = get_common_config ();
684
+ if (c_conf -> tcp_mux ) {
685
+ handle_tcp_mux (bev , len );
686
+ } else {
687
+ handle_direct_message (bev , ctx );
688
+ }
640
689
}
641
690
642
691
static void handle_connection_failure (struct common_conf * c_conf , int * retry_times ) {
0 commit comments