thiserror_impl/
attr.rs

1use proc_macro2::{Delimiter, Group, Literal, Punct, Spacing, Span, TokenStream, TokenTree};
2use quote::{format_ident, quote, quote_spanned, ToTokens};
3use std::collections::BTreeSet as Set;
4use syn::parse::discouraged::Speculative;
5use syn::parse::{End, ParseStream};
6use syn::{
7    braced, bracketed, parenthesized, token, Attribute, Error, ExprPath, Ident, Index, LitFloat,
8    LitInt, LitStr, Meta, Result, Token,
9};
10
11pub struct Attrs<'a> {
12    pub display: Option<Display<'a>>,
13    pub source: Option<Source<'a>>,
14    pub backtrace: Option<&'a Attribute>,
15    pub from: Option<From<'a>>,
16    pub transparent: Option<Transparent<'a>>,
17    pub fmt: Option<Fmt<'a>>,
18}
19
20#[derive(Clone)]
21pub struct Display<'a> {
22    pub original: &'a Attribute,
23    pub fmt: LitStr,
24    pub args: TokenStream,
25    pub requires_fmt_machinery: bool,
26    pub has_bonus_display: bool,
27    pub infinite_recursive: bool,
28    pub implied_bounds: Set<(usize, Trait)>,
29    pub bindings: Vec<(Ident, TokenStream)>,
30}
31
32#[derive(Copy, Clone)]
33pub struct Source<'a> {
34    pub original: &'a Attribute,
35    pub span: Span,
36}
37
38#[derive(Copy, Clone)]
39pub struct From<'a> {
40    pub original: &'a Attribute,
41    pub span: Span,
42}
43
44#[derive(Copy, Clone)]
45pub struct Transparent<'a> {
46    pub original: &'a Attribute,
47    pub span: Span,
48}
49
50#[derive(Clone)]
51pub struct Fmt<'a> {
52    pub original: &'a Attribute,
53    pub path: ExprPath,
54}
55
56#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
57pub enum Trait {
58    Debug,
59    Display,
60    Octal,
61    LowerHex,
62    UpperHex,
63    Pointer,
64    Binary,
65    LowerExp,
66    UpperExp,
67}
68
69pub fn get(input: &[Attribute]) -> Result<Attrs> {
70    let mut attrs = Attrs {
71        display: None,
72        source: None,
73        backtrace: None,
74        from: None,
75        transparent: None,
76        fmt: None,
77    };
78
79    for attr in input {
80        if attr.path().is_ident("error") {
81            parse_error_attribute(&mut attrs, attr)?;
82        } else if attr.path().is_ident("source") {
83            attr.meta.require_path_only()?;
84            if attrs.source.is_some() {
85                return Err(Error::new_spanned(attr, "duplicate #[source] attribute"));
86            }
87            let span = (attr.pound_token.span)
88                .join(attr.bracket_token.span.join())
89                .unwrap_or(attr.path().get_ident().unwrap().span());
90            attrs.source = Some(Source {
91                original: attr,
92                span,
93            });
94        } else if attr.path().is_ident("backtrace") {
95            attr.meta.require_path_only()?;
96            if attrs.backtrace.is_some() {
97                return Err(Error::new_spanned(attr, "duplicate #[backtrace] attribute"));
98            }
99            attrs.backtrace = Some(attr);
100        } else if attr.path().is_ident("from") {
101            match attr.meta {
102                Meta::Path(_) => {}
103                Meta::List(_) | Meta::NameValue(_) => {
104                    // Assume this is meant for derive_more crate or something.
105                    continue;
106                }
107            }
108            if attrs.from.is_some() {
109                return Err(Error::new_spanned(attr, "duplicate #[from] attribute"));
110            }
111            let span = (attr.pound_token.span)
112                .join(attr.bracket_token.span.join())
113                .unwrap_or(attr.path().get_ident().unwrap().span());
114            attrs.from = Some(From {
115                original: attr,
116                span,
117            });
118        }
119    }
120
121    Ok(attrs)
122}
123
124fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
125    mod kw {
126        syn::custom_keyword!(transparent);
127        syn::custom_keyword!(fmt);
128    }
129
130    attr.parse_args_with(|input: ParseStream| {
131        let lookahead = input.lookahead1();
132        let fmt = if lookahead.peek(LitStr) {
133            input.parse::<LitStr>()?
134        } else if lookahead.peek(kw::transparent) {
135            let kw: kw::transparent = input.parse()?;
136            if attrs.transparent.is_some() {
137                return Err(Error::new_spanned(
138                    attr,
139                    "duplicate #[error(transparent)] attribute",
140                ));
141            }
142            attrs.transparent = Some(Transparent {
143                original: attr,
144                span: kw.span,
145            });
146            return Ok(());
147        } else if lookahead.peek(kw::fmt) {
148            input.parse::<kw::fmt>()?;
149            input.parse::<Token![=]>()?;
150            let path: ExprPath = input.parse()?;
151            if attrs.fmt.is_some() {
152                return Err(Error::new_spanned(
153                    attr,
154                    "duplicate #[error(fmt = ...)] attribute",
155                ));
156            }
157            attrs.fmt = Some(Fmt {
158                original: attr,
159                path,
160            });
161            return Ok(());
162        } else {
163            return Err(lookahead.error());
164        };
165
166        let args = if input.is_empty() || input.peek(Token![,]) && input.peek2(End) {
167            input.parse::<Option<Token![,]>>()?;
168            TokenStream::new()
169        } else {
170            parse_token_expr(input, false)?
171        };
172
173        let requires_fmt_machinery = !args.is_empty();
174
175        let display = Display {
176            original: attr,
177            fmt,
178            args,
179            requires_fmt_machinery,
180            has_bonus_display: false,
181            infinite_recursive: false,
182            implied_bounds: Set::new(),
183            bindings: Vec::new(),
184        };
185        if attrs.display.is_some() {
186            return Err(Error::new_spanned(
187                attr,
188                "only one #[error(...)] attribute is allowed",
189            ));
190        }
191        attrs.display = Some(display);
192        Ok(())
193    })
194}
195
196fn parse_token_expr(input: ParseStream, mut begin_expr: bool) -> Result<TokenStream> {
197    let mut tokens = Vec::new();
198    while !input.is_empty() {
199        if input.peek(token::Group) {
200            let group: TokenTree = input.parse()?;
201            tokens.push(group);
202            begin_expr = false;
203            continue;
204        }
205
206        if begin_expr && input.peek(Token![.]) {
207            if input.peek2(Ident) {
208                input.parse::<Token![.]>()?;
209                begin_expr = false;
210                continue;
211            } else if input.peek2(LitInt) {
212                input.parse::<Token![.]>()?;
213                let int: Index = input.parse()?;
214                tokens.push({
215                    let ident = format_ident!("_{}", int.index, span = int.span);
216                    TokenTree::Ident(ident)
217                });
218                begin_expr = false;
219                continue;
220            } else if input.peek2(LitFloat) {
221                let ahead = input.fork();
222                ahead.parse::<Token![.]>()?;
223                let float: LitFloat = ahead.parse()?;
224                let repr = float.to_string();
225                let mut indices = repr.split('.').map(syn::parse_str::<Index>);
226                if let (Some(Ok(first)), Some(Ok(second)), None) =
227                    (indices.next(), indices.next(), indices.next())
228                {
229                    input.advance_to(&ahead);
230                    tokens.push({
231                        let ident = format_ident!("_{}", first, span = float.span());
232                        TokenTree::Ident(ident)
233                    });
234                    tokens.push({
235                        let mut punct = Punct::new('.', Spacing::Alone);
236                        punct.set_span(float.span());
237                        TokenTree::Punct(punct)
238                    });
239                    tokens.push({
240                        let mut literal = Literal::u32_unsuffixed(second.index);
241                        literal.set_span(float.span());
242                        TokenTree::Literal(literal)
243                    });
244                    begin_expr = false;
245                    continue;
246                }
247            }
248        }
249
250        begin_expr = input.peek(Token![break])
251            || input.peek(Token![continue])
252            || input.peek(Token![if])
253            || input.peek(Token![in])
254            || input.peek(Token![match])
255            || input.peek(Token![mut])
256            || input.peek(Token![return])
257            || input.peek(Token![while])
258            || input.peek(Token![+])
259            || input.peek(Token![&])
260            || input.peek(Token![!])
261            || input.peek(Token![^])
262            || input.peek(Token![,])
263            || input.peek(Token![/])
264            || input.peek(Token![=])
265            || input.peek(Token![>])
266            || input.peek(Token![<])
267            || input.peek(Token![|])
268            || input.peek(Token![%])
269            || input.peek(Token![;])
270            || input.peek(Token![*])
271            || input.peek(Token![-]);
272
273        let token: TokenTree = if input.peek(token::Paren) {
274            let content;
275            let delimiter = parenthesized!(content in input);
276            let nested = parse_token_expr(&content, true)?;
277            let mut group = Group::new(Delimiter::Parenthesis, nested);
278            group.set_span(delimiter.span.join());
279            TokenTree::Group(group)
280        } else if input.peek(token::Brace) {
281            let content;
282            let delimiter = braced!(content in input);
283            let nested = parse_token_expr(&content, true)?;
284            let mut group = Group::new(Delimiter::Brace, nested);
285            group.set_span(delimiter.span.join());
286            TokenTree::Group(group)
287        } else if input.peek(token::Bracket) {
288            let content;
289            let delimiter = bracketed!(content in input);
290            let nested = parse_token_expr(&content, true)?;
291            let mut group = Group::new(Delimiter::Bracket, nested);
292            group.set_span(delimiter.span.join());
293            TokenTree::Group(group)
294        } else {
295            input.parse()?
296        };
297        tokens.push(token);
298    }
299    Ok(TokenStream::from_iter(tokens))
300}
301
302impl ToTokens for Display<'_> {
303    fn to_tokens(&self, tokens: &mut TokenStream) {
304        if self.infinite_recursive {
305            let span = self.fmt.span();
306            tokens.extend(quote_spanned! {span=>
307                #[warn(unconditional_recursion)]
308                fn _fmt() { _fmt() }
309            });
310        }
311
312        let fmt = &self.fmt;
313        let args = &self.args;
314
315        // Currently `write!(f, "text")` produces less efficient code than
316        // `f.write_str("text")`. We recognize the case when the format string
317        // has no braces and no interpolated values, and generate simpler code.
318        let write = if self.requires_fmt_machinery {
319            quote! {
320                ::core::write!(__formatter, #fmt #args)
321            }
322        } else {
323            quote! {
324                __formatter.write_str(#fmt)
325            }
326        };
327
328        tokens.extend(if self.bindings.is_empty() {
329            write
330        } else {
331            let locals = self.bindings.iter().map(|(local, _value)| local);
332            let values = self.bindings.iter().map(|(_local, value)| value);
333            quote! {
334                match (#(#values,)*) {
335                    (#(#locals,)*) => #write
336                }
337            }
338        });
339    }
340}
341
342impl ToTokens for Trait {
343    fn to_tokens(&self, tokens: &mut TokenStream) {
344        let trait_name = match self {
345            Trait::Debug => "Debug",
346            Trait::Display => "Display",
347            Trait::Octal => "Octal",
348            Trait::LowerHex => "LowerHex",
349            Trait::UpperHex => "UpperHex",
350            Trait::Pointer => "Pointer",
351            Trait::Binary => "Binary",
352            Trait::LowerExp => "LowerExp",
353            Trait::UpperExp => "UpperExp",
354        };
355        let ident = Ident::new(trait_name, Span::call_site());
356        tokens.extend(quote!(::core::fmt::#ident));
357    }
358}