@@ -2,10 +2,11 @@ use std::iter;
2
2
3
3
use proc_macro2:: TokenStream ;
4
4
use quote:: { quote, quote_spanned, ToTokens } ;
5
+ use syn:: visit_mut:: VisitMut ;
5
6
use syn:: {
6
7
punctuated:: Punctuated , spanned:: Spanned , Block , Expr , ExprAsync , ExprCall , FieldPat , FnArg ,
7
8
Ident , Item , ItemFn , Pat , PatIdent , PatReference , PatStruct , PatTuple , PatTupleStruct , PatType ,
8
- Path , Signature , Stmt , Token , TypePath ,
9
+ Path , ReturnType , Signature , Stmt , Token , Type , TypePath ,
9
10
} ;
10
11
11
12
use crate :: {
@@ -18,7 +19,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
18
19
input : MaybeItemFnRef < ' a , B > ,
19
20
args : InstrumentArgs ,
20
21
instrumented_function_name : & str ,
21
- self_type : Option < & syn :: TypePath > ,
22
+ self_type : Option < & TypePath > ,
22
23
) -> proc_macro2:: TokenStream {
23
24
// these are needed ahead of time, as ItemFn contains the function body _and_
24
25
// isn't representable inside a quote!/quote_spanned! macro
@@ -31,7 +32,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
31
32
} = input;
32
33
33
34
let Signature {
34
- output : return_type ,
35
+ output,
35
36
inputs : params,
36
37
unsafety,
37
38
asyncness,
@@ -49,8 +50,34 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
49
50
50
51
let warnings = args. warnings ( ) ;
51
52
53
+ let block = if let ReturnType :: Type ( _, return_type) = & output {
54
+ let return_type = erase_impl_trait ( return_type) ;
55
+ // Install a fake return statement as the first thing in the function
56
+ // body, so that we eagerly infer that the return type is what we
57
+ // declared in the async fn signature.
58
+ let fake_return_edge = quote_spanned ! { return_type. span( ) =>
59
+ #[ allow( unreachable_code) ]
60
+ if false {
61
+ let __tracing_attr_fake_return: #return_type = panic!( ) ;
62
+ return __tracing_attr_fake_return;
63
+ }
64
+ } ;
65
+ quote ! {
66
+ {
67
+ #fake_return_edge
68
+ #block
69
+ }
70
+ }
71
+ } else {
72
+ quote ! {
73
+ {
74
+ let _: ( ) = #block;
75
+ }
76
+ }
77
+ } ;
78
+
52
79
let body = gen_block (
53
- block,
80
+ & block,
54
81
params,
55
82
asyncness. is_some ( ) ,
56
83
args,
@@ -60,7 +87,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
60
87
61
88
quote ! (
62
89
#( #attrs) *
63
- #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>( #params) #return_type
90
+ #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>( #params) #output
64
91
#where_clause
65
92
{
66
93
#warnings
@@ -76,7 +103,7 @@ fn gen_block<B: ToTokens>(
76
103
async_context : bool ,
77
104
mut args : InstrumentArgs ,
78
105
instrumented_function_name : & str ,
79
- self_type : Option < & syn :: TypePath > ,
106
+ self_type : Option < & TypePath > ,
80
107
) -> proc_macro2:: TokenStream {
81
108
// generate the span's name
82
109
let span_name = args
@@ -393,11 +420,11 @@ impl RecordType {
393
420
"Wrapping" ,
394
421
] ;
395
422
396
- /// Parse `RecordType` from [syn:: Type] by looking up
423
+ /// Parse `RecordType` from [Type] by looking up
397
424
/// the [RecordType::TYPES_FOR_VALUE] array.
398
- fn parse_from_ty ( ty : & syn :: Type ) -> Self {
425
+ fn parse_from_ty ( ty : & Type ) -> Self {
399
426
match ty {
400
- syn :: Type :: Path ( syn :: TypePath { path, .. } )
427
+ Type :: Path ( TypePath { path, .. } )
401
428
if path
402
429
. segments
403
430
. iter ( )
@@ -410,9 +437,7 @@ impl RecordType {
410
437
{
411
438
RecordType :: Value
412
439
}
413
- syn:: Type :: Reference ( syn:: TypeReference { elem, .. } ) => {
414
- RecordType :: parse_from_ty ( & * elem)
415
- }
440
+ Type :: Reference ( syn:: TypeReference { elem, .. } ) => RecordType :: parse_from_ty ( & * elem) ,
416
441
_ => RecordType :: Debug ,
417
442
}
418
443
}
@@ -471,7 +496,7 @@ pub(crate) struct AsyncInfo<'block> {
471
496
// statement that must be patched
472
497
source_stmt : & ' block Stmt ,
473
498
kind : AsyncKind < ' block > ,
474
- self_type : Option < syn :: TypePath > ,
499
+ self_type : Option < TypePath > ,
475
500
input : & ' block ItemFn ,
476
501
}
477
502
@@ -606,11 +631,11 @@ impl<'block> AsyncInfo<'block> {
606
631
if ident == "_self" {
607
632
let mut ty = * ty. ty . clone ( ) ;
608
633
// extract the inner type if the argument is "&self" or "&mut self"
609
- if let syn :: Type :: Reference ( syn:: TypeReference { elem, .. } ) = ty {
634
+ if let Type :: Reference ( syn:: TypeReference { elem, .. } ) = ty {
610
635
ty = * elem;
611
636
}
612
637
613
- if let syn :: Type :: Path ( tp) = ty {
638
+ if let Type :: Path ( tp) = ty {
614
639
self_type = Some ( tp) ;
615
640
break ;
616
641
}
@@ -722,7 +747,7 @@ struct IdentAndTypesRenamer<'a> {
722
747
idents : Vec < ( Ident , Ident ) > ,
723
748
}
724
749
725
- impl < ' a > syn :: visit_mut :: VisitMut for IdentAndTypesRenamer < ' a > {
750
+ impl < ' a > VisitMut for IdentAndTypesRenamer < ' a > {
726
751
// we deliberately compare strings because we want to ignore the spans
727
752
// If we apply clippy's lint, the behavior changes
728
753
#[ allow( clippy:: cmp_owned) ]
@@ -734,11 +759,11 @@ impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> {
734
759
}
735
760
}
736
761
737
- fn visit_type_mut ( & mut self , ty : & mut syn :: Type ) {
762
+ fn visit_type_mut ( & mut self , ty : & mut Type ) {
738
763
for ( type_name, new_type) in & self . types {
739
- if let syn :: Type :: Path ( TypePath { path, .. } ) = ty {
764
+ if let Type :: Path ( TypePath { path, .. } ) = ty {
740
765
if path_to_string ( path) == * type_name {
741
- * ty = syn :: Type :: Path ( new_type. clone ( ) ) ;
766
+ * ty = Type :: Path ( new_type. clone ( ) ) ;
742
767
}
743
768
}
744
769
}
@@ -751,10 +776,33 @@ struct AsyncTraitBlockReplacer<'a> {
751
776
patched_block : Block ,
752
777
}
753
778
754
- impl < ' a > syn :: visit_mut :: VisitMut for AsyncTraitBlockReplacer < ' a > {
779
+ impl < ' a > VisitMut for AsyncTraitBlockReplacer < ' a > {
755
780
fn visit_block_mut ( & mut self , i : & mut Block ) {
756
781
if i == self . block {
757
782
* i = self . patched_block . clone ( ) ;
758
783
}
759
784
}
760
785
}
786
+
787
+ // Replaces any `impl Trait` with `_` so it can be used as the type in
788
+ // a `let` statement's LHS.
789
+ struct ImplTraitEraser ;
790
+
791
+ impl VisitMut for ImplTraitEraser {
792
+ fn visit_type_mut ( & mut self , t : & mut Type ) {
793
+ if let Type :: ImplTrait ( ..) = t {
794
+ * t = syn:: TypeInfer {
795
+ underscore_token : Token ! [ _] ( t. span ( ) ) ,
796
+ }
797
+ . into ( ) ;
798
+ } else {
799
+ syn:: visit_mut:: visit_type_mut ( self , t) ;
800
+ }
801
+ }
802
+ }
803
+
804
+ fn erase_impl_trait ( ty : & Type ) -> Type {
805
+ let mut ty = ty. clone ( ) ;
806
+ ImplTraitEraser . visit_type_mut ( & mut ty) ;
807
+ ty
808
+ }
0 commit comments