thiserror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::unraw::MemberUnraw;
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
10
11pub fn derive(input: &DeriveInput) -> TokenStream {
12    match try_expand(input) {
13        Ok(expanded) => expanded,
14        // If there are invalid attributes in the input, expand to an Error impl
15        // anyway to minimize spurious secondary errors in other code that uses
16        // this type as an Error.
17        Err(error) => fallback::expand(input, error),
18    }
19}
20
21fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
22    let input = Input::from_syn(input)?;
23    input.validate()?;
24    Ok(match input {
25        Input::Struct(input) => impl_struct(input),
26        Input::Enum(input) => impl_enum(input),
27    })
28}
29
30fn impl_struct(input: Struct) -> TokenStream {
31    let ty = call_site_ident(&input.ident);
32    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33    let mut error_inferred_bounds = InferredBounds::new();
34
35    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
36        let only_field = &input.fields[0];
37        if only_field.contains_generic {
38            error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::__private::Error));
39        }
40        let member = &only_field.member;
41        Some(quote_spanned! {transparent_attr.span=>
42            ::thiserror::__private::Error::source(self.#member.as_dyn_error())
43        })
44    } else if let Some(source_field) = input.source_field() {
45        let source = &source_field.member;
46        if source_field.contains_generic {
47            let ty = unoptional_type(source_field.ty);
48            error_inferred_bounds.insert(ty, quote!(::thiserror::__private::Error + 'static));
49        }
50        let asref = if type_is_option(source_field.ty) {
51            Some(quote_spanned!(source.span()=> .as_ref()?))
52        } else {
53            None
54        };
55        let dyn_error = quote_spanned! {source_field.source_span()=>
56            self.#source #asref.as_dyn_error()
57        };
58        Some(quote! {
59            ::core::option::Option::Some(#dyn_error)
60        })
61    } else {
62        None
63    };
64    let source_method = source_body.map(|body| {
65        quote! {
66            fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::__private::Error + 'static)> {
67                use ::thiserror::__private::AsDynError as _;
68                #body
69            }
70        }
71    });
72
73    let provide_method = input.backtrace_field().map(|backtrace_field| {
74        let request = quote!(request);
75        let backtrace = &backtrace_field.member;
76        let body = if let Some(source_field) = input.source_field() {
77            let source = &source_field.member;
78            let source_provide = if type_is_option(source_field.ty) {
79                quote_spanned! {source.span()=>
80                    if let ::core::option::Option::Some(source) = &self.#source {
81                        source.thiserror_provide(#request);
82                    }
83                }
84            } else {
85                quote_spanned! {source.span()=>
86                    self.#source.thiserror_provide(#request);
87                }
88            };
89            let self_provide = if source == backtrace {
90                None
91            } else if type_is_option(backtrace_field.ty) {
92                Some(quote! {
93                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
94                        #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
95                    }
96                })
97            } else {
98                Some(quote! {
99                    #request.provide_ref::<::thiserror::__private::Backtrace>(&self.#backtrace);
100                })
101            };
102            quote! {
103                use ::thiserror::__private::ThiserrorProvide as _;
104                #source_provide
105                #self_provide
106            }
107        } else if type_is_option(backtrace_field.ty) {
108            quote! {
109                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
110                    #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
111                }
112            }
113        } else {
114            quote! {
115                #request.provide_ref::<::thiserror::__private::Backtrace>(&self.#backtrace);
116            }
117        };
118        quote! {
119            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
120                #body
121            }
122        }
123    });
124
125    let mut display_implied_bounds = Set::new();
126    let display_body = if input.attrs.transparent.is_some() {
127        let only_field = &input.fields[0].member;
128        display_implied_bounds.insert((0, Trait::Display));
129        Some(quote! {
130            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
131        })
132    } else if let Some(display) = &input.attrs.display {
133        display_implied_bounds.clone_from(&display.implied_bounds);
134        let use_as_display = use_as_display(display.has_bonus_display);
135        let pat = fields_pat(&input.fields);
136        Some(quote! {
137            #use_as_display
138            #[allow(unused_variables, deprecated)]
139            let Self #pat = self;
140            #display
141        })
142    } else {
143        None
144    };
145    let display_impl = display_body.map(|body| {
146        let mut display_inferred_bounds = InferredBounds::new();
147        for (field, bound) in display_implied_bounds {
148            let field = &input.fields[field];
149            if field.contains_generic {
150                display_inferred_bounds.insert(field.ty, bound);
151            }
152        }
153        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
154        quote! {
155            #[allow(unused_qualifications)]
156            #[automatically_derived]
157            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
158                #[allow(clippy::used_underscore_binding)]
159                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
160                    #body
161                }
162            }
163        }
164    });
165
166    let from_impl = input.from_field().map(|from_field| {
167        let span = from_field.attrs.from.unwrap().span;
168        let backtrace_field = input.distinct_backtrace_field();
169        let from = unoptional_type(from_field.ty);
170        let source_var = Ident::new("source", span);
171        let body = from_initializer(from_field, backtrace_field, &source_var);
172        let from_function = quote! {
173            fn from(#source_var: #from) -> Self {
174                #ty #body
175            }
176        };
177        let from_impl = quote_spanned! {span=>
178            #[automatically_derived]
179            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
180                #from_function
181            }
182        };
183        Some(quote! {
184            #[allow(
185                deprecated,
186                unused_qualifications,
187                clippy::elidable_lifetime_names,
188                clippy::needless_lifetimes,
189            )]
190            #from_impl
191        })
192    });
193
194    if input.generics.type_params().next().is_some() {
195        let self_token = <Token![Self]>::default();
196        error_inferred_bounds.insert(self_token, Trait::Debug);
197        error_inferred_bounds.insert(self_token, Trait::Display);
198    }
199    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
200
201    quote! {
202        #[allow(unused_qualifications)]
203        #[automatically_derived]
204        impl #impl_generics ::thiserror::__private::Error for #ty #ty_generics #error_where_clause {
205            #source_method
206            #provide_method
207        }
208        #display_impl
209        #from_impl
210    }
211}
212
213fn impl_enum(input: Enum) -> TokenStream {
214    let ty = call_site_ident(&input.ident);
215    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
216    let mut error_inferred_bounds = InferredBounds::new();
217
218    let source_method = if input.has_source() {
219        let arms = input.variants.iter().map(|variant| {
220            let ident = &variant.ident;
221            if let Some(transparent_attr) = &variant.attrs.transparent {
222                let only_field = &variant.fields[0];
223                if only_field.contains_generic {
224                    error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::__private::Error));
225                }
226                let member = &only_field.member;
227                let source = quote_spanned! {transparent_attr.span=>
228                    ::thiserror::__private::Error::source(transparent.as_dyn_error())
229                };
230                quote! {
231                    #ty::#ident {#member: transparent} => #source,
232                }
233            } else if let Some(source_field) = variant.source_field() {
234                let source = &source_field.member;
235                if source_field.contains_generic {
236                    let ty = unoptional_type(source_field.ty);
237                    error_inferred_bounds.insert(ty, quote!(::thiserror::__private::Error + 'static));
238                }
239                let asref = if type_is_option(source_field.ty) {
240                    Some(quote_spanned!(source.span()=> .as_ref()?))
241                } else {
242                    None
243                };
244                let varsource = quote!(source);
245                let dyn_error = quote_spanned! {source_field.source_span()=>
246                    #varsource #asref.as_dyn_error()
247                };
248                quote! {
249                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
250                }
251            } else {
252                quote! {
253                    #ty::#ident {..} => ::core::option::Option::None,
254                }
255            }
256        });
257        Some(quote! {
258            fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::__private::Error + 'static)> {
259                use ::thiserror::__private::AsDynError as _;
260                #[allow(deprecated)]
261                match self {
262                    #(#arms)*
263                }
264            }
265        })
266    } else {
267        None
268    };
269
270    let provide_method = if input.has_backtrace() {
271        let request = quote!(request);
272        let arms = input.variants.iter().map(|variant| {
273            let ident = &variant.ident;
274            match (variant.backtrace_field(), variant.source_field()) {
275                (Some(backtrace_field), Some(source_field))
276                    if backtrace_field.attrs.backtrace.is_none() =>
277                {
278                    let backtrace = &backtrace_field.member;
279                    let source = &source_field.member;
280                    let varsource = quote!(source);
281                    let source_provide = if type_is_option(source_field.ty) {
282                        quote_spanned! {source.span()=>
283                            if let ::core::option::Option::Some(source) = #varsource {
284                                source.thiserror_provide(#request);
285                            }
286                        }
287                    } else {
288                        quote_spanned! {source.span()=>
289                            #varsource.thiserror_provide(#request);
290                        }
291                    };
292                    let self_provide = if type_is_option(backtrace_field.ty) {
293                        quote! {
294                            if let ::core::option::Option::Some(backtrace) = backtrace {
295                                #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
296                            }
297                        }
298                    } else {
299                        quote! {
300                            #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
301                        }
302                    };
303                    quote! {
304                        #ty::#ident {
305                            #backtrace: backtrace,
306                            #source: #varsource,
307                            ..
308                        } => {
309                            use ::thiserror::__private::ThiserrorProvide as _;
310                            #source_provide
311                            #self_provide
312                        }
313                    }
314                }
315                (Some(backtrace_field), Some(source_field))
316                    if backtrace_field.member == source_field.member =>
317                {
318                    let backtrace = &backtrace_field.member;
319                    let varsource = quote!(source);
320                    let source_provide = if type_is_option(source_field.ty) {
321                        quote_spanned! {backtrace.span()=>
322                            if let ::core::option::Option::Some(source) = #varsource {
323                                source.thiserror_provide(#request);
324                            }
325                        }
326                    } else {
327                        quote_spanned! {backtrace.span()=>
328                            #varsource.thiserror_provide(#request);
329                        }
330                    };
331                    quote! {
332                        #ty::#ident {#backtrace: #varsource, ..} => {
333                            use ::thiserror::__private::ThiserrorProvide as _;
334                            #source_provide
335                        }
336                    }
337                }
338                (Some(backtrace_field), _) => {
339                    let backtrace = &backtrace_field.member;
340                    let body = if type_is_option(backtrace_field.ty) {
341                        quote! {
342                            if let ::core::option::Option::Some(backtrace) = backtrace {
343                                #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
344                            }
345                        }
346                    } else {
347                        quote! {
348                            #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
349                        }
350                    };
351                    quote! {
352                        #ty::#ident {#backtrace: backtrace, ..} => {
353                            #body
354                        }
355                    }
356                }
357                (None, _) => quote! {
358                    #ty::#ident {..} => {}
359                },
360            }
361        });
362        Some(quote! {
363            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
364                #[allow(deprecated)]
365                match self {
366                    #(#arms)*
367                }
368            }
369        })
370    } else {
371        None
372    };
373
374    let display_impl = if input.has_display() {
375        let mut display_inferred_bounds = InferredBounds::new();
376        let has_bonus_display = input.variants.iter().any(|v| {
377            v.attrs
378                .display
379                .as_ref()
380                .map_or(false, |display| display.has_bonus_display)
381        });
382        let use_as_display = use_as_display(has_bonus_display);
383        let void_deref = if input.variants.is_empty() {
384            Some(quote!(*))
385        } else {
386            None
387        };
388        let arms = input.variants.iter().map(|variant| {
389            let mut display_implied_bounds = Set::new();
390            let display = if let Some(display) = &variant.attrs.display {
391                display_implied_bounds.clone_from(&display.implied_bounds);
392                display.to_token_stream()
393            } else if let Some(fmt) = &variant.attrs.fmt {
394                let fmt_path = &fmt.path;
395                let vars = variant.fields.iter().map(|field| match &field.member {
396                    MemberUnraw::Named(ident) => ident.to_local(),
397                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
398                });
399                quote!(#fmt_path(#(#vars,)* __formatter))
400            } else {
401                let only_field = match &variant.fields[0].member {
402                    MemberUnraw::Named(ident) => ident.to_local(),
403                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
404                };
405                display_implied_bounds.insert((0, Trait::Display));
406                quote!(::core::fmt::Display::fmt(#only_field, __formatter))
407            };
408            for (field, bound) in display_implied_bounds {
409                let field = &variant.fields[field];
410                if field.contains_generic {
411                    display_inferred_bounds.insert(field.ty, bound);
412                }
413            }
414            let ident = &variant.ident;
415            let pat = fields_pat(&variant.fields);
416            quote! {
417                #ty::#ident #pat => #display
418            }
419        });
420        let arms = arms.collect::<Vec<_>>();
421        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
422        Some(quote! {
423            #[allow(unused_qualifications)]
424            #[automatically_derived]
425            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
426                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
427                    #use_as_display
428                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
429                    match #void_deref self {
430                        #(#arms,)*
431                    }
432                }
433            }
434        })
435    } else {
436        None
437    };
438
439    let from_impls = input.variants.iter().filter_map(|variant| {
440        let from_field = variant.from_field()?;
441        let span = from_field.attrs.from.unwrap().span;
442        let backtrace_field = variant.distinct_backtrace_field();
443        let variant = &variant.ident;
444        let from = unoptional_type(from_field.ty);
445        let source_var = Ident::new("source", span);
446        let body = from_initializer(from_field, backtrace_field, &source_var);
447        let from_function = quote! {
448            fn from(#source_var: #from) -> Self {
449                #ty::#variant #body
450            }
451        };
452        let from_impl = quote_spanned! {span=>
453            #[automatically_derived]
454            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
455                #from_function
456            }
457        };
458        Some(quote! {
459            #[allow(
460                deprecated,
461                unused_qualifications,
462                clippy::elidable_lifetime_names,
463                clippy::needless_lifetimes,
464            )]
465            #from_impl
466        })
467    });
468
469    if input.generics.type_params().next().is_some() {
470        let self_token = <Token![Self]>::default();
471        error_inferred_bounds.insert(self_token, Trait::Debug);
472        error_inferred_bounds.insert(self_token, Trait::Display);
473    }
474    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
475
476    quote! {
477        #[allow(unused_qualifications)]
478        #[automatically_derived]
479        impl #impl_generics ::thiserror::__private::Error for #ty #ty_generics #error_where_clause {
480            #source_method
481            #provide_method
482        }
483        #display_impl
484        #(#from_impls)*
485    }
486}
487
488// Create an ident with which we can expand `impl Trait for #ident {}` on a
489// deprecated type without triggering deprecation warning on the generated impl.
490pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
491    let mut ident = ident.clone();
492    ident.set_span(ident.span().resolved_at(Span::call_site()));
493    ident
494}
495
496fn fields_pat(fields: &[Field]) -> TokenStream {
497    let mut members = fields.iter().map(|field| &field.member).peekable();
498    match members.peek() {
499        Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
500        Some(MemberUnraw::Unnamed(_)) => {
501            let vars = members.map(|member| match member {
502                MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
503                MemberUnraw::Named(_) => unreachable!(),
504            });
505            quote!((#(#vars),*))
506        }
507        None => quote!({}),
508    }
509}
510
511fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
512    if needs_as_display {
513        Some(quote! {
514            use ::thiserror::__private::AsDisplay as _;
515        })
516    } else {
517        None
518    }
519}
520
521fn from_initializer(
522    from_field: &Field,
523    backtrace_field: Option<&Field>,
524    source_var: &Ident,
525) -> TokenStream {
526    let from_member = &from_field.member;
527    let some_source = if type_is_option(from_field.ty) {
528        quote!(::core::option::Option::Some(#source_var))
529    } else {
530        quote!(#source_var)
531    };
532    let backtrace = backtrace_field.map(|backtrace_field| {
533        let backtrace_member = &backtrace_field.member;
534        if type_is_option(backtrace_field.ty) {
535            quote! {
536                #backtrace_member: ::core::option::Option::Some(::thiserror::__private::Backtrace::capture()),
537            }
538        } else {
539            quote! {
540                #backtrace_member: ::core::convert::From::from(::thiserror::__private::Backtrace::capture()),
541            }
542        }
543    });
544    quote!({
545        #from_member: #some_source,
546        #backtrace
547    })
548}
549
550fn type_is_option(ty: &Type) -> bool {
551    type_parameter_of_option(ty).is_some()
552}
553
554fn unoptional_type(ty: &Type) -> TokenStream {
555    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
556    quote!(#unoptional)
557}
558
559fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
560    let path = match ty {
561        Type::Path(ty) => &ty.path,
562        _ => return None,
563    };
564
565    let last = path.segments.last().unwrap();
566    if last.ident != "Option" {
567        return None;
568    }
569
570    let bracketed = match &last.arguments {
571        PathArguments::AngleBracketed(bracketed) => bracketed,
572        _ => return None,
573    };
574
575    if bracketed.args.len() != 1 {
576        return None;
577    }
578
579    match &bracketed.args[0] {
580        GenericArgument::Type(arg) => Some(arg),
581        _ => None,
582    }
583}