Skip to content

Commit 1b5aaf7

Browse files
committed
Add support for custom bakes to databake
1 parent 73cb5ce commit 1b5aaf7

File tree

4 files changed

+273
-29
lines changed

4 files changed

+273
-29
lines changed

utils/databake/derive/src/lib.rs

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
use proc_macro::TokenStream;
88
use proc_macro2::TokenStream as TokenStream2;
9+
use quote::format_ident;
910
use quote::quote;
1011
use syn::{
1112
parse::{Parse, ParseStream},
@@ -31,6 +32,40 @@ use synstructure::{AddBounds, Structure};
3132
/// pub age: u32,
3233
/// }
3334
/// ```
35+
///
36+
/// # Custom baked type
37+
///
38+
/// To bake to a different type than this, use `custom_bake`
39+
/// and implement `CustomBake`.
40+
///
41+
/// ```rust
42+
/// use databake::Bake;
43+
/// use databake::CustomBake;
44+
///
45+
/// #[derive(Bake)]
46+
/// #[databake(path = bar::module)]
47+
/// #[databake(path = custom_bake)]
48+
/// pub struct Message<'a> {
49+
/// pub message: &'a str,
50+
/// }
51+
///
52+
/// // Bake to a string:
53+
/// impl CustomBake for Message<'_> {
54+
/// type BakedType<'a> = &'a str where Self: 'a;
55+
/// fn to_custom_bake(&self) -> Self::BakedType<'_> {
56+
/// &self.message
57+
/// }
58+
/// }
59+
///
60+
/// impl<'a> Message<'a> {
61+
/// pub fn from_custom_bake(message: &'a str) -> Self {
62+
/// Self { message }
63+
/// }
64+
/// }
65+
/// ```
66+
///
67+
/// If the constructor is unsafe, use `custom_bake_unsafe`
68+
/// and implement `CustomBakeUnsafe`.
3469
#[proc_macro_derive(Bake, attributes(databake))]
3570
pub fn bake_derive(input: TokenStream) -> TokenStream {
3671
let input = parse_macro_input!(input as DeriveInput);
@@ -40,44 +75,93 @@ pub fn bake_derive(input: TokenStream) -> TokenStream {
4075
fn bake_derive_impl(input: &DeriveInput) -> TokenStream2 {
4176
let mut structure = Structure::new(input);
4277

43-
struct PathAttr(Punctuated<PathSegment, Token![::]>);
78+
enum DatabakeAttr {
79+
Path(Punctuated<PathSegment, Token![::]>),
80+
CustomBake,
81+
CustomBakeUnsafe,
82+
}
4483

45-
impl Parse for PathAttr {
84+
impl Parse for DatabakeAttr {
4685
fn parse(input: ParseStream<'_>) -> syn::parse::Result<Self> {
4786
let i: Ident = input.parse()?;
48-
if i != "path" {
49-
return Err(input.error(format!("expected token \"path\", found {i:?}")));
87+
if i == "path" {
88+
input.parse::<Token![=]>()?;
89+
Ok(Self::Path(input.parse::<Path>()?.segments))
90+
} else if i == "custom_bake" {
91+
Ok(Self::CustomBake)
92+
} else if i == "custom_bake_unsafe" {
93+
Ok(Self::CustomBakeUnsafe)
94+
} else {
95+
Err(input.error(format!("expected token \"path\", found {i:?}")))
5096
}
51-
input.parse::<Token![=]>()?;
52-
Ok(Self(input.parse::<Path>()?.segments))
5397
}
5498
}
5599

56-
let path = input
100+
let attrs = input
57101
.attrs
58102
.iter()
59-
.find(|a| a.path().is_ident("databake"))
60-
.expect("missing databake(path = ...) attribute")
61-
.parse_args::<PathAttr>()
62-
.unwrap()
63-
.0;
103+
.filter(|a| a.path().is_ident("databake"))
104+
.map(|a| a.parse_args::<DatabakeAttr>().unwrap())
105+
.collect::<Vec<_>>();
64106

65-
let bake_body = structure.each_variant(|vi| {
66-
let recursive_calls = vi.bindings().iter().map(|b| {
67-
let ident = b.binding.clone();
68-
quote! { let #ident = #ident.bake(env); }
69-
});
107+
let path = attrs
108+
.iter()
109+
.filter_map(|a| match a {
110+
DatabakeAttr::Path(path) => Some(path),
111+
_ => None,
112+
})
113+
.next()
114+
.expect("missing databake(path = ...) attribute");
70115

71-
let constructor = vi.construct(|_, i| {
72-
let ident = &vi.bindings()[i].binding;
73-
quote! { # #ident }
74-
});
116+
let is_custom_bake = attrs
117+
.iter()
118+
.find(|a| matches!(a, DatabakeAttr::CustomBake))
119+
.is_some();
75120

76-
quote! {
77-
#(#recursive_calls)*
78-
databake::quote! { #path::#constructor }
121+
let is_custom_bake_unsafe = attrs
122+
.iter()
123+
.find(|a| matches!(a, DatabakeAttr::CustomBakeUnsafe))
124+
.is_some();
125+
126+
let bake_body = if is_custom_bake || is_custom_bake_unsafe {
127+
let type_ident = &structure.ast().ident;
128+
let baked_ident = format_ident!("baked");
129+
if is_custom_bake_unsafe {
130+
quote! {
131+
x => {
132+
let baked = databake::CustomBakeUnsafe::to_custom_bake(x).bake(env);
133+
databake::quote! {
134+
// Safety: the bake is generated from `CustomBakeUnsafe::to_custom_bake`
135+
unsafe { #path::#type_ident::from_custom_bake(##baked_ident) }
136+
}
137+
}
138+
}
139+
} else {
140+
quote! {
141+
x => {
142+
let baked = databake::CustomBake::to_custom_bake(x).bake(env);
143+
databake::quote! { #path::#type_ident::from_custom_bake(##baked_ident) }
144+
}
145+
}
79146
}
80-
});
147+
} else {
148+
structure.each_variant(|vi| {
149+
let recursive_calls = vi.bindings().iter().map(|b| {
150+
let ident = b.binding.clone();
151+
quote! { let #ident = #ident.bake(env); }
152+
});
153+
154+
let constructor = vi.construct(|_, i| {
155+
let ident = &vi.bindings()[i].binding;
156+
quote! { # #ident }
157+
});
158+
159+
quote! {
160+
#(#recursive_calls)*
161+
databake::quote! { #path::#constructor }
162+
}
163+
})
164+
};
81165

82166
let borrows_size_body = structure.each_variant(|vi| {
83167
let recursive_calls = vi.bindings().iter().map(|b| {

utils/databake/derive/tests/derive.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,86 @@ fn test_cow_example() {
6060
test
6161
);
6262
}
63+
64+
#[derive(Bake)]
65+
#[databake(path = test)]
66+
#[databake(custom_bake)]
67+
pub struct CustomBakeExample<'a> {
68+
x: usize,
69+
y: alloc::borrow::Cow<'a, str>,
70+
}
71+
72+
impl CustomBake for CustomBakeExample<'_> {
73+
type BakedType<'a>
74+
= (usize, &'a str)
75+
where
76+
Self: 'a;
77+
fn to_custom_bake(&self) -> Self::BakedType<'_> {
78+
(self.x, &*self.y)
79+
}
80+
}
81+
82+
impl<'a> CustomBakeExample<'a> {
83+
pub const fn from_custom_bake(baked: <Self as CustomBake>::BakedType<'a>) -> Self {
84+
Self {
85+
x: baked.0,
86+
y: alloc::borrow::Cow::Borrowed(baked.1),
87+
}
88+
}
89+
}
90+
91+
#[test]
92+
fn test_custom_bake_example() {
93+
test_bake!(
94+
CustomBakeExample<'static>,
95+
const,
96+
crate::CustomBakeExample {
97+
x: 51423usize,
98+
y: alloc::borrow::Cow::Borrowed("bar"),
99+
},
100+
crate::CustomBakeExample::from_custom_bake((51423usize, "bar")),
101+
test
102+
);
103+
}
104+
105+
#[derive(Bake)]
106+
#[databake(path = test)]
107+
#[databake(custom_bake_unsafe)]
108+
pub struct CustomBakeUnsafeExample<'a> {
109+
x: usize,
110+
y: alloc::borrow::Cow<'a, str>,
111+
}
112+
113+
114+
unsafe impl CustomBakeUnsafe for CustomBakeUnsafeExample<'_> {
115+
type BakedType<'a>
116+
= (usize, &'a str)
117+
where
118+
Self: 'a;
119+
fn to_custom_bake(&self) -> Self::BakedType<'_> {
120+
(self.x, &*self.y)
121+
}
122+
}
123+
124+
impl<'a> CustomBakeUnsafeExample<'a> {
125+
pub const unsafe fn from_custom_bake(baked: <Self as CustomBakeUnsafe>::BakedType<'a>) -> Self {
126+
Self {
127+
x: baked.0,
128+
y: alloc::borrow::Cow::Borrowed(baked.1),
129+
}
130+
}
131+
}
132+
133+
#[test]
134+
fn test_custom_bake_unsafe_example() {
135+
test_bake!(
136+
CustomBakeUnsafeExample<'static>,
137+
const,
138+
crate::CustomBakeUnsafeExample {
139+
x: 51423usize,
140+
y: alloc::borrow::Cow::Borrowed("bar"),
141+
},
142+
unsafe { crate::CustomBakeUnsafeExample::from_custom_bake((51423usize, "bar")) },
143+
test
144+
);
145+
}

utils/databake/src/custom_bake.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// This file is part of ICU4X. For terms of use, please see the file
2+
// called LICENSE at the top level of the ICU4X source tree
3+
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4+
5+
use crate::Bake;
6+
7+
/// A trait for an item that can bake to something other than itself.
8+
///
9+
/// For an unsafe version of this trait, see [`CustomBakeUnsafe`].
10+
///
11+
/// The type implementing this trait should have an associated function
12+
/// with the following signature:
13+
///
14+
/// ```ignore
15+
/// /// The argument should have been returned from [`Self::to_custom_bake`].
16+
/// pub fn from_custom_bake(baked: CustomBake::BakedType) -> Self
17+
/// ```
18+
pub trait CustomBake {
19+
/// The type of the custom bake.
20+
type BakedType<'a>: Bake
21+
where
22+
Self: 'a;
23+
/// Returns `self` as the custom bake type.
24+
fn to_custom_bake(&self) -> Self::BakedType<'_>;
25+
}
26+
27+
/// Same as [`CustomBake`] but allows for the constructor to be `unsafe`.
28+
///
29+
/// # Safety
30+
///
31+
/// The type implementing this trait MUST have an associated unsafe function
32+
/// with the following signature:
33+
///
34+
/// ```ignore
35+
/// /// Safety: the argument MUST have been returned from [`Self::to_custom_bake`].
36+
/// pub unsafe fn from_custom_bake(baked: CustomBakeUnsafe::BakedType) -> Self
37+
/// ```
38+
pub unsafe trait CustomBakeUnsafe {
39+
/// The type of the custom bake.
40+
type BakedType<'a>: Bake
41+
where
42+
Self: 'a;
43+
/// Returns `self` as the custom bake type.
44+
fn to_custom_bake(&self) -> Self::BakedType<'_>;
45+
}

utils/databake/src/lib.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
7575
mod alloc;
7676
pub mod converter;
77+
mod custom_bake;
7778
mod primitives;
7879

7980
#[doc(no_inline)]
@@ -85,6 +86,8 @@ pub use quote::quote;
8586
#[cfg(feature = "derive")]
8687
pub use databake_derive::Bake;
8788

89+
pub use custom_bake::*;
90+
8891
use std::collections::HashSet;
8992
use std::sync::Mutex;
9093

@@ -145,6 +148,27 @@ pub trait BakeSize: Sized + Bake {
145148
/// test_bake!(usize, const, 18usize);
146149
/// ```
147150
///
151+
/// ## Custom baked type
152+
///
153+
/// If the baked type is different than the input type, pass both as arguments:
154+
///
155+
/// ```no_run
156+
/// # use databake::*;
157+
/// # struct MyStruct(usize);
158+
/// # impl Bake for MyStruct {
159+
/// # fn bake(&self, _: &CrateEnv) -> TokenStream { unimplemented!() }
160+
/// # }
161+
/// # // We need an explicit main to put the struct at the crate root
162+
/// # fn main() {
163+
/// test_bake!(
164+
/// MyStruct,
165+
/// crate::MyStruct(42usize),
166+
/// crate::MyStruct::from_custom_bake(42usize),
167+
/// my_crate,
168+
/// );
169+
/// # }
170+
/// ```
171+
///
148172
/// ## Crates and imports
149173
///
150174
/// As most output will need to reference its crate, and its not possible to name a crate from
@@ -172,16 +196,24 @@ pub trait BakeSize: Sized + Bake {
172196
#[macro_export]
173197
macro_rules! test_bake {
174198
($type:ty, const, $expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => {
175-
const _: &$type = &$expr;
176-
$crate::test_bake!($type, $expr $(, $krate)? $(, [$($env_crate),+])?);
199+
$crate::test_bake!($type, const, $expr, $expr $(, $krate)? $(, [$($env_crate),+])?);
177200
};
178201

179202
($type:ty, $expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => {
203+
$crate::test_bake!($type, $expr, $expr $(, $krate)? $(, [$($env_crate),+])?);
204+
};
205+
206+
($type:ty, const, $init_expr:expr, $baked_expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => {
207+
const _: &$type = &$baked_expr;
208+
$crate::test_bake!($type, $init_expr, $baked_expr $(, $krate)? $(, [$($env_crate),+])?);
209+
};
210+
211+
($type:ty, $init_expr:expr, $baked_expr:expr $(, $krate:ident)? $(, [$($env_crate:ident),+])? $(,)?) => {
180212
let env = Default::default();
181-
let expr: &$type = &$expr;
213+
let expr: &$type = &$init_expr;
182214
let bake = $crate::Bake::bake(expr, &env).to_string();
183215
// For some reason `TokenStream` behaves differently in this line
184-
let expected_bake = $crate::quote!($expr).to_string().replace("::<", ":: <").replace(">::", "> ::");
216+
let expected_bake = $crate::quote!($baked_expr).to_string().replace("::<", ":: <").replace(">::", "> ::");
185217
// Trailing commas are a mess as well
186218
let bake = bake.replace(" ,)", ")").replace(" ,]", "]").replace(" , }", " }").replace(" , >", " >");
187219
let expected_bake = expected_bake.replace(" ,)", ")").replace(" ,]", "]").replace(" , }", " }").replace(" , >", " >");

0 commit comments

Comments
 (0)