@@ -2,10 +2,11 @@ use std::iter;
2
2
3
3
use proc_macro2:: TokenStream ;
4
4
use quote:: { quote, quote_spanned, ToTokens } ;
5
+ use syn:: visit_mut:: VisitMut ;
5
6
use syn:: {
6
7
punctuated:: Punctuated , spanned:: Spanned , Block , Expr , ExprAsync , ExprCall , FieldPat , FnArg ,
7
8
Ident , Item , ItemFn , Pat , PatIdent , PatReference , PatStruct , PatTuple , PatTupleStruct , PatType ,
8
- Path , Signature , Stmt , Token , TypePath ,
9
+ Path , ReturnType , Signature , Stmt , Token , Type , TypePath ,
9
10
} ;
10
11
11
12
use crate :: {
@@ -18,7 +19,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
18
19
input : MaybeItemFnRef < ' a , B > ,
19
20
args : InstrumentArgs ,
20
21
instrumented_function_name : & str ,
21
- self_type : Option < & syn :: TypePath > ,
22
+ self_type : Option < & TypePath > ,
22
23
) -> proc_macro2:: TokenStream {
23
24
// these are needed ahead of time, as ItemFn contains the function body _and_
24
25
// isn't representable inside a quote!/quote_spanned! macro
@@ -31,7 +32,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
31
32
} = input;
32
33
33
34
let Signature {
34
- output : return_type ,
35
+ output,
35
36
inputs : params,
36
37
unsafety,
37
38
asyncness,
@@ -49,8 +50,37 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
49
50
50
51
let warnings = args. warnings ( ) ;
51
52
53
+ let block = if let ReturnType :: Type ( _, return_type) = & output {
54
+ let return_type = erase_impl_trait ( return_type) ;
55
+ // Install a fake return statement as the first thing in the function
56
+ // body, so that we eagerly infer that the return type is what we
57
+ // declared in the async fn signature.
58
+ // The `#[allow(unreachable_code)]` is given because the return
59
+ // statement is unreachable, but does affect inference.
60
+ let fake_return_edge = quote_spanned ! { return_type. span( ) =>
61
+ #[ allow( unreachable_code) ]
62
+ if false {
63
+ let __tracing_attr_fake_return: #return_type =
64
+ unreachable!( "this is just for type inference, and is unreachable code" ) ;
65
+ return __tracing_attr_fake_return;
66
+ }
67
+ } ;
68
+ quote ! {
69
+ {
70
+ #fake_return_edge
71
+ #block
72
+ }
73
+ }
74
+ } else {
75
+ quote ! {
76
+ {
77
+ let _: ( ) = #block;
78
+ }
79
+ }
80
+ } ;
81
+
52
82
let body = gen_block (
53
- block,
83
+ & block,
54
84
params,
55
85
asyncness. is_some ( ) ,
56
86
args,
@@ -60,7 +90,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
60
90
61
91
quote ! (
62
92
#( #attrs) *
63
- #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>( #params) #return_type
93
+ #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>( #params) #output
64
94
#where_clause
65
95
{
66
96
#warnings
@@ -76,7 +106,7 @@ fn gen_block<B: ToTokens>(
76
106
async_context : bool ,
77
107
mut args : InstrumentArgs ,
78
108
instrumented_function_name : & str ,
79
- self_type : Option < & syn :: TypePath > ,
109
+ self_type : Option < & TypePath > ,
80
110
) -> proc_macro2:: TokenStream {
81
111
// generate the span's name
82
112
let span_name = args
@@ -393,11 +423,11 @@ impl RecordType {
393
423
"Wrapping" ,
394
424
] ;
395
425
396
- /// Parse `RecordType` from [syn:: Type] by looking up
426
+ /// Parse `RecordType` from [Type] by looking up
397
427
/// the [RecordType::TYPES_FOR_VALUE] array.
398
- fn parse_from_ty ( ty : & syn :: Type ) -> Self {
428
+ fn parse_from_ty ( ty : & Type ) -> Self {
399
429
match ty {
400
- syn :: Type :: Path ( syn :: TypePath { path, .. } )
430
+ Type :: Path ( TypePath { path, .. } )
401
431
if path
402
432
. segments
403
433
. iter ( )
@@ -410,9 +440,7 @@ impl RecordType {
410
440
{
411
441
RecordType :: Value
412
442
}
413
- syn:: Type :: Reference ( syn:: TypeReference { elem, .. } ) => {
414
- RecordType :: parse_from_ty ( & * elem)
415
- }
443
+ Type :: Reference ( syn:: TypeReference { elem, .. } ) => RecordType :: parse_from_ty ( & * elem) ,
416
444
_ => RecordType :: Debug ,
417
445
}
418
446
}
@@ -471,7 +499,7 @@ pub(crate) struct AsyncInfo<'block> {
471
499
// statement that must be patched
472
500
source_stmt : & ' block Stmt ,
473
501
kind : AsyncKind < ' block > ,
474
- self_type : Option < syn :: TypePath > ,
502
+ self_type : Option < TypePath > ,
475
503
input : & ' block ItemFn ,
476
504
}
477
505
@@ -606,11 +634,11 @@ impl<'block> AsyncInfo<'block> {
606
634
if ident == "_self" {
607
635
let mut ty = * ty. ty . clone ( ) ;
608
636
// extract the inner type if the argument is "&self" or "&mut self"
609
- if let syn :: Type :: Reference ( syn:: TypeReference { elem, .. } ) = ty {
637
+ if let Type :: Reference ( syn:: TypeReference { elem, .. } ) = ty {
610
638
ty = * elem;
611
639
}
612
640
613
- if let syn :: Type :: Path ( tp) = ty {
641
+ if let Type :: Path ( tp) = ty {
614
642
self_type = Some ( tp) ;
615
643
break ;
616
644
}
@@ -722,7 +750,7 @@ struct IdentAndTypesRenamer<'a> {
722
750
idents : Vec < ( Ident , Ident ) > ,
723
751
}
724
752
725
- impl < ' a > syn :: visit_mut :: VisitMut for IdentAndTypesRenamer < ' a > {
753
+ impl < ' a > VisitMut for IdentAndTypesRenamer < ' a > {
726
754
// we deliberately compare strings because we want to ignore the spans
727
755
// If we apply clippy's lint, the behavior changes
728
756
#[ allow( clippy:: cmp_owned) ]
@@ -734,11 +762,11 @@ impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> {
734
762
}
735
763
}
736
764
737
- fn visit_type_mut ( & mut self , ty : & mut syn :: Type ) {
765
+ fn visit_type_mut ( & mut self , ty : & mut Type ) {
738
766
for ( type_name, new_type) in & self . types {
739
- if let syn :: Type :: Path ( TypePath { path, .. } ) = ty {
767
+ if let Type :: Path ( TypePath { path, .. } ) = ty {
740
768
if path_to_string ( path) == * type_name {
741
- * ty = syn :: Type :: Path ( new_type. clone ( ) ) ;
769
+ * ty = Type :: Path ( new_type. clone ( ) ) ;
742
770
}
743
771
}
744
772
}
@@ -751,10 +779,33 @@ struct AsyncTraitBlockReplacer<'a> {
751
779
patched_block : Block ,
752
780
}
753
781
754
- impl < ' a > syn :: visit_mut :: VisitMut for AsyncTraitBlockReplacer < ' a > {
782
+ impl < ' a > VisitMut for AsyncTraitBlockReplacer < ' a > {
755
783
fn visit_block_mut ( & mut self , i : & mut Block ) {
756
784
if i == self . block {
757
785
* i = self . patched_block . clone ( ) ;
758
786
}
759
787
}
760
788
}
789
+
790
+ // Replaces any `impl Trait` with `_` so it can be used as the type in
791
+ // a `let` statement's LHS.
792
+ struct ImplTraitEraser ;
793
+
794
+ impl VisitMut for ImplTraitEraser {
795
+ fn visit_type_mut ( & mut self , t : & mut Type ) {
796
+ if let Type :: ImplTrait ( ..) = t {
797
+ * t = syn:: TypeInfer {
798
+ underscore_token : Token ! [ _] ( t. span ( ) ) ,
799
+ }
800
+ . into ( ) ;
801
+ } else {
802
+ syn:: visit_mut:: visit_type_mut ( self , t) ;
803
+ }
804
+ }
805
+ }
806
+
807
+ fn erase_impl_trait ( ty : & Type ) -> Type {
808
+ let mut ty = ty. clone ( ) ;
809
+ ImplTraitEraser . visit_type_mut ( & mut ty) ;
810
+ ty
811
+ }
0 commit comments