proc-macro-workshop:debug-7
审题
// This test case covers one more heuristic that is often worth incorporating
// into derive macros that infer trait bounds. Here we look for the use of an
// associated type of a type parameter.
//
// The generated impl will need to look like:
//
// impl<T: Trait> Debug for Field<T>
// where
// T::Value: Debug,
// {...}
//
// You can identify associated types as any syn::TypePath in which the first
// path segment is one of the type parameters and there is more than one
// segment.
//
//
// Resources:
//
// - The relevant types in the input will be represented in this syntax tree
// node: https://docs.rs/syn/1.0/syn/struct.TypePath.html
use derive_debug::CustomDebug;
use std::fmt::Debug;
pub trait Trait {
type Value;
}
#[derive(CustomDebug)]
pub struct Field<T: Trait> {
values: Vec<T::Value>,
}
fn assert_debug<F: Debug>() {}
fn main() {
// Does not implement Debug, but its associated type does.
struct Id;
impl Trait for Id {
type Value = u8;
}
assert_debug::<Field<Id>>();
}
这里出了一个奇葩的情况,那就是A:: B::C
。
之前我们对于泛型的限定中,如果有没有被PhantomData
修饰,就会限定T::Debug
。
但是这里出现了一个奇葩的情况T::Value
,也就是说,我们涉及到了T
,但是却没有直接使用T
,实际使用的是T::Value
这种关联类型。
回顾一下我们对于过滤后的泛型的处理方式
for generic in generics.params.iter_mut() {
if let syn::GenericParam::Type(t) = generic {
let type_param_name = t.ident.to_string();
if phantom_generic_type_names.contains(&type_param_name)
&& !fields_type_names.contains(&type_param_name)
{
continue;
}
t.bounds.push(syn::parse_quote!(std::fmt::Debug));
}
}
很明显的出现一个问题:我们会将泛型全部限定为T::Debug
,即使它只在声明中出现,即使它并未参与字段声明。
因此,我们当前的任务是:更细粒度的控制泛型约束。
这道题中,我们约束T::Value
却忽略T
,因为T
并没有直接参与泛型字段声明。
其中涉及两点
- 关联类型的提取
where_clause
修改
提取
自动遍历
// common.rs
struct TypePathVisitor {
interst_generic_type_names: Vec<String>,
associated_type_names: std::collections::HashMap<String, Vec<syn::TypePath>>,
}
// 需要启动visit: syn = { version = "1.0.84", features = ["visit"]}
impl<'ast> syn::visit::Visit<'ast> for TypePathVisitor {
fn visit_type_path(&mut self, node: &'ast syn::TypePath) {
// 路径必定大于1 A::B
if node.path.segments.len() > 1 {
// 外部泛型声明
let generic_type_name = node.path.segments[0].ident.to_string();
// 以外部泛型声明开头的泛型全路径
if self.interst_generic_type_names.contains(&generic_type_name) {
self.associated_type_names
.entry(generic_type_name)
.or_insert(vec![])
.push(node.clone());
}
}
syn::visit::visit_type_path(self, node);
}
}
自动遍历并且收集我们感兴趣的数据,可以参考官方文档。
// common.rs
pub(crate) fn parse_generic_associated_types(
ast: &syn::DeriveInput,
) -> std::collections::HashMap<String, Vec<syn::TypePath>> {
let origin_generic_type_names: Vec<String> = ast
.generics
.params
.iter()
.filter_map(|f| {
if let syn::GenericParam::Type(t) = f {
return Some(t.ident.to_string());
}
return None;
})
.collect();
let mut visitor = TypePathVisitor {
interst_generic_type_names: origin_generic_type_names,
associated_type_names: std::collections::HashMap::new(),
};
visitor.visit_derive_input(ast);
return visitor.associated_type_names;
}
到这里,我们已经收集到了以最顶级泛型声明开头的其他泛型,后续就是定制where_clause
。
题解
// solution7.rs
pub(super) fn soution(
fields: &crate::common::FieldsType,
origin_ident: &syn::Ident,
ast: &syn::DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
let mut origin_field_type_names = vec![];
let mut phantom_generic_type_names = vec![];
for field in fields.iter() {
if let Some(origin_field_type_name) = crate::common::parse_field_type_name(field)? {
origin_field_type_names.push(origin_field_type_name);
}
if let Some(phantom_generic_type_name) =
crate::common::parse_phantom_generic_type_name(field)?
{
phantom_generic_type_names.push(phantom_generic_type_name);
}
}
let associated_generics_type_map = crate::common::parse_generic_associated_types(ast);
let mut generics = crate::common::parse_generic_type(ast);
// 限定非PhatomData和非T::Value的限定
for generic in generics.params.iter_mut() {
if let syn::GenericParam::Type(t) = generic {
let type_name = t.ident.to_string();
if phantom_generic_type_names.contains(&type_name)
&& !origin_field_type_names.contains(&type_name)
{
continue;
}
if associated_generics_type_map.contains_key(&type_name)
&& !origin_field_type_names.contains(&type_name)
{
continue;
}
t.bounds.push(syn::parse_quote!(std::fmt::Debug));
}
}
// 自定义where_clause
generics.make_where_clause();
for (_, associated_types) in associated_generics_type_map {
for associated_type in associated_types {
generics
.where_clause
.as_mut()
.unwrap()
.predicates
// 限定关联的泛型类型
.push(syn::parse_quote!(#associated_type:std::fmt::Debug));
}
}
let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
let fields_strea_vec = generate_field_stream_vec(fields)?;
let origin_ident_string = origin_ident.to_string();
// 照旧
syn::Result::Ok(quote::quote! {
impl #impl_generics std::fmt::Debug for #origin_ident #type_generics #where_clause {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct(#origin_ident_string)
#(
#fields_strea_vec
)*
.finish()
}
}
})
}
// 无关的debug字段设置
fn generate_field_stream_vec(
fields: &crate::common::FieldsType,
) -> syn::Result<Vec<proc_macro2::TokenStream>> {
fields
.iter()
.map(|f| {
let ident = &f.ident;
let ident_string = ident.as_ref().unwrap().to_string();
let mut format = "{:?}".to_string();
if let Some(customer_format) = crate::common::parse_format(f)? {
format = customer_format;
}
syn::Result::Ok(quote::quote! {
.field(#ident_string, &format_args!(#format, &self.#ident))
})
})
.collect()
}
完整
mod common;
mod solution2;
mod solution3;
mod solution4;
mod solution56;
mod solution7;
mod solution8;
#[proc_macro_derive(CustomDebug, attributes(debug))]
pub fn derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let ast = syn::parse_macro_input!(input as syn::DeriveInput);
match solution1(&ast) {
syn::Result::Ok(token_stream) => {
return proc_macro::TokenStream::from(token_stream);
}
syn::Result::Err(e) => {
return e.into_compile_error().into();
}
}
}
fn solution1(ast: &syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let origin_ident = &ast.ident;
let fields = crate::common::parse_fields(&ast)?;
// soluton2
let _ = solution2::solution(fields, origin_ident)?;
let _ = solution3::solution(fields, origin_ident)?;
let _ = solution4::solution(fields, origin_ident, ast)?;
let _ = solution56::solution(fields, origin_ident, ast)?;
let token_stream = solution7::soution(fields, origin_ident, ast)?;
syn::Result::Ok(token_stream)
}
可以cargo expand
观察一下结果
#![feature(prelude_import)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use derive_debug::CustomDebug;
use std::fmt::Debug;
pub trait Trait {
type Value;
}
pub struct Field<T: Trait> {
values: Vec<T::Value>,
}
impl<T: Trait> std::fmt::Debug for Field<T>
where
T::Value: std::fmt::Debug,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("Field")
.field(
"values",
&::core::fmt::Arguments::new_v1(
&[""],
&match (&&self.values,) {
_args => [::core::fmt::ArgumentV1::new(
_args.0,
::core::fmt::Debug::fmt,
)],
},
),
)
.finish()
}
}
fn assert_debug<F: Debug>() {}
fn main() {
struct Id;
impl Trait for Id {
type Value = u8;
}
assert_debug::<Field<Id>>();
}
可以看到where T:: Value: std:: fmt::Debug
单独的限定了T::Value
。
本作品采用《CC 协议》,转载必须注明作者和本文链接
推荐文章: