thiserror_impl/
valid.rs

1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use syn::{Error, GenericArgument, PathArguments, Result, Type};
4
5impl Input<'_> {
6    pub(crate) fn validate(&self) -> Result<()> {
7        match self {
8            Input::Struct(input) => input.validate(),
9            Input::Enum(input) => input.validate(),
10        }
11    }
12}
13
14impl Struct<'_> {
15    fn validate(&self) -> Result<()> {
16        check_non_field_attrs(&self.attrs)?;
17        if let Some(transparent) = self.attrs.transparent {
18            if self.fields.len() != 1 {
19                return Err(Error::new_spanned(
20                    transparent.original,
21                    "#[error(transparent)] requires exactly one field",
22                ));
23            }
24            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
25                return Err(Error::new_spanned(
26                    source.original,
27                    "transparent error struct can't contain #[source]",
28                ));
29            }
30        }
31        if let Some(fmt) = &self.attrs.fmt {
32            return Err(Error::new_spanned(
33                fmt.original,
34                "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
35            ));
36        }
37        check_field_attrs(&self.fields)?;
38        for field in &self.fields {
39            field.validate()?;
40        }
41        Ok(())
42    }
43}
44
45impl Enum<'_> {
46    fn validate(&self) -> Result<()> {
47        check_non_field_attrs(&self.attrs)?;
48        let has_display = self.has_display();
49        for variant in &self.variants {
50            variant.validate()?;
51            if has_display
52                && variant.attrs.display.is_none()
53                && variant.attrs.transparent.is_none()
54                && variant.attrs.fmt.is_none()
55            {
56                return Err(Error::new_spanned(
57                    variant.original,
58                    "missing #[error(\"...\")] display attribute",
59                ));
60            }
61        }
62        Ok(())
63    }
64}
65
66impl Variant<'_> {
67    fn validate(&self) -> Result<()> {
68        check_non_field_attrs(&self.attrs)?;
69        if self.attrs.transparent.is_some() {
70            if self.fields.len() != 1 {
71                return Err(Error::new_spanned(
72                    self.original,
73                    "#[error(transparent)] requires exactly one field",
74                ));
75            }
76            if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
77                return Err(Error::new_spanned(
78                    source.original,
79                    "transparent variant can't contain #[source]",
80                ));
81            }
82        }
83        check_field_attrs(&self.fields)?;
84        for field in &self.fields {
85            field.validate()?;
86        }
87        Ok(())
88    }
89}
90
91impl Field<'_> {
92    fn validate(&self) -> Result<()> {
93        if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
94            Some(display.original)
95        } else if let Some(fmt) = &self.attrs.fmt {
96            Some(fmt.original)
97        } else {
98            None
99        } {
100            return Err(Error::new_spanned(
101                unexpected_display_attr,
102                "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
103            ));
104        }
105        Ok(())
106    }
107}
108
109fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
110    if let Some(from) = &attrs.from {
111        return Err(Error::new_spanned(
112            from.original,
113            "not expected here; the #[from] attribute belongs on a specific field",
114        ));
115    }
116    if let Some(source) = &attrs.source {
117        return Err(Error::new_spanned(
118            source.original,
119            "not expected here; the #[source] attribute belongs on a specific field",
120        ));
121    }
122    if let Some(backtrace) = &attrs.backtrace {
123        return Err(Error::new_spanned(
124            backtrace,
125            "not expected here; the #[backtrace] attribute belongs on a specific field",
126        ));
127    }
128    if attrs.transparent.is_some() {
129        if let Some(display) = &attrs.display {
130            return Err(Error::new_spanned(
131                display.original,
132                "cannot have both #[error(transparent)] and a display attribute",
133            ));
134        }
135        if let Some(fmt) = &attrs.fmt {
136            return Err(Error::new_spanned(
137                fmt.original,
138                "cannot have both #[error(transparent)] and #[error(fmt = ...)]",
139            ));
140        }
141    } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
142        return Err(Error::new_spanned(
143            display.original,
144            "cannot have both #[error(fmt = ...)] and a format arguments attribute",
145        ));
146    }
147
148    Ok(())
149}
150
151fn check_field_attrs(fields: &[Field]) -> Result<()> {
152    let mut from_field = None;
153    let mut source_field = None;
154    let mut backtrace_field = None;
155    let mut has_backtrace = false;
156    for field in fields {
157        if let Some(from) = field.attrs.from {
158            if from_field.is_some() {
159                return Err(Error::new_spanned(
160                    from.original,
161                    "duplicate #[from] attribute",
162                ));
163            }
164            from_field = Some(field);
165        }
166        if let Some(source) = field.attrs.source {
167            if source_field.is_some() {
168                return Err(Error::new_spanned(
169                    source.original,
170                    "duplicate #[source] attribute",
171                ));
172            }
173            source_field = Some(field);
174        }
175        if let Some(backtrace) = field.attrs.backtrace {
176            if backtrace_field.is_some() {
177                return Err(Error::new_spanned(
178                    backtrace,
179                    "duplicate #[backtrace] attribute",
180                ));
181            }
182            backtrace_field = Some(field);
183            has_backtrace = true;
184        }
185        if let Some(transparent) = field.attrs.transparent {
186            return Err(Error::new_spanned(
187                transparent.original,
188                "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
189            ));
190        }
191        has_backtrace |= field.is_backtrace();
192    }
193    if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
194        if from_field.member != source_field.member {
195            return Err(Error::new_spanned(
196                from_field.attrs.from.unwrap().original,
197                "#[from] is only supported on the source field, not any other field",
198            ));
199        }
200    }
201    if let Some(from_field) = from_field {
202        let max_expected_fields = match backtrace_field {
203            Some(backtrace_field) => 1 + (from_field.member != backtrace_field.member) as usize,
204            None => 1 + has_backtrace as usize,
205        };
206        if fields.len() > max_expected_fields {
207            return Err(Error::new_spanned(
208                from_field.attrs.from.unwrap().original,
209                "deriving From requires no fields other than source and backtrace",
210            ));
211        }
212    }
213    if let Some(source_field) = source_field.or(from_field) {
214        if contains_non_static_lifetime(source_field.ty) {
215            return Err(Error::new_spanned(
216                &source_field.original.ty,
217                "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
218            ));
219        }
220    }
221    Ok(())
222}
223
224fn contains_non_static_lifetime(ty: &Type) -> bool {
225    match ty {
226        Type::Path(ty) => {
227            let bracketed = match &ty.path.segments.last().unwrap().arguments {
228                PathArguments::AngleBracketed(bracketed) => bracketed,
229                _ => return false,
230            };
231            for arg in &bracketed.args {
232                match arg {
233                    GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
234                    GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
235                        return true
236                    }
237                    _ => {}
238                }
239            }
240            false
241        }
242        Type::Reference(ty) => ty
243            .lifetime
244            .as_ref()
245            .map_or(false, |lifetime| lifetime.ident != "static"),
246        _ => false, // maybe implement later if there are common other cases
247    }
248}