1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::unraw::MemberUnraw;
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::{format_ident, quote, quote_spanned, ToTokens};
8use std::collections::BTreeSet as Set;
9use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
10
11pub fn derive(input: &DeriveInput) -> TokenStream {
12 match try_expand(input) {
13 Ok(expanded) => expanded,
14 Err(error) => fallback::expand(input, error),
18 }
19}
20
21fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
22 let input = Input::from_syn(input)?;
23 input.validate()?;
24 Ok(match input {
25 Input::Struct(input) => impl_struct(input),
26 Input::Enum(input) => impl_enum(input),
27 })
28}
29
30fn impl_struct(input: Struct) -> TokenStream {
31 let ty = call_site_ident(&input.ident);
32 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33 let mut error_inferred_bounds = InferredBounds::new();
34
35 let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
36 let only_field = &input.fields[0];
37 if only_field.contains_generic {
38 error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::__private::Error));
39 }
40 let member = &only_field.member;
41 Some(quote_spanned! {transparent_attr.span=>
42 ::thiserror::__private::Error::source(self.#member.as_dyn_error())
43 })
44 } else if let Some(source_field) = input.source_field() {
45 let source = &source_field.member;
46 if source_field.contains_generic {
47 let ty = unoptional_type(source_field.ty);
48 error_inferred_bounds.insert(ty, quote!(::thiserror::__private::Error + 'static));
49 }
50 let asref = if type_is_option(source_field.ty) {
51 Some(quote_spanned!(source.span()=> .as_ref()?))
52 } else {
53 None
54 };
55 let dyn_error = quote_spanned! {source_field.source_span()=>
56 self.#source #asref.as_dyn_error()
57 };
58 Some(quote! {
59 ::core::option::Option::Some(#dyn_error)
60 })
61 } else {
62 None
63 };
64 let source_method = source_body.map(|body| {
65 quote! {
66 fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::__private::Error + 'static)> {
67 use ::thiserror::__private::AsDynError as _;
68 #body
69 }
70 }
71 });
72
73 let provide_method = input.backtrace_field().map(|backtrace_field| {
74 let request = quote!(request);
75 let backtrace = &backtrace_field.member;
76 let body = if let Some(source_field) = input.source_field() {
77 let source = &source_field.member;
78 let source_provide = if type_is_option(source_field.ty) {
79 quote_spanned! {source.span()=>
80 if let ::core::option::Option::Some(source) = &self.#source {
81 source.thiserror_provide(#request);
82 }
83 }
84 } else {
85 quote_spanned! {source.span()=>
86 self.#source.thiserror_provide(#request);
87 }
88 };
89 let self_provide = if source == backtrace {
90 None
91 } else if type_is_option(backtrace_field.ty) {
92 Some(quote! {
93 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
94 #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
95 }
96 })
97 } else {
98 Some(quote! {
99 #request.provide_ref::<::thiserror::__private::Backtrace>(&self.#backtrace);
100 })
101 };
102 quote! {
103 use ::thiserror::__private::ThiserrorProvide as _;
104 #source_provide
105 #self_provide
106 }
107 } else if type_is_option(backtrace_field.ty) {
108 quote! {
109 if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
110 #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
111 }
112 }
113 } else {
114 quote! {
115 #request.provide_ref::<::thiserror::__private::Backtrace>(&self.#backtrace);
116 }
117 };
118 quote! {
119 fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
120 #body
121 }
122 }
123 });
124
125 let mut display_implied_bounds = Set::new();
126 let display_body = if input.attrs.transparent.is_some() {
127 let only_field = &input.fields[0].member;
128 display_implied_bounds.insert((0, Trait::Display));
129 Some(quote! {
130 ::core::fmt::Display::fmt(&self.#only_field, __formatter)
131 })
132 } else if let Some(display) = &input.attrs.display {
133 display_implied_bounds.clone_from(&display.implied_bounds);
134 let use_as_display = use_as_display(display.has_bonus_display);
135 let pat = fields_pat(&input.fields);
136 Some(quote! {
137 #use_as_display
138 #[allow(unused_variables, deprecated)]
139 let Self #pat = self;
140 #display
141 })
142 } else {
143 None
144 };
145 let display_impl = display_body.map(|body| {
146 let mut display_inferred_bounds = InferredBounds::new();
147 for (field, bound) in display_implied_bounds {
148 let field = &input.fields[field];
149 if field.contains_generic {
150 display_inferred_bounds.insert(field.ty, bound);
151 }
152 }
153 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
154 quote! {
155 #[allow(unused_qualifications)]
156 #[automatically_derived]
157 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
158 #[allow(clippy::used_underscore_binding)]
159 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
160 #body
161 }
162 }
163 }
164 });
165
166 let from_impl = input.from_field().map(|from_field| {
167 let span = from_field.attrs.from.unwrap().span;
168 let backtrace_field = input.distinct_backtrace_field();
169 let from = unoptional_type(from_field.ty);
170 let source_var = Ident::new("source", span);
171 let body = from_initializer(from_field, backtrace_field, &source_var);
172 let from_function = quote! {
173 fn from(#source_var: #from) -> Self {
174 #ty #body
175 }
176 };
177 let from_impl = quote_spanned! {span=>
178 #[automatically_derived]
179 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
180 #from_function
181 }
182 };
183 Some(quote! {
184 #[allow(
185 deprecated,
186 unused_qualifications,
187 clippy::elidable_lifetime_names,
188 clippy::needless_lifetimes,
189 )]
190 #from_impl
191 })
192 });
193
194 if input.generics.type_params().next().is_some() {
195 let self_token = <Token![Self]>::default();
196 error_inferred_bounds.insert(self_token, Trait::Debug);
197 error_inferred_bounds.insert(self_token, Trait::Display);
198 }
199 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
200
201 quote! {
202 #[allow(unused_qualifications)]
203 #[automatically_derived]
204 impl #impl_generics ::thiserror::__private::Error for #ty #ty_generics #error_where_clause {
205 #source_method
206 #provide_method
207 }
208 #display_impl
209 #from_impl
210 }
211}
212
213fn impl_enum(input: Enum) -> TokenStream {
214 let ty = call_site_ident(&input.ident);
215 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
216 let mut error_inferred_bounds = InferredBounds::new();
217
218 let source_method = if input.has_source() {
219 let arms = input.variants.iter().map(|variant| {
220 let ident = &variant.ident;
221 if let Some(transparent_attr) = &variant.attrs.transparent {
222 let only_field = &variant.fields[0];
223 if only_field.contains_generic {
224 error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::__private::Error));
225 }
226 let member = &only_field.member;
227 let source = quote_spanned! {transparent_attr.span=>
228 ::thiserror::__private::Error::source(transparent.as_dyn_error())
229 };
230 quote! {
231 #ty::#ident {#member: transparent} => #source,
232 }
233 } else if let Some(source_field) = variant.source_field() {
234 let source = &source_field.member;
235 if source_field.contains_generic {
236 let ty = unoptional_type(source_field.ty);
237 error_inferred_bounds.insert(ty, quote!(::thiserror::__private::Error + 'static));
238 }
239 let asref = if type_is_option(source_field.ty) {
240 Some(quote_spanned!(source.span()=> .as_ref()?))
241 } else {
242 None
243 };
244 let varsource = quote!(source);
245 let dyn_error = quote_spanned! {source_field.source_span()=>
246 #varsource #asref.as_dyn_error()
247 };
248 quote! {
249 #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
250 }
251 } else {
252 quote! {
253 #ty::#ident {..} => ::core::option::Option::None,
254 }
255 }
256 });
257 Some(quote! {
258 fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::__private::Error + 'static)> {
259 use ::thiserror::__private::AsDynError as _;
260 #[allow(deprecated)]
261 match self {
262 #(#arms)*
263 }
264 }
265 })
266 } else {
267 None
268 };
269
270 let provide_method = if input.has_backtrace() {
271 let request = quote!(request);
272 let arms = input.variants.iter().map(|variant| {
273 let ident = &variant.ident;
274 match (variant.backtrace_field(), variant.source_field()) {
275 (Some(backtrace_field), Some(source_field))
276 if backtrace_field.attrs.backtrace.is_none() =>
277 {
278 let backtrace = &backtrace_field.member;
279 let source = &source_field.member;
280 let varsource = quote!(source);
281 let source_provide = if type_is_option(source_field.ty) {
282 quote_spanned! {source.span()=>
283 if let ::core::option::Option::Some(source) = #varsource {
284 source.thiserror_provide(#request);
285 }
286 }
287 } else {
288 quote_spanned! {source.span()=>
289 #varsource.thiserror_provide(#request);
290 }
291 };
292 let self_provide = if type_is_option(backtrace_field.ty) {
293 quote! {
294 if let ::core::option::Option::Some(backtrace) = backtrace {
295 #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
296 }
297 }
298 } else {
299 quote! {
300 #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
301 }
302 };
303 quote! {
304 #ty::#ident {
305 #backtrace: backtrace,
306 #source: #varsource,
307 ..
308 } => {
309 use ::thiserror::__private::ThiserrorProvide as _;
310 #source_provide
311 #self_provide
312 }
313 }
314 }
315 (Some(backtrace_field), Some(source_field))
316 if backtrace_field.member == source_field.member =>
317 {
318 let backtrace = &backtrace_field.member;
319 let varsource = quote!(source);
320 let source_provide = if type_is_option(source_field.ty) {
321 quote_spanned! {backtrace.span()=>
322 if let ::core::option::Option::Some(source) = #varsource {
323 source.thiserror_provide(#request);
324 }
325 }
326 } else {
327 quote_spanned! {backtrace.span()=>
328 #varsource.thiserror_provide(#request);
329 }
330 };
331 quote! {
332 #ty::#ident {#backtrace: #varsource, ..} => {
333 use ::thiserror::__private::ThiserrorProvide as _;
334 #source_provide
335 }
336 }
337 }
338 (Some(backtrace_field), _) => {
339 let backtrace = &backtrace_field.member;
340 let body = if type_is_option(backtrace_field.ty) {
341 quote! {
342 if let ::core::option::Option::Some(backtrace) = backtrace {
343 #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
344 }
345 }
346 } else {
347 quote! {
348 #request.provide_ref::<::thiserror::__private::Backtrace>(backtrace);
349 }
350 };
351 quote! {
352 #ty::#ident {#backtrace: backtrace, ..} => {
353 #body
354 }
355 }
356 }
357 (None, _) => quote! {
358 #ty::#ident {..} => {}
359 },
360 }
361 });
362 Some(quote! {
363 fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
364 #[allow(deprecated)]
365 match self {
366 #(#arms)*
367 }
368 }
369 })
370 } else {
371 None
372 };
373
374 let display_impl = if input.has_display() {
375 let mut display_inferred_bounds = InferredBounds::new();
376 let has_bonus_display = input.variants.iter().any(|v| {
377 v.attrs
378 .display
379 .as_ref()
380 .map_or(false, |display| display.has_bonus_display)
381 });
382 let use_as_display = use_as_display(has_bonus_display);
383 let void_deref = if input.variants.is_empty() {
384 Some(quote!(*))
385 } else {
386 None
387 };
388 let arms = input.variants.iter().map(|variant| {
389 let mut display_implied_bounds = Set::new();
390 let display = if let Some(display) = &variant.attrs.display {
391 display_implied_bounds.clone_from(&display.implied_bounds);
392 display.to_token_stream()
393 } else if let Some(fmt) = &variant.attrs.fmt {
394 let fmt_path = &fmt.path;
395 let vars = variant.fields.iter().map(|field| match &field.member {
396 MemberUnraw::Named(ident) => ident.to_local(),
397 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
398 });
399 quote!(#fmt_path(#(#vars,)* __formatter))
400 } else {
401 let only_field = match &variant.fields[0].member {
402 MemberUnraw::Named(ident) => ident.to_local(),
403 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
404 };
405 display_implied_bounds.insert((0, Trait::Display));
406 quote!(::core::fmt::Display::fmt(#only_field, __formatter))
407 };
408 for (field, bound) in display_implied_bounds {
409 let field = &variant.fields[field];
410 if field.contains_generic {
411 display_inferred_bounds.insert(field.ty, bound);
412 }
413 }
414 let ident = &variant.ident;
415 let pat = fields_pat(&variant.fields);
416 quote! {
417 #ty::#ident #pat => #display
418 }
419 });
420 let arms = arms.collect::<Vec<_>>();
421 let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
422 Some(quote! {
423 #[allow(unused_qualifications)]
424 #[automatically_derived]
425 impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
426 fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
427 #use_as_display
428 #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
429 match #void_deref self {
430 #(#arms,)*
431 }
432 }
433 }
434 })
435 } else {
436 None
437 };
438
439 let from_impls = input.variants.iter().filter_map(|variant| {
440 let from_field = variant.from_field()?;
441 let span = from_field.attrs.from.unwrap().span;
442 let backtrace_field = variant.distinct_backtrace_field();
443 let variant = &variant.ident;
444 let from = unoptional_type(from_field.ty);
445 let source_var = Ident::new("source", span);
446 let body = from_initializer(from_field, backtrace_field, &source_var);
447 let from_function = quote! {
448 fn from(#source_var: #from) -> Self {
449 #ty::#variant #body
450 }
451 };
452 let from_impl = quote_spanned! {span=>
453 #[automatically_derived]
454 impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
455 #from_function
456 }
457 };
458 Some(quote! {
459 #[allow(
460 deprecated,
461 unused_qualifications,
462 clippy::elidable_lifetime_names,
463 clippy::needless_lifetimes,
464 )]
465 #from_impl
466 })
467 });
468
469 if input.generics.type_params().next().is_some() {
470 let self_token = <Token![Self]>::default();
471 error_inferred_bounds.insert(self_token, Trait::Debug);
472 error_inferred_bounds.insert(self_token, Trait::Display);
473 }
474 let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
475
476 quote! {
477 #[allow(unused_qualifications)]
478 #[automatically_derived]
479 impl #impl_generics ::thiserror::__private::Error for #ty #ty_generics #error_where_clause {
480 #source_method
481 #provide_method
482 }
483 #display_impl
484 #(#from_impls)*
485 }
486}
487
488pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
491 let mut ident = ident.clone();
492 ident.set_span(ident.span().resolved_at(Span::call_site()));
493 ident
494}
495
496fn fields_pat(fields: &[Field]) -> TokenStream {
497 let mut members = fields.iter().map(|field| &field.member).peekable();
498 match members.peek() {
499 Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
500 Some(MemberUnraw::Unnamed(_)) => {
501 let vars = members.map(|member| match member {
502 MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
503 MemberUnraw::Named(_) => unreachable!(),
504 });
505 quote!((#(#vars),*))
506 }
507 None => quote!({}),
508 }
509}
510
511fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
512 if needs_as_display {
513 Some(quote! {
514 use ::thiserror::__private::AsDisplay as _;
515 })
516 } else {
517 None
518 }
519}
520
521fn from_initializer(
522 from_field: &Field,
523 backtrace_field: Option<&Field>,
524 source_var: &Ident,
525) -> TokenStream {
526 let from_member = &from_field.member;
527 let some_source = if type_is_option(from_field.ty) {
528 quote!(::core::option::Option::Some(#source_var))
529 } else {
530 quote!(#source_var)
531 };
532 let backtrace = backtrace_field.map(|backtrace_field| {
533 let backtrace_member = &backtrace_field.member;
534 if type_is_option(backtrace_field.ty) {
535 quote! {
536 #backtrace_member: ::core::option::Option::Some(::thiserror::__private::Backtrace::capture()),
537 }
538 } else {
539 quote! {
540 #backtrace_member: ::core::convert::From::from(::thiserror::__private::Backtrace::capture()),
541 }
542 }
543 });
544 quote!({
545 #from_member: #some_source,
546 #backtrace
547 })
548}
549
550fn type_is_option(ty: &Type) -> bool {
551 type_parameter_of_option(ty).is_some()
552}
553
554fn unoptional_type(ty: &Type) -> TokenStream {
555 let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
556 quote!(#unoptional)
557}
558
559fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
560 let path = match ty {
561 Type::Path(ty) => &ty.path,
562 _ => return None,
563 };
564
565 let last = path.segments.last().unwrap();
566 if last.ident != "Option" {
567 return None;
568 }
569
570 let bracketed = match &last.arguments {
571 PathArguments::AngleBracketed(bracketed) => bracketed,
572 _ => return None,
573 };
574
575 if bracketed.args.len() != 1 {
576 return None;
577 }
578
579 match &bracketed.args[0] {
580 GenericArgument::Type(arg) => Some(arg),
581 _ => None,
582 }
583}