proc-macro-workshop:debug-7

审题

// This test case covers one more heuristic that is often worth incorporating
// into derive macros that infer trait bounds. Here we look for the use of an
// associated type of a type parameter.
//
// The generated impl will need to look like:
//
//     impl<T: Trait> Debug for Field<T>
//     where
//         T::Value: Debug,
//     {...}
//
// You can identify associated types as any syn::TypePath in which the first
// path segment is one of the type parameters and there is more than one
// segment.
//
//
// Resources:
//
//   - The relevant types in the input will be represented in this syntax tree
//     node: https://docs.rs/syn/1.0/syn/struct.TypePath.html

use derive_debug::CustomDebug;
use std::fmt::Debug;

pub trait Trait {
    type Value;
}

#[derive(CustomDebug)]
pub struct Field<T: Trait> {
    values: Vec<T::Value>,
}

fn assert_debug<F: Debug>() {}

fn main() {
    // Does not implement Debug, but its associated type does.
    struct Id;

    impl Trait for Id {
        type Value = u8;
    }

    assert_debug::<Field<Id>>();
}

这里出了一个奇葩的情况,那就是A:: B::C

之前我们对于泛型的限定中,如果有没有被PhantomData修饰,就会限定T::Debug
但是这里出现了一个奇葩的情况T::Value,也就是说,我们涉及到了T,但是却没有直接使用T,实际使用的是T::Value这种关联类型。

回顾一下我们对于过滤后的泛型的处理方式

    for generic in generics.params.iter_mut() {
        if let syn::GenericParam::Type(t) = generic {
            let type_param_name = t.ident.to_string();
            if phantom_generic_type_names.contains(&type_param_name)
                && !fields_type_names.contains(&type_param_name)
            {
                continue;
            }
            t.bounds.push(syn::parse_quote!(std::fmt::Debug));
        }
    }

很明显的出现一个问题:我们会将泛型全部限定为T::Debug,即使它只在声明中出现,即使它并未参与字段声明。

因此,我们当前的任务是:更细粒度的控制泛型约束。
这道题中,我们约束T::Value却忽略T,因为T并没有直接参与泛型字段声明。

其中涉及两点

  • 关联类型的提取
  • where_clause修改

提取

自动遍历

// common.rs
struct TypePathVisitor {
    interst_generic_type_names: Vec<String>,
    associated_type_names: std::collections::HashMap<String, Vec<syn::TypePath>>,
}
// 需要启动visit: syn = { version =  "1.0.84", features = ["visit"]}
impl<'ast> syn::visit::Visit<'ast> for TypePathVisitor {
    fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
        // 路径必定大于1 A::B
        if node.path.segments.len() > 1 {
            // 外部泛型声明
            let generic_type_name = node.path.segments[0].ident.to_string();
            // 以外部泛型声明开头的泛型全路径
            if self.interst_generic_type_names.contains(&generic_type_name) {
                self.associated_type_names
                    .entry(generic_type_name)
                    .or_insert(vec![])
                    .push(node.clone());
            }
        }
        syn::visit::visit_type_path(self, node);
    }
}

自动遍历并且收集我们感兴趣的数据,可以参考官方文档

// common.rs
pub(crate) fn parse_generic_associated_types(
    ast: &syn::DeriveInput,
) -> std::collections::HashMap<String, Vec<syn::TypePath>> {
    let origin_generic_type_names: Vec<String> = ast
        .generics
        .params
        .iter()
        .filter_map(|f| {
            if let syn::GenericParam::Type(t) = f {
                return Some(t.ident.to_string());
            }
            return None;
        })
        .collect();
    let mut visitor = TypePathVisitor {
        interst_generic_type_names: origin_generic_type_names,
        associated_type_names: std::collections::HashMap::new(),
    };
    visitor.visit_derive_input(ast);
    return visitor.associated_type_names;
}

到这里,我们已经收集到了以最顶级泛型声明开头的其他泛型,后续就是定制where_clause

题解

