serde_derive/
bound.rs

1use std::collections::HashSet;
2
3use syn;
4use syn::punctuated::{Pair, Punctuated};
5use syn::visit::{self, Visit};
6
7use internals::ast::{Container, Data};
8use internals::attr;
9
10use proc_macro2::Span;
11
12// Remove the default from every type parameter because in the generated impls
13// they look like associated types: "error: associated type bindings are not
14// allowed here".
15pub fn without_defaults(generics: &syn::Generics) -> syn::Generics {
16    syn::Generics {
17        params: generics
18            .params
19            .iter()
20            .map(|param| match param {
21                syn::GenericParam::Type(param) => syn::GenericParam::Type(syn::TypeParam {
22                    eq_token: None,
23                    default: None,
24                    ..param.clone()
25                }),
26                _ => param.clone(),
27            })
28            .collect(),
29        ..generics.clone()
30    }
31}
32
33pub fn with_where_predicates(
34    generics: &syn::Generics,
35    predicates: &[syn::WherePredicate],
36) -> syn::Generics {
37    let mut generics = generics.clone();
38    generics
39        .make_where_clause()
40        .predicates
41        .extend(predicates.iter().cloned());
42    generics
43}
44
45pub fn with_where_predicates_from_fields(
46    cont: &Container,
47    generics: &syn::Generics,
48    from_field: fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
49) -> syn::Generics {
50    let predicates = cont
51        .data
52        .all_fields()
53        .flat_map(|field| from_field(&field.attrs))
54        .flat_map(|predicates| predicates.to_vec());
55
56    let mut generics = generics.clone();
57    generics.make_where_clause().predicates.extend(predicates);
58    generics
59}
60
61pub fn with_where_predicates_from_variants(
62    cont: &Container,
63    generics: &syn::Generics,
64    from_variant: fn(&attr::Variant) -> Option<&[syn::WherePredicate]>,
65) -> syn::Generics {
66    let variants = match &cont.data {
67        Data::Enum(variants) => variants,
68        Data::Struct(_, _) => {
69            return generics.clone();
70        }
71    };
72
73    let predicates = variants
74        .iter()
75        .flat_map(|variant| from_variant(&variant.attrs))
76        .flat_map(|predicates| predicates.to_vec());
77
78    let mut generics = generics.clone();
79    generics.make_where_clause().predicates.extend(predicates);
80    generics
81}
82
83// Puts the given bound on any generic type parameters that are used in fields
84// for which filter returns true.
85//
86// For example, the following struct needs the bound `A: Serialize, B:
87// Serialize`.
88//
89//     struct S<'b, A, B: 'b, C> {
90//         a: A,
91//         b: Option<&'b B>
92//         #[serde(skip_serializing)]
93//         c: C,
94//     }
95pub fn with_bound(
96    cont: &Container,
97    generics: &syn::Generics,
98    filter: fn(&attr::Field, Option<&attr::Variant>) -> bool,
99    bound: &syn::Path,
100) -> syn::Generics {
101    struct FindTyParams<'ast> {
102        // Set of all generic type parameters on the current struct (A, B, C in
103        // the example). Initialized up front.
104        all_type_params: HashSet<syn::Ident>,
105
106        // Set of generic type parameters used in fields for which filter
107        // returns true (A and B in the example). Filled in as the visitor sees
108        // them.
109        relevant_type_params: HashSet<syn::Ident>,
110
111        // Fields whose type is an associated type of one of the generic type
112        // parameters.
113        associated_type_usage: Vec<&'ast syn::TypePath>,
114    }
115    impl<'ast> Visit<'ast> for FindTyParams<'ast> {
116        fn visit_field(&mut self, field: &'ast syn::Field) {
117            if let syn::Type::Path(ty) = &field.ty {
118                if let Some(Pair::Punctuated(t, _)) = ty.path.segments.pairs().next() {
119                    if self.all_type_params.contains(&t.ident) {
120                        self.associated_type_usage.push(ty);
121                    }
122                }
123            }
124            self.visit_type(&field.ty);
125        }
126
127        fn visit_path(&mut self, path: &'ast syn::Path) {
128            if let Some(seg) = path.segments.last() {
129                if seg.ident == "PhantomData" {
130                    // Hardcoded exception, because PhantomData<T> implements
131                    // Serialize and Deserialize whether or not T implements it.
132                    return;
133                }
134            }
135            if path.leading_colon.is_none() && path.segments.len() == 1 {
136                let id = &path.segments[0].ident;
137                if self.all_type_params.contains(id) {
138                    self.relevant_type_params.insert(id.clone());
139                }
140            }
141            visit::visit_path(self, path);
142        }
143
144        // Type parameter should not be considered used by a macro path.
145        //
146        //     struct TypeMacro<T> {
147        //         mac: T!(),
148        //         marker: PhantomData<T>,
149        //     }
150        fn visit_macro(&mut self, _mac: &'ast syn::Macro) {}
151    }
152
153    let all_type_params = generics
154        .type_params()
155        .map(|param| param.ident.clone())
156        .collect();
157
158    let mut visitor = FindTyParams {
159        all_type_params: all_type_params,
160        relevant_type_params: HashSet::new(),
161        associated_type_usage: Vec::new(),
162    };
163    match &cont.data {
164        Data::Enum(variants) => {
165            for variant in variants.iter() {
166                let relevant_fields = variant
167                    .fields
168                    .iter()
169                    .filter(|field| filter(&field.attrs, Some(&variant.attrs)));
170                for field in relevant_fields {
171                    visitor.visit_field(field.original);
172                }
173            }
174        }
175        Data::Struct(_, fields) => {
176            for field in fields.iter().filter(|field| filter(&field.attrs, None)) {
177                visitor.visit_field(field.original);
178            }
179        }
180    }
181
182    let relevant_type_params = visitor.relevant_type_params;
183    let associated_type_usage = visitor.associated_type_usage;
184    let new_predicates = generics
185        .type_params()
186        .map(|param| param.ident.clone())
187        .filter(|id| relevant_type_params.contains(id))
188        .map(|id| syn::TypePath {
189            qself: None,
190            path: id.into(),
191        })
192        .chain(associated_type_usage.into_iter().cloned())
193        .map(|bounded_ty| {
194            syn::WherePredicate::Type(syn::PredicateType {
195                lifetimes: None,
196                // the type parameter that is being bounded e.g. T
197                bounded_ty: syn::Type::Path(bounded_ty),
198                colon_token: <Token![:]>::default(),
199                // the bound e.g. Serialize
200                bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
201                    paren_token: None,
202                    modifier: syn::TraitBoundModifier::None,
203                    lifetimes: None,
204                    path: bound.clone(),
205                })]
206                .into_iter()
207                .collect(),
208            })
209        });
210
211    let mut generics = generics.clone();
212    generics
213        .make_where_clause()
214        .predicates
215        .extend(new_predicates);
216    generics
217}
218
219pub fn with_self_bound(
220    cont: &Container,
221    generics: &syn::Generics,
222    bound: &syn::Path,
223) -> syn::Generics {
224    let mut generics = generics.clone();
225    generics
226        .make_where_clause()
227        .predicates
228        .push(syn::WherePredicate::Type(syn::PredicateType {
229            lifetimes: None,
230            // the type that is being bounded e.g. MyStruct<'a, T>
231            bounded_ty: type_of_item(cont),
232            colon_token: <Token![:]>::default(),
233            // the bound e.g. Default
234            bounds: vec![syn::TypeParamBound::Trait(syn::TraitBound {
235                paren_token: None,
236                modifier: syn::TraitBoundModifier::None,
237                lifetimes: None,
238                path: bound.clone(),
239            })]
240            .into_iter()
241            .collect(),
242        }));
243    generics
244}
245
246pub fn with_lifetime_bound(generics: &syn::Generics, lifetime: &str) -> syn::Generics {
247    let bound = syn::Lifetime::new(lifetime, Span::call_site());
248    let def = syn::LifetimeDef {
249        attrs: Vec::new(),
250        lifetime: bound.clone(),
251        colon_token: None,
252        bounds: Punctuated::new(),
253    };
254
255    let params = Some(syn::GenericParam::Lifetime(def))
256        .into_iter()
257        .chain(generics.params.iter().cloned().map(|mut param| {
258            match &mut param {
259                syn::GenericParam::Lifetime(param) => {
260                    param.bounds.push(bound.clone());
261                }
262                syn::GenericParam::Type(param) => {
263                    param
264                        .bounds
265                        .push(syn::TypeParamBound::Lifetime(bound.clone()));
266                }
267                syn::GenericParam::Const(_) => {}
268            }
269            param
270        }))
271        .collect();
272
273    syn::Generics {
274        params: params,
275        ..generics.clone()
276    }
277}
278
279fn type_of_item(cont: &Container) -> syn::Type {
280    syn::Type::Path(syn::TypePath {
281        qself: None,
282        path: syn::Path {
283            leading_colon: None,
284            segments: vec![syn::PathSegment {
285                ident: cont.ident.clone(),
286                arguments: syn::PathArguments::AngleBracketed(
287                    syn::AngleBracketedGenericArguments {
288                        colon2_token: None,
289                        lt_token: <Token![<]>::default(),
290                        args: cont
291                            .generics
292                            .params
293                            .iter()
294                            .map(|param| match param {
295                                syn::GenericParam::Type(param) => {
296                                    syn::GenericArgument::Type(syn::Type::Path(syn::TypePath {
297                                        qself: None,
298                                        path: param.ident.clone().into(),
299                                    }))
300                                }
301                                syn::GenericParam::Lifetime(param) => {
302                                    syn::GenericArgument::Lifetime(param.lifetime.clone())
303                                }
304                                syn::GenericParam::Const(_) => {
305                                    panic!("Serde does not support const generics yet");
306                                }
307                            })
308                            .collect(),
309                        gt_token: <Token![>]>::default(),
310                    },
311                ),
312            }]
313            .into_iter()
314            .collect(),
315        },
316    })
317}