|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use proc_macro2::TokenStream; |
| 18 | +use proc_macro2::{Literal, TokenStream}; |
19 | 19 | use quote::{format_ident, quote, quote_spanned, ToTokens};
|
20 | 20 | use syn::spanned::Spanned;
|
21 | 21 | use syn::{
|
22 | 22 | parse::{Parse, ParseStream},
|
23 | 23 | parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
|
24 | 24 | Ident, Index, LitStr, Meta, Token, Type, TypePath,
|
25 | 25 | };
|
26 |
| -use syn::{Path, PathArguments}; |
| 26 | +use syn::{ |
| 27 | + AngleBracketedGenericArguments, DataEnum, DataStruct, FieldsNamed, FieldsUnnamed, |
| 28 | + GenericArgument, MetaList, Path, PathArguments, PathSegment, |
| 29 | +}; |
27 | 30 |
|
28 | 31 | /// Implementation of `[#derive(Visit)]`
|
29 | 32 | #[proc_macro_derive(VisitMut, attributes(visit))]
|
@@ -289,3 +292,323 @@ fn is_option(ty: &Type) -> bool {
|
289 | 292 | }
|
290 | 293 | false
|
291 | 294 | }
|
| 295 | + |
| 296 | +/// Determine the variable type to decide which method in the `Convert` trait to use |
| 297 | +fn get_var_type(ty: &Type) -> TokenStream { |
| 298 | + let span = ty.span(); |
| 299 | + if let Type::Path(TypePath { |
| 300 | + path: Path { segments, .. }, |
| 301 | + .. |
| 302 | + }) = ty |
| 303 | + { |
| 304 | + if let Some(PathSegment { ident, arguments }) = segments.first() { |
| 305 | + return match ident.to_string().as_str() { |
| 306 | + "Option" => { |
| 307 | + if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { |
| 308 | + args, |
| 309 | + .. |
| 310 | + }) = arguments |
| 311 | + { |
| 312 | + if let Some(GenericArgument::Type(Type::Path(TypePath { |
| 313 | + path: Path { segments, .. }, |
| 314 | + .. |
| 315 | + }))) = args.first() |
| 316 | + { |
| 317 | + if let Some(PathSegment { ident, .. }) = segments.first() { |
| 318 | + return match ident.to_string().as_str() { |
| 319 | + "Box" => quote_spanned!(span => Convert::convert_option_box), |
| 320 | + "Vec" => quote_spanned!(span => Convert::convert_option_vec), |
| 321 | + _ => quote_spanned!(span => Convert::convert_option), |
| 322 | + }; |
| 323 | + } |
| 324 | + } |
| 325 | + } |
| 326 | + quote_spanned!(span => Convert::convert_option) |
| 327 | + } |
| 328 | + "Vec" => { |
| 329 | + if let PathArguments::AngleBracketed(AngleBracketedGenericArguments { |
| 330 | + args, |
| 331 | + .. |
| 332 | + }) = arguments |
| 333 | + { |
| 334 | + if let Some(GenericArgument::Type(Type::Path(TypePath { |
| 335 | + path: Path { segments, .. }, |
| 336 | + .. |
| 337 | + }))) = args.first() |
| 338 | + { |
| 339 | + if let Some(PathSegment { ident, .. }) = segments.first() { |
| 340 | + return match ident.to_string().as_str() { |
| 341 | + "Vec" => quote_spanned!(span => Convert::convert_matrix), |
| 342 | + "Box" => quote_spanned!(span => Convert::convert_vec_box), |
| 343 | + _ => quote_spanned!(span => Convert::convert_vec), |
| 344 | + }; |
| 345 | + } |
| 346 | + } |
| 347 | + } |
| 348 | + quote_spanned!(span => Convert::convert_vec) |
| 349 | + } |
| 350 | + "Box" => quote_spanned!(span => Convert::convert_box), |
| 351 | + _ => quote_spanned!(span => Convert::convert), |
| 352 | + }; |
| 353 | + } |
| 354 | + } |
| 355 | + quote_spanned!(span => Convert::convert) |
| 356 | +} |
| 357 | + |
| 358 | +/// Obtain the struct path where `datafusion` `sqlparser` is located from derive macro helper attribute `df_path`, |
| 359 | +/// if value not given, the default return is `df_sqlparser::ast` |
| 360 | +fn get_crate_path(st: &DeriveInput) -> TokenStream { |
| 361 | + let span = st.span(); |
| 362 | + for attr in &st.attrs { |
| 363 | + let Meta::List(MetaList { |
| 364 | + path: Path { segments, .. }, |
| 365 | + tokens, |
| 366 | + .. |
| 367 | + }) = &attr.meta |
| 368 | + else { |
| 369 | + continue; |
| 370 | + }; |
| 371 | + if let Some(PathSegment { ident, .. }) = segments.first() { |
| 372 | + if ident.to_string().as_str() == "df_path" { |
| 373 | + return tokens.clone(); |
| 374 | + } |
| 375 | + } |
| 376 | + } |
| 377 | + quote_spanned!(span => df_sqlparser::ast) |
| 378 | +} |
| 379 | + |
| 380 | +/// Check whether the attribute `ignore_item` exists. If the attribute exists, |
| 381 | +/// the corresponding convert method will not be generated. |
| 382 | +/// If exist attribute `ignore_item` |
| 383 | +/// 1. enum conversion returns panic |
| 384 | +/// 2. struct conversion does not generate the corresponding field |
| 385 | +fn ignore_convert(attrs: &Vec<Attribute>) -> bool { |
| 386 | + for attr in attrs { |
| 387 | + let Meta::Path(Path { segments, .. }) = &attr.meta else { |
| 388 | + continue; |
| 389 | + }; |
| 390 | + if let Some(PathSegment { ident, .. }) = segments.first() { |
| 391 | + if ident.to_string().as_str() == "ignore_item" { |
| 392 | + return true; |
| 393 | + } |
| 394 | + } |
| 395 | + } |
| 396 | + false |
| 397 | +} |
| 398 | + |
| 399 | +fn convert_struct(st: &DeriveInput) -> TokenStream { |
| 400 | + let name = &st.ident; |
| 401 | + let path = get_crate_path(st); |
| 402 | + // for struct pattern like |
| 403 | + // struct xxx { |
| 404 | + // xxx: xxx |
| 405 | + // } |
| 406 | + if let Data::Struct(DataStruct { |
| 407 | + fields: Fields::Named(FieldsNamed { named, .. }), |
| 408 | + .. |
| 409 | + }) = &st.data |
| 410 | + { |
| 411 | + let span = named.span(); |
| 412 | + let mut fields: Vec<TokenStream> = Vec::with_capacity(named.len()); |
| 413 | + for field in named { |
| 414 | + if ignore_convert(&field.attrs) { |
| 415 | + continue; |
| 416 | + } |
| 417 | + let field_name = field.ident.clone().unwrap(); |
| 418 | + let var_type = get_var_type(&field.ty); |
| 419 | + let span = field_name.span(); |
| 420 | + let code = quote_spanned! { span => |
| 421 | + #field_name: #var_type(value.#field_name), |
| 422 | + }; |
| 423 | + fields.push(code); |
| 424 | + } |
| 425 | + return quote_spanned! { span => |
| 426 | + impl From<#name> for #path::#name { |
| 427 | + #[allow(unused_variables)] |
| 428 | + fn from(value: #name) -> Self { |
| 429 | + Self { |
| 430 | + #(#fields)* |
| 431 | + } |
| 432 | + } |
| 433 | + } |
| 434 | + }; |
| 435 | + } |
| 436 | + // for struct pattern like |
| 437 | + // struct xxx(xxxx); |
| 438 | + if let Data::Struct(DataStruct { |
| 439 | + fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), |
| 440 | + .. |
| 441 | + }) = &st.data |
| 442 | + { |
| 443 | + let span = unnamed.span(); |
| 444 | + let mut fields: Vec<TokenStream> = Vec::with_capacity(unnamed.len()); |
| 445 | + for i in 0..unnamed.len() { |
| 446 | + if ignore_convert(&unnamed[i].attrs) { |
| 447 | + continue; |
| 448 | + } |
| 449 | + let field_name = Literal::usize_unsuffixed(i); |
| 450 | + let var_type = get_var_type(&unnamed[i].ty); |
| 451 | + let span = unnamed[i].span(); |
| 452 | + let code = quote_spanned! { span => |
| 453 | + #var_type(value.#field_name), |
| 454 | + }; |
| 455 | + fields.push(code); |
| 456 | + } |
| 457 | + return quote_spanned! { span => |
| 458 | + impl From<#name> for #path::#name { |
| 459 | + #[allow(unused_variables)] |
| 460 | + fn from(value: #name) -> Self { |
| 461 | + Self(#(#fields)*) |
| 462 | + } |
| 463 | + } |
| 464 | + }; |
| 465 | + } |
| 466 | + panic!("Unrecognised Struct Type{}", st.to_token_stream()) |
| 467 | +} |
| 468 | + |
| 469 | +fn convert_enum(st: &DeriveInput) -> TokenStream { |
| 470 | + let name = &st.ident; |
| 471 | + let path = get_crate_path(st); |
| 472 | + if let Data::Enum(DataEnum { variants, .. }) = &st.data { |
| 473 | + let span = variants.span(); |
| 474 | + let mut fields: Vec<TokenStream> = Vec::with_capacity(variants.len()); |
| 475 | + for field in variants { |
| 476 | + let enum_name = &field.ident; |
| 477 | + let span = enum_name.span(); |
| 478 | + let ignore_convert = ignore_convert(&field.attrs); |
| 479 | + // for enum item like xxxxxx(xxx) |
| 480 | + if let Fields::Unnamed(FieldsUnnamed { unnamed, .. }) = &field.fields { |
| 481 | + let inner_names = ('a'..='z') |
| 482 | + .map(|x| Ident::new(x.to_string().as_str(), unnamed.span())) |
| 483 | + .collect::<Vec<_>>()[..unnamed.len()] |
| 484 | + .to_vec(); |
| 485 | + let mut codes: Vec<TokenStream> = Vec::with_capacity(unnamed.len()); |
| 486 | + let inner_fields: Vec<_> = inner_names.iter().map(|x| quote!(#x,)).collect(); |
| 487 | + for (inner_name, field) in inner_names.iter().zip(unnamed.iter()) { |
| 488 | + let var_type = get_var_type(&field.ty); |
| 489 | + let span = field.span(); |
| 490 | + codes.push(quote_spanned! { span => |
| 491 | + #var_type(#inner_name), |
| 492 | + }); |
| 493 | + } |
| 494 | + fields.push(if ignore_convert { |
| 495 | + quote_spanned! { span => |
| 496 | + #name::#enum_name(#(#inner_fields)*) => panic!("Convert on this item is ignored"), |
| 497 | + } |
| 498 | + } else { |
| 499 | + quote_spanned! { span => |
| 500 | + #name::#enum_name(#(#inner_fields)*) => Self::#enum_name(#(#codes)*), |
| 501 | + } |
| 502 | + }); |
| 503 | + } |
| 504 | + // for enum item like |
| 505 | + // xxxxxx { |
| 506 | + // xxx: xxxx, |
| 507 | + // }, |
| 508 | + if let Fields::Named(FieldsNamed { named, .. }) = &field.fields { |
| 509 | + let mut inner_fields: Vec<TokenStream> = Vec::with_capacity(named.len()); |
| 510 | + let mut codes: Vec<TokenStream> = Vec::with_capacity(named.len()); |
| 511 | + let span = named.span(); |
| 512 | + for field in named { |
| 513 | + let field_name = field.ident.clone().unwrap(); |
| 514 | + let span = field_name.span(); |
| 515 | + let var_type = get_var_type(&field.ty); |
| 516 | + inner_fields.push(quote_spanned!(span => #field_name,)); |
| 517 | + codes.push(quote_spanned! { span => |
| 518 | + #field_name: #var_type(#field_name), |
| 519 | + }); |
| 520 | + } |
| 521 | + fields.push(if ignore_convert { |
| 522 | + quote_spanned! { span => |
| 523 | + #name::#enum_name{#(#inner_fields)*} => panic!("Convert on this item is ignored"), |
| 524 | + } |
| 525 | + } else { |
| 526 | + quote_spanned! { span => |
| 527 | + #name::#enum_name{#(#inner_fields)*} => Self::#enum_name{#(#codes)*}, |
| 528 | + } |
| 529 | + }); |
| 530 | + } |
| 531 | + // for enum item like |
| 532 | + // xxxxxx |
| 533 | + if let Fields::Unit = &field.fields { |
| 534 | + let span = field.span(); |
| 535 | + fields.push(if ignore_convert { |
| 536 | + quote_spanned! { span => |
| 537 | + #name::#enum_name => panic!("Convert on this item is ignored"), |
| 538 | + } |
| 539 | + } else { |
| 540 | + quote_spanned! { span => |
| 541 | + #name::#enum_name => Self::#enum_name, |
| 542 | + } |
| 543 | + }); |
| 544 | + } |
| 545 | + } |
| 546 | + return quote_spanned! { span => |
| 547 | + impl From<#name> for #path::#name { |
| 548 | + #[allow(unused_variables)] |
| 549 | + fn from(value: #name) -> Self { |
| 550 | + match value{ |
| 551 | + #(#fields)* |
| 552 | + } |
| 553 | + } |
| 554 | + } |
| 555 | + }; |
| 556 | + } |
| 557 | + panic!("Unrecognised Enum Type{}", st.to_token_stream()) |
| 558 | +} |
| 559 | + |
| 560 | +fn convert_union(st: &DeriveInput) -> TokenStream { |
| 561 | + let name = &st.ident; |
| 562 | + let path = get_crate_path(st); |
| 563 | + |
| 564 | + if let Data::Union(data_union) = &st.data { |
| 565 | + let span = data_union.fields.span(); |
| 566 | + let mut fields: Vec<TokenStream> = Vec::with_capacity(data_union.fields.named.len()); |
| 567 | + |
| 568 | + for field in &data_union.fields.named { |
| 569 | + if ignore_convert(&field.attrs) { |
| 570 | + continue; |
| 571 | + } |
| 572 | + let field_name = field.ident.clone().unwrap(); |
| 573 | + let var_type = get_var_type(&field.ty); |
| 574 | + let span = field_name.span(); |
| 575 | + let code = quote_spanned! { span => |
| 576 | + #field_name: unsafe { #var_type(value.#field_name) }, |
| 577 | + }; |
| 578 | + fields.push(code); |
| 579 | + } |
| 580 | + |
| 581 | + quote_spanned! { span => |
| 582 | + impl From<#name> for #path::#name { |
| 583 | + #[allow(unused_variables)] |
| 584 | + fn from(value: #name) -> Self { |
| 585 | + unsafe { |
| 586 | + Self { |
| 587 | + #(#fields)* |
| 588 | + } |
| 589 | + } |
| 590 | + } |
| 591 | + } |
| 592 | + } |
| 593 | + } else { |
| 594 | + panic!("Expected Union type") |
| 595 | + } |
| 596 | +} |
| 597 | + |
| 598 | +fn expand_df_convert(st: &DeriveInput) -> TokenStream { |
| 599 | + match st.data { |
| 600 | + Data::Struct(_) => convert_struct(st), |
| 601 | + Data::Enum(_) => convert_enum(st), |
| 602 | + Data::Union(_) => convert_union(st), |
| 603 | + } |
| 604 | +} |
| 605 | + |
| 606 | +/// Derive macro to implement `From` Trait. Convert the current sqlparser struct to the struct used by datafusion sqlparser. |
| 607 | +/// There are two helper attributes that can be marked on the derive struct/enum, affecting the generated Convert function |
| 608 | +/// 1. `#[df_path(....)]`: Most structures are defined in `df_sqlparser::ast`, if the path of some structures is not in this path, |
| 609 | +/// user need to specify `df_path` to tell the compiler the location of this struct/enum |
| 610 | +/// 2. `#[ignore_item]`: Marked on the field of the struct/enum, indicating that the Convert method of the field of the struct/enum is not generated· |
| 611 | +#[proc_macro_derive(DFConvert, attributes(df_path, ignore_item))] |
| 612 | +pub fn derive_df_convert(input: proc_macro::TokenStream) -> proc_macro::TokenStream { |
| 613 | + expand_df_convert(&parse_macro_input!(input as DeriveInput)).into() |
| 614 | +} |
0 commit comments