thiserror_impl/
valid.rs
1use crate::ast::{Enum, Field, Input, Struct, Variant};
2use crate::attr::Attrs;
3use syn::{Error, GenericArgument, PathArguments, Result, Type};
4
5impl Input<'_> {
6 pub(crate) fn validate(&self) -> Result<()> {
7 match self {
8 Input::Struct(input) => input.validate(),
9 Input::Enum(input) => input.validate(),
10 }
11 }
12}
13
14impl Struct<'_> {
15 fn validate(&self) -> Result<()> {
16 check_non_field_attrs(&self.attrs)?;
17 if let Some(transparent) = self.attrs.transparent {
18 if self.fields.len() != 1 {
19 return Err(Error::new_spanned(
20 transparent.original,
21 "#[error(transparent)] requires exactly one field",
22 ));
23 }
24 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
25 return Err(Error::new_spanned(
26 source.original,
27 "transparent error struct can't contain #[source]",
28 ));
29 }
30 }
31 if let Some(fmt) = &self.attrs.fmt {
32 return Err(Error::new_spanned(
33 fmt.original,
34 "#[error(fmt = ...)] is only supported in enums; for a struct, handwrite your own Display impl",
35 ));
36 }
37 check_field_attrs(&self.fields)?;
38 for field in &self.fields {
39 field.validate()?;
40 }
41 Ok(())
42 }
43}
44
45impl Enum<'_> {
46 fn validate(&self) -> Result<()> {
47 check_non_field_attrs(&self.attrs)?;
48 let has_display = self.has_display();
49 for variant in &self.variants {
50 variant.validate()?;
51 if has_display
52 && variant.attrs.display.is_none()
53 && variant.attrs.transparent.is_none()
54 && variant.attrs.fmt.is_none()
55 {
56 return Err(Error::new_spanned(
57 variant.original,
58 "missing #[error(\"...\")] display attribute",
59 ));
60 }
61 }
62 Ok(())
63 }
64}
65
66impl Variant<'_> {
67 fn validate(&self) -> Result<()> {
68 check_non_field_attrs(&self.attrs)?;
69 if self.attrs.transparent.is_some() {
70 if self.fields.len() != 1 {
71 return Err(Error::new_spanned(
72 self.original,
73 "#[error(transparent)] requires exactly one field",
74 ));
75 }
76 if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) {
77 return Err(Error::new_spanned(
78 source.original,
79 "transparent variant can't contain #[source]",
80 ));
81 }
82 }
83 check_field_attrs(&self.fields)?;
84 for field in &self.fields {
85 field.validate()?;
86 }
87 Ok(())
88 }
89}
90
91impl Field<'_> {
92 fn validate(&self) -> Result<()> {
93 if let Some(unexpected_display_attr) = if let Some(display) = &self.attrs.display {
94 Some(display.original)
95 } else if let Some(fmt) = &self.attrs.fmt {
96 Some(fmt.original)
97 } else {
98 None
99 } {
100 return Err(Error::new_spanned(
101 unexpected_display_attr,
102 "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant",
103 ));
104 }
105 Ok(())
106 }
107}
108
109fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
110 if let Some(from) = &attrs.from {
111 return Err(Error::new_spanned(
112 from.original,
113 "not expected here; the #[from] attribute belongs on a specific field",
114 ));
115 }
116 if let Some(source) = &attrs.source {
117 return Err(Error::new_spanned(
118 source.original,
119 "not expected here; the #[source] attribute belongs on a specific field",
120 ));
121 }
122 if let Some(backtrace) = &attrs.backtrace {
123 return Err(Error::new_spanned(
124 backtrace,
125 "not expected here; the #[backtrace] attribute belongs on a specific field",
126 ));
127 }
128 if attrs.transparent.is_some() {
129 if let Some(display) = &attrs.display {
130 return Err(Error::new_spanned(
131 display.original,
132 "cannot have both #[error(transparent)] and a display attribute",
133 ));
134 }
135 if let Some(fmt) = &attrs.fmt {
136 return Err(Error::new_spanned(
137 fmt.original,
138 "cannot have both #[error(transparent)] and #[error(fmt = ...)]",
139 ));
140 }
141 } else if let (Some(display), Some(_)) = (&attrs.display, &attrs.fmt) {
142 return Err(Error::new_spanned(
143 display.original,
144 "cannot have both #[error(fmt = ...)] and a format arguments attribute",
145 ));
146 }
147
148 Ok(())
149}
150
151fn check_field_attrs(fields: &[Field]) -> Result<()> {
152 let mut from_field = None;
153 let mut source_field = None;
154 let mut backtrace_field = None;
155 let mut has_backtrace = false;
156 for field in fields {
157 if let Some(from) = field.attrs.from {
158 if from_field.is_some() {
159 return Err(Error::new_spanned(
160 from.original,
161 "duplicate #[from] attribute",
162 ));
163 }
164 from_field = Some(field);
165 }
166 if let Some(source) = field.attrs.source {
167 if source_field.is_some() {
168 return Err(Error::new_spanned(
169 source.original,
170 "duplicate #[source] attribute",
171 ));
172 }
173 source_field = Some(field);
174 }
175 if let Some(backtrace) = field.attrs.backtrace {
176 if backtrace_field.is_some() {
177 return Err(Error::new_spanned(
178 backtrace,
179 "duplicate #[backtrace] attribute",
180 ));
181 }
182 backtrace_field = Some(field);
183 has_backtrace = true;
184 }
185 if let Some(transparent) = field.attrs.transparent {
186 return Err(Error::new_spanned(
187 transparent.original,
188 "#[error(transparent)] needs to go outside the enum or struct, not on an individual field",
189 ));
190 }
191 has_backtrace |= field.is_backtrace();
192 }
193 if let (Some(from_field), Some(source_field)) = (from_field, source_field) {
194 if from_field.member != source_field.member {
195 return Err(Error::new_spanned(
196 from_field.attrs.from.unwrap().original,
197 "#[from] is only supported on the source field, not any other field",
198 ));
199 }
200 }
201 if let Some(from_field) = from_field {
202 let max_expected_fields = match backtrace_field {
203 Some(backtrace_field) => 1 + (from_field.member != backtrace_field.member) as usize,
204 None => 1 + has_backtrace as usize,
205 };
206 if fields.len() > max_expected_fields {
207 return Err(Error::new_spanned(
208 from_field.attrs.from.unwrap().original,
209 "deriving From requires no fields other than source and backtrace",
210 ));
211 }
212 }
213 if let Some(source_field) = source_field.or(from_field) {
214 if contains_non_static_lifetime(source_field.ty) {
215 return Err(Error::new_spanned(
216 &source_field.original.ty,
217 "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static",
218 ));
219 }
220 }
221 Ok(())
222}
223
224fn contains_non_static_lifetime(ty: &Type) -> bool {
225 match ty {
226 Type::Path(ty) => {
227 let bracketed = match &ty.path.segments.last().unwrap().arguments {
228 PathArguments::AngleBracketed(bracketed) => bracketed,
229 _ => return false,
230 };
231 for arg in &bracketed.args {
232 match arg {
233 GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true,
234 GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => {
235 return true
236 }
237 _ => {}
238 }
239 }
240 false
241 }
242 Type::Reference(ty) => ty
243 .lifetime
244 .as_ref()
245 .map_or(false, |lifetime| lifetime.ident != "static"),
246 _ => false, }
248}