thiserror_impl/
fmt.rs

1use crate::ast::{ContainerKind, Field};
2use crate::attr::{Display, Trait};
3use crate::scan_expr::scan_expr;
4use crate::unraw::{IdentUnraw, MemberUnraw};
5use proc_macro2::{Delimiter, TokenStream, TokenTree};
6use quote::{format_ident, quote, quote_spanned, ToTokens as _};
7use std::collections::{BTreeSet, HashMap};
8use std::iter;
9use syn::ext::IdentExt;
10use syn::parse::discouraged::Speculative;
11use syn::parse::{Error, ParseStream, Parser, Result};
12use syn::{Expr, Ident, Index, LitStr, Token};
13
14impl Display<'_> {
15    pub fn expand_shorthand(&mut self, fields: &[Field], container: ContainerKind) -> Result<()> {
16        let raw_args = self.args.clone();
17        let FmtArguments {
18            named: user_named_args,
19            first_unnamed,
20        } = explicit_named_args.parse2(raw_args).unwrap();
21
22        let mut member_index = HashMap::new();
23        let mut extra_positional_arguments_allowed = true;
24        for (i, field) in fields.iter().enumerate() {
25            member_index.insert(&field.member, i);
26            extra_positional_arguments_allowed &= matches!(&field.member, MemberUnraw::Named(_));
27        }
28
29        let span = self.fmt.span();
30        let fmt = self.fmt.value();
31        let mut read = fmt.as_str();
32        let mut out = String::new();
33        let mut has_bonus_display = false;
34        let mut infinite_recursive = false;
35        let mut implied_bounds = BTreeSet::new();
36        let mut bindings = Vec::new();
37        let mut macro_named_args = BTreeSet::new();
38
39        self.requires_fmt_machinery = self.requires_fmt_machinery || fmt.contains('}');
40
41        while let Some(brace) = read.find('{') {
42            self.requires_fmt_machinery = true;
43            out += &read[..brace + 1];
44            read = &read[brace + 1..];
45            if read.starts_with('{') {
46                out.push('{');
47                read = &read[1..];
48                continue;
49            }
50            let next = match read.chars().next() {
51                Some(next) => next,
52                None => return Ok(()),
53            };
54            let member = match next {
55                '0'..='9' => {
56                    let int = take_int(&mut read);
57                    if !extra_positional_arguments_allowed {
58                        if let Some(first_unnamed) = &first_unnamed {
59                            let msg = format!("ambiguous reference to positional arguments by number in a {container}; change this to a named argument");
60                            return Err(Error::new_spanned(first_unnamed, msg));
61                        }
62                    }
63                    match int.parse::<u32>() {
64                        Ok(index) => MemberUnraw::Unnamed(Index { index, span }),
65                        Err(_) => return Ok(()),
66                    }
67                }
68                'a'..='z' | 'A'..='Z' | '_' => {
69                    if read.starts_with("r#") {
70                        continue;
71                    }
72                    let repr = take_ident(&mut read);
73                    if repr == "_" {
74                        // Invalid. Let rustc produce the diagnostic.
75                        out += repr;
76                        continue;
77                    }
78                    let ident = IdentUnraw::new(Ident::new(repr, span));
79                    if user_named_args.contains(&ident) {
80                        // Refers to a named argument written by the user, not to field.
81                        out += repr;
82                        continue;
83                    }
84                    MemberUnraw::Named(ident)
85                }
86                _ => continue,
87            };
88            let end_spec = match read.find('}') {
89                Some(end_spec) => end_spec,
90                None => return Ok(()),
91            };
92            let mut bonus_display = false;
93            let bound = match read[..end_spec].chars().next_back() {
94                Some('?') => Trait::Debug,
95                Some('o') => Trait::Octal,
96                Some('x') => Trait::LowerHex,
97                Some('X') => Trait::UpperHex,
98                Some('p') => Trait::Pointer,
99                Some('b') => Trait::Binary,
100                Some('e') => Trait::LowerExp,
101                Some('E') => Trait::UpperExp,
102                Some(_) => Trait::Display,
103                None => {
104                    bonus_display = true;
105                    has_bonus_display = true;
106                    Trait::Display
107                }
108            };
109            infinite_recursive |= member == *"self" && bound == Trait::Display;
110            let field = match member_index.get(&member) {
111                Some(&field) => field,
112                None => {
113                    out += &member.to_string();
114                    continue;
115                }
116            };
117            implied_bounds.insert((field, bound));
118            let formatvar_prefix = if bonus_display {
119                "__display"
120            } else if bound == Trait::Pointer {
121                "__pointer"
122            } else {
123                "__field"
124            };
125            let mut formatvar = IdentUnraw::new(match &member {
126                MemberUnraw::Unnamed(index) => format_ident!("{}{}", formatvar_prefix, index),
127                MemberUnraw::Named(ident) => {
128                    format_ident!("{}_{}", formatvar_prefix, ident.to_string())
129                }
130            });
131            while user_named_args.contains(&formatvar) {
132                formatvar = IdentUnraw::new(format_ident!("_{}", formatvar.to_string()));
133            }
134            formatvar.set_span(span);
135            out += &formatvar.to_string();
136            if !macro_named_args.insert(formatvar.clone()) {
137                // Already added to bindings by a previous use.
138                continue;
139            }
140            let mut binding_value = match &member {
141                MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
142                MemberUnraw::Named(ident) => ident.to_local(),
143            };
144            binding_value.set_span(span.resolved_at(fields[field].member.span()));
145            let wrapped_binding_value = if bonus_display {
146                quote_spanned!(span=> #binding_value.as_display())
147            } else if bound == Trait::Pointer {
148                quote!(::thiserror::__private::Var(#binding_value))
149            } else {
150                binding_value.into_token_stream()
151            };
152            bindings.push((formatvar.to_local(), wrapped_binding_value));
153        }
154
155        out += read;
156        self.fmt = LitStr::new(&out, self.fmt.span());
157        self.has_bonus_display = has_bonus_display;
158        self.infinite_recursive = infinite_recursive;
159        self.implied_bounds = implied_bounds;
160        self.bindings = bindings;
161        Ok(())
162    }
163}
164
165struct FmtArguments {
166    named: BTreeSet<IdentUnraw>,
167    first_unnamed: Option<TokenStream>,
168}
169
170#[allow(clippy::unnecessary_wraps)]
171fn explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
172    let ahead = input.fork();
173    if let Ok(set) = try_explicit_named_args(&ahead) {
174        input.advance_to(&ahead);
175        return Ok(set);
176    }
177
178    let ahead = input.fork();
179    if let Ok(set) = fallback_explicit_named_args(&ahead) {
180        input.advance_to(&ahead);
181        return Ok(set);
182    }
183
184    input.parse::<TokenStream>().unwrap();
185    Ok(FmtArguments {
186        named: BTreeSet::new(),
187        first_unnamed: None,
188    })
189}
190
191fn try_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
192    let mut syn_full = None;
193    let mut args = FmtArguments {
194        named: BTreeSet::new(),
195        first_unnamed: None,
196    };
197
198    while !input.is_empty() {
199        input.parse::<Token![,]>()?;
200        if input.is_empty() {
201            break;
202        }
203
204        let mut begin_unnamed = None;
205        if input.peek(Ident::peek_any) && input.peek2(Token![=]) && !input.peek2(Token![==]) {
206            let ident: IdentUnraw = input.parse()?;
207            input.parse::<Token![=]>()?;
208            args.named.insert(ident);
209        } else {
210            begin_unnamed = Some(input.fork());
211        }
212
213        let ahead = input.fork();
214        if *syn_full.get_or_insert_with(is_syn_full) && ahead.parse::<Expr>().is_ok() {
215            input.advance_to(&ahead);
216        } else {
217            scan_expr(input)?;
218        }
219
220        if let Some(begin_unnamed) = begin_unnamed {
221            if args.first_unnamed.is_none() {
222                args.first_unnamed = Some(between(&begin_unnamed, input));
223            }
224        }
225    }
226
227    Ok(args)
228}
229
230fn fallback_explicit_named_args(input: ParseStream) -> Result<FmtArguments> {
231    let mut args = FmtArguments {
232        named: BTreeSet::new(),
233        first_unnamed: None,
234    };
235
236    while !input.is_empty() {
237        if input.peek(Token![,])
238            && input.peek2(Ident::peek_any)
239            && input.peek3(Token![=])
240            && !input.peek3(Token![==])
241        {
242            input.parse::<Token![,]>()?;
243            let ident: IdentUnraw = input.parse()?;
244            input.parse::<Token![=]>()?;
245            args.named.insert(ident);
246        } else {
247            input.parse::<TokenTree>()?;
248        }
249    }
250
251    Ok(args)
252}
253
254fn is_syn_full() -> bool {
255    // Expr::Block contains syn::Block which contains Vec<syn::Stmt>. In the
256    // current version of Syn, syn::Stmt is exhaustive and could only plausibly
257    // represent `trait Trait {}` in Stmt::Item which contains syn::Item. Most
258    // of the point of syn's non-"full" mode is to avoid compiling Item and the
259    // entire expansive syntax tree it comprises. So the following expression
260    // being parsed to Expr::Block is a reliable indication that "full" is
261    // enabled.
262    let test = quote!({
263        trait Trait {}
264    });
265    match syn::parse2(test) {
266        Ok(Expr::Verbatim(_)) | Err(_) => false,
267        Ok(Expr::Block(_)) => true,
268        Ok(_) => unreachable!(),
269    }
270}
271
272fn take_int<'a>(read: &mut &'a str) -> &'a str {
273    let mut int_len = 0;
274    for ch in read.chars() {
275        match ch {
276            '0'..='9' => int_len += 1,
277            _ => break,
278        }
279    }
280    let (int, rest) = read.split_at(int_len);
281    *read = rest;
282    int
283}
284
285fn take_ident<'a>(read: &mut &'a str) -> &'a str {
286    let mut ident_len = 0;
287    for ch in read.chars() {
288        match ch {
289            'a'..='z' | 'A'..='Z' | '0'..='9' | '_' => ident_len += 1,
290            _ => break,
291        }
292    }
293    let (ident, rest) = read.split_at(ident_len);
294    *read = rest;
295    ident
296}
297
298fn between<'a>(begin: ParseStream<'a>, end: ParseStream<'a>) -> TokenStream {
299    let end = end.cursor();
300    let mut cursor = begin.cursor();
301    let mut tokens = TokenStream::new();
302
303    while cursor < end {
304        let (tt, next) = cursor.token_tree().unwrap();
305
306        if end < next {
307            if let Some((inside, _span, _after)) = cursor.group(Delimiter::None) {
308                cursor = inside;
309                continue;
310            }
311            if tokens.is_empty() {
312                tokens.extend(iter::once(tt));
313            }
314            break;
315        }
316
317        tokens.extend(iter::once(tt));
318        cursor = next;
319    }
320
321    tokens
322}