// solution7.rs
pub(super) fn soution(
    fields: &crate::common::FieldsType,
    origin_ident: &syn::Ident,
    ast: &syn::DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
    let mut origin_field_type_names = vec![];
    let mut phantom_generic_type_names = vec![];
    for field in fields.iter() {
        if let Some(origin_field_type_name) = crate::common::parse_field_type_name(field)? {
            origin_field_type_names.push(origin_field_type_name);
        }
        if let Some(phantom_generic_type_name) =
            crate::common::parse_phantom_generic_type_name(field)?
        {
            phantom_generic_type_names.push(phantom_generic_type_name);
        }
    }
    let associated_generics_type_map = crate::common::parse_generic_associated_types(ast);
    let mut generics = crate::common::parse_generic_type(ast);
    // 限定非PhatomData和非T::Value的限定
    for generic in generics.params.iter_mut() {
        if let syn::GenericParam::Type(t) = generic {
            let type_name = t.ident.to_string();
            if phantom_generic_type_names.contains(&type_name)
                && !origin_field_type_names.contains(&type_name)
            {
                continue;
            }
            if associated_generics_type_map.contains_key(&type_name)
                && !origin_field_type_names.contains(&type_name)
            {
                continue;
            }
            t.bounds.push(syn::parse_quote!(std::fmt::Debug));
        }
    }
    // 自定义where_clause
    generics.make_where_clause();
    for (_, associated_types) in associated_generics_type_map {
        for associated_type in associated_types {
            generics
                .where_clause
                .as_mut()
                .unwrap()
                .predicates
                // 限定关联的泛型类型
                .push(syn::parse_quote!(#associated_type:std::fmt::Debug));
        }
    }

    let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
    let fields_strea_vec = generate_field_stream_vec(fields)?;
    let origin_ident_string = origin_ident.to_string();
    // 照旧
    syn::Result::Ok(quote::quote! {
        impl #impl_generics std::fmt::Debug for #origin_ident #type_generics #where_clause {
            fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
                fmt.debug_struct(#origin_ident_string)
                #(
                    #fields_strea_vec
                )*
                .finish()
            }
        }
    })
}
// 无关的debug字段设置
fn generate_field_stream_vec(
    fields: &crate::common::FieldsType,
) -> syn::Result<Vec<proc_macro2::TokenStream>> {
    fields
        .iter()
        .map(|f| {
            let ident = &f.ident;
            let ident_string = ident.as_ref().unwrap().to_string();
            let mut format = "{:?}".to_string();
            if let Some(customer_format) = crate::common::parse_format(f)? {
                format = customer_format;
            }
            syn::Result::Ok(quote::quote! {
                .field(#ident_string, &format_args!(#format, &self.#ident))
            })
        })
        .collect()
}

完整

mod common;
mod solution2;
mod solution3;
mod solution4;
mod solution56;
mod solution7;
mod solution8;

#[proc_macro_derive(CustomDebug, attributes(debug))]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
    match solution1(&ast) {
        syn::Result::Ok(token_stream) => {
            return proc_macro::TokenStream::from(token_stream);
        }
        syn::Result::Err(e) => {
            return e.into_compile_error().into();
        }
    }
}

fn solution1(ast: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
    let origin_ident = &ast.ident;
    let fields = crate::common::parse_fields(&ast)?;
    // soluton2
    let _ = solution2::solution(fields, origin_ident)?;

    let _ = solution3::solution(fields, origin_ident)?;

    let _ = solution4::solution(fields, origin_ident, ast)?;

    let _ = solution56::solution(fields, origin_ident, ast)?;

    let token_stream = solution7::soution(fields, origin_ident, ast)?;

    syn::Result::Ok(token_stream)
}

可以cargo expand观察一下结果

#![feature(prelude_import)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use derive_debug::CustomDebug;
use std::fmt::Debug;
pub trait Trait {
    type Value;
}
pub struct Field<T: Trait> {
    values: Vec<T::Value>,
}
impl<T: Trait> std::fmt::Debug for Field<T>
where
    T::Value: std::fmt::Debug,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        fmt.debug_struct("Field")
            .field(
                "values",
                &::core::fmt::Arguments::new_v1(
                    &[""],
                    &match (&&self.values,) {
                        _args => [::core::fmt::ArgumentV1::new(
                            _args.0,
                            ::core::fmt::Debug::fmt,
                        )],
                    },
                ),
            )
            .finish()
    }
}
fn assert_debug<F: Debug>() {}
fn main() {
    struct Id;
    impl Trait for Id {
        type Value = u8;
    }
    assert_debug::<Field<Id>>();
}

可以看到where T:: Value: std:: fmt::Debug单独的限定了T::Value

本作品采用《CC 协议》,转载必须注明作者和本文链接
讨论数量: 0
(= ̄ω ̄=)··· 暂无内容!

讨论应以学习和精进为目的。请勿发布不友善或者负能量的内容,与人为善,比聪明更重要!