1 // Copyright 2022, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Derive macro for `AsCborValue`.
16 use proc_macro2::TokenStream;
17 use quote::{format_ident, quote, quote_spanned};
18 use syn::{
19     parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
20     Generics, Ident, Index,
21 };
22 
23 /// Derive macro that implements the `AsCborValue` trait.  Using this macro requires
24 /// that `AsCborValue`, `CborError` and `cbor_type_error` are locally `use`d.
25 #[proc_macro_derive(AsCborValue)]
derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream26 pub fn derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27     let input = parse_macro_input!(input as DeriveInput);
28     derive_as_cbor_value_internal(&input)
29 }
30 
derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream31 fn derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream {
32     let name = &input.ident;
33 
34     // Add a bound `T: AsCborValue` for every type parameter `T`.
35     let generics = add_trait_bounds(&input.generics);
36     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37 
38     let from_val = from_val_struct(&input.data);
39     let to_val = to_val_struct(&input.data);
40     let cddl = cddl_struct(name, &input.data);
41 
42     let expanded = quote! {
43         // The generated impl
44         impl #impl_generics AsCborValue for #name #ty_generics #where_clause {
45             fn from_cbor_value(value: ciborium::value::Value) -> Result<Self, CborError> {
46                 #from_val
47             }
48             fn to_cbor_value(self) -> Result<ciborium::value::Value, CborError> {
49                 #to_val
50             }
51             fn cddl_typename() -> Option<String> {
52                 Some(stringify!(#name).to_string())
53             }
54             fn cddl_schema() -> Option<String> {
55                 #cddl
56             }
57         }
58     };
59 
60     expanded.into()
61 }
62 
63 /// Add a bound `T: AsCborValue` for every type parameter `T`.
add_trait_bounds(generics: &Generics) -> Generics64 fn add_trait_bounds(generics: &Generics) -> Generics {
65     let mut generics = generics.clone();
66     for param in &mut generics.params {
67         if let GenericParam::Type(ref mut type_param) = *param {
68             type_param.bounds.push(parse_quote!(AsCborValue));
69         }
70     }
71     generics
72 }
73 
74 /// Generate an expression to convert an instance of a compound type to `ciborium::value::Value`
to_val_struct(data: &Data) -> TokenStream75 fn to_val_struct(data: &Data) -> TokenStream {
76     match *data {
77         Data::Struct(ref data) => {
78             match data.fields {
79                 Fields::Named(ref fields) => {
80                     // Expands to an expression like
81                     //
82                     //     {
83                     //         let mut v = Vec::new();
84                     //         v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
85                     //         v.push(AsCborValue::to_cbor_value(self.x)?);
86                     //         v.push(AsCborValue::to_cbor_value(self.y)?);
87                     //         v.push(AsCborValue::to_cbor_value(self.z)?);
88                     //         Ok(ciborium::value::Value::Array(v))
89                     //     }
90                     let nfields = fields.named.len();
91                     let recurse = fields.named.iter().map(|f| {
92                         let name = &f.ident;
93                         quote_spanned! {f.span()=>
94                             v.push(AsCborValue::to_cbor_value(self.#name)?)
95                         }
96                     });
97                     quote! {
98                         {
99                             let mut v = Vec::new();
100                             v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
101                             #(#recurse; )*
102                             Ok(ciborium::value::Value::Array(v))
103                         }
104                     }
105                 }
106                 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
107                     // For a newtype, expands to an expression
108                     //
109                     //     self.0.to_cbor_value()
110                     quote! {
111                         self.0.to_cbor_value()
112                     }
113                 }
114                 Fields::Unnamed(ref fields) => {
115                     // Expands to an expression like
116                     //
117                     //
118                     //     {
119                     //         let mut v = Vec::new();
120                     //         v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
121                     //         v.push(AsCborValue::to_cbor_value(self.0)?);
122                     //         v.push(AsCborValue::to_cbor_value(self.1)?);
123                     //         v.push(AsCborValue::to_cbor_value(self.2)?);
124                     //         Ok(ciborium::value::Value::Array(v))
125                     //     }
126                     let nfields = fields.unnamed.len();
127                     let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
128                         let index = Index::from(i);
129                         quote_spanned! {f.span()=>
130                             v.push(AsCborValue::to_cbor_value(self.#index)?)
131                         }
132                     });
133                     quote! {
134                         {
135                             let mut v = Vec::new();
136                             v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
137                             #(#recurse; )*
138                             Ok(ciborium::value::Value::Array(v))
139                         }
140                     }
141                 }
142                 Fields::Unit => unimplemented!(),
143             }
144         }
145         Data::Enum(_) => {
146             quote! {
147                 let v: ciborium::value::Integer = (self as i32).into();
148                 Ok(ciborium::value::Value::Integer(v))
149             }
150         }
151         Data::Union(_) => unimplemented!(),
152     }
153 }
154 
155 /// Generate an expression to convert a `ciborium::value::Value` into an instance of a compound
156 /// type.
from_val_struct(data: &Data) -> TokenStream157 fn from_val_struct(data: &Data) -> TokenStream {
158     match data {
159         Data::Struct(ref data) => {
160             match data.fields {
161                 Fields::Named(ref fields) => {
162                     // Expands to an expression like
163                     //
164                     //     let mut a = match value {
165                     //         ciborium::value::Value::Array(a) => a,
166                     //         _ => return cbor_type_error(&value, "arr"),
167                     //     };
168                     //     if a.len() != 3 {
169                     //         return Err(CborError::UnexpectedItem("arr", "arr len 3"));
170                     //     }
171                     //     // Fields specified in reverse order to reduce shifting.
172                     //     Ok(Self {
173                     //         z: <ZType>::from_cbor_value(a.remove(2))?,
174                     //         y: <YType>::from_cbor_value(a.remove(1))?,
175                     //         x: <XType>::from_cbor_value(a.remove(0))?,
176                     //     })
177                     //
178                     // but using fully qualified function call syntax.
179                     let nfields = fields.named.len();
180                     let recurse = fields.named.iter().enumerate().rev().map(|(i, f)| {
181                         let name = &f.ident;
182                         let index = Index::from(i);
183                         let typ = &f.ty;
184                         quote_spanned! {f.span()=>
185                                         #name: <#typ>::from_cbor_value(a.remove(#index))?
186                         }
187                     });
188                     quote! {
189                         let mut a = match value {
190                             ciborium::value::Value::Array(a) => a,
191                             _ => return cbor_type_error(&value, "arr"),
192                         };
193                         if a.len() != #nfields {
194                             return Err(CborError::UnexpectedItem(
195                                 "arr",
196                                 concat!("arr len ", stringify!(#nfields)),
197                             ));
198                         }
199                         // Fields specified in reverse order to reduce shifting.
200                         Ok(Self {
201                             #(#recurse, )*
202                         })
203                     }
204                 }
205                 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
206                     // For a newtype, expands to an expression like
207                     //
208                     //     Ok(Self(<InnerType>::from_cbor_value(value)?))
209                     let inner = fields.unnamed.first().unwrap();
210                     let typ = &inner.ty;
211                     quote! {
212                         Ok(Self(<#typ>::from_cbor_value(value)?))
213                     }
214                 }
215                 Fields::Unnamed(ref fields) => {
216                     // Expands to an expression like
217                     //
218                     //     let mut a = match value {
219                     //         ciborium::value::Value::Array(a) => a,
220                     //         _ => return cbor_type_error(&value, "arr"),
221                     //     };
222                     //     if a.len() != 3 {
223                     //         return Err(CborError::UnexpectedItem("arr", "arr len 3"));
224                     //     }
225                     //     // Fields specified in reverse order to reduce shifting.
226                     //     let field_2 = <Type2>::from_cbor_value(a.remove(2))?;
227                     //     let field_1 = <Type1>::from_cbor_value(a.remove(1))?;
228                     //     let field_0 = <Type0>::from_cbor_value(a.remove(0))?;
229                     //     Ok(Self(field_0, field_1, field_2))
230                     let nfields = fields.unnamed.len();
231                     let recurse1 = fields.unnamed.iter().enumerate().rev().map(|(i, f)| {
232                         let typ = &f.ty;
233                         let varname = format_ident!("field_{}", i);
234                         quote_spanned! {f.span()=>
235                                         let #varname = <#typ>::from_cbor_value(a.remove(#i))?;
236                         }
237                     });
238                     let recurse2 = fields.unnamed.iter().enumerate().map(|(i, f)| {
239                         let varname = format_ident!("field_{}", i);
240                         quote_spanned! {f.span()=>
241                                         #varname
242                         }
243                     });
244                     quote! {
245                         let mut a = match value {
246                             ciborium::value::Value::Array(a) => a,
247                             _ => return cbor_type_error(&value, "arr"),
248                         };
249                         if a.len() != #nfields {
250                             return Err(CborError::UnexpectedItem("arr",
251                                                                  concat!("arr len ",
252                                                                          stringify!(#nfields))));
253                         }
254                         // Fields specified in reverse order to reduce shifting.
255                         #(#recurse1)*
256 
257                         Ok(Self( #(#recurse2, )* ))
258                     }
259                 }
260                 Fields::Unit => unimplemented!(),
261             }
262         }
263         Data::Enum(enum_data) => {
264             // This only copes with variants with no fields.
265             // Expands to an expression like:
266             //
267             //     use core::convert::TryInto;
268             //     let v: i32 = match value {
269             //         ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
270             //             CborError::OutOfRangeIntegerValue
271             //         })?,
272             //         v => return cbor_type_error(&v, &"int"),
273             //     };
274             //     match v {
275             //         x if x == Self::Variant1 as i32 => Ok(Self::Variant1),
276             //         x if x == Self::Variant2 as i32 => Ok(Self::Variant2),
277             //         x if x == Self::Variant3 as i32 => Ok(Self::Variant3),
278             //         _ => Err( CborError::OutOfRangeIntegerValue),
279             //     }
280             let recurse = enum_data.variants.iter().map(|variant| {
281                 let vname = &variant.ident;
282                 quote_spanned! {variant.span()=>
283                                 x if x == Self::#vname as i32 => Ok(Self::#vname),
284                 }
285             });
286 
287             quote! {
288                 use core::convert::TryInto;
289                 // First get the int value as an `i32`.
290                 let v: i32 = match value {
291                     ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
292                         CborError::OutOfRangeIntegerValue
293                     })?,
294                     v => return cbor_type_error(&v, &"int"),
295                 };
296                 // Now match against enum possibilities.
297                 match v {
298                     #(#recurse)*
299                     _ => Err(
300                         CborError::OutOfRangeIntegerValue
301                     ),
302                 }
303             }
304         }
305         Data::Union(_) => unimplemented!(),
306     }
307 }
308 
309 /// Generate an expression that expresses the CDDL schema for the type.
cddl_struct(name: &Ident, data: &Data) -> TokenStream310 fn cddl_struct(name: &Ident, data: &Data) -> TokenStream {
311     match *data {
312         Data::Struct(ref data) => {
313             match data.fields {
314                 Fields::Named(ref fields) => {
315                     if fields.named.iter().next().is_none() {
316                         return quote! {
317                             Some(format!("[]"))
318                         };
319                     }
320                     // Expands to an expression like
321                     //
322                     //     format!("[
323                     //         x: {},
324                     //         y: {},
325                     //         z: {},
326                     //     ]",
327                     //         <TypeX>::cddl_ref(),
328                     //         <TypeY>::cddl_ref(),
329                     //         <TypeZ>::cddl_ref(),
330                     //     )
331                     let fmt_recurse = fields.named.iter().map(|f| {
332                         let name = &f.ident;
333                         quote_spanned! {f.span()=>
334                                         concat!("    ", stringify!(#name), ": {},\n")
335                         }
336                     });
337                     let fmt = quote! {
338                         concat!("[\n",
339                                 #(#fmt_recurse, )*
340                                 "]")
341                     };
342                     let recurse = fields.named.iter().map(|f| {
343                         let typ = &f.ty;
344                         quote_spanned! {f.span()=>
345                                         <#typ>::cddl_ref()
346                         }
347                     });
348                     quote! {
349                         Some(format!(
350                             #fmt,
351                             #(#recurse, )*
352                         ))
353                     }
354                 }
355                 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
356                     let inner = fields.unnamed.first().unwrap();
357                     let typ = &inner.ty;
358                     quote! {
359                         Some(<#typ>::cddl_ref())
360                     }
361                 }
362                 Fields::Unnamed(ref fields) => {
363                     if fields.unnamed.iter().next().is_none() {
364                         return quote! {
365                             Some(format!("()"))
366                         };
367                     }
368                     // Expands to an expression like
369                     //
370                     //     format!("[
371                     //         {},
372                     //         {},
373                     //         {},
374                     //     ]",
375                     //         <TypeX>::cddl_ref(),
376                     //         <TypeY>::cddl_ref(),
377                     //         <TypeZ>::cddl_ref(),
378                     //     )
379                     //
380                     let fmt_recurse = fields.unnamed.iter().map(|f| {
381                         quote_spanned! {f.span()=>
382                                         "    {},\n"
383                         }
384                     });
385                     let fmt = quote! {
386                         concat!("[\n",
387                                  #(#fmt_recurse, )*
388                                  "]")
389                     };
390                     let recurse = fields.unnamed.iter().map(|f| {
391                         let typ = &f.ty;
392                         quote_spanned! {f.span()=>
393                                         <#typ>::cddl_ref()
394                         }
395                     });
396                     quote! {
397                         Some(format!(
398                             #fmt,
399                             #(#recurse, )*
400                         ))
401                     }
402                 }
403                 Fields::Unit => unimplemented!(),
404             }
405         }
406         Data::Enum(ref enum_data) => {
407             // This only copes with variants with no fields.
408             // Expands to an expression like:
409             //
410             //     format!("&(
411             //         EnumName_Variant1: {},
412             //         EnumName_Variant2: {},
413             //         EnumName_Variant3: {},
414             //     )",
415             //         Self::Variant1 as i32,
416             //         Self::Variant2 as i32,
417             //         Self::Variant3 as i32,
418             //     )
419             //
420             let fmt_recurse = enum_data.variants.iter().map(|variant| {
421                 let vname = &variant.ident;
422                 quote_spanned! {variant.span()=>
423                                 concat!("    ",
424                                         stringify!(#name),
425                                         "_",
426                                         stringify!(#vname),
427                                         ": {},\n")
428                 }
429             });
430             let fmt = quote! {
431                 concat!("&(\n",
432                          #(#fmt_recurse, )*
433                          ")")
434             };
435             let recurse = enum_data.variants.iter().map(|variant| {
436                 let vname = &variant.ident;
437                 quote_spanned! {variant.span()=>
438                                 Self::#vname as i32
439                 }
440             });
441             quote! {
442                 Some(format!(
443                     #fmt,
444                     #(#recurse, )*
445                 ))
446             }
447         }
448         Data::Union(_) => unimplemented!(),
449     }
450 }
451 
452 /// Derive macro that implements a `from_raw_tag_value` method for the `Tag` enum.
453 #[proc_macro_derive(FromRawTag)]
derive_from_raw_tag(input: proc_macro::TokenStream) -> proc_macro::TokenStream454 pub fn derive_from_raw_tag(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
455     let input = parse_macro_input!(input as DeriveInput);
456     derive_from_raw_tag_internal(&input)
457 }
458 
derive_from_raw_tag_internal(input: &DeriveInput) -> proc_macro::TokenStream459 fn derive_from_raw_tag_internal(input: &DeriveInput) -> proc_macro::TokenStream {
460     let name = &input.ident;
461     let from_val = from_raw_tag(name, &input.data);
462     let expanded = quote! {
463         pub fn from_raw_tag_value(raw_tag: u32) -> #name {
464             #from_val
465         }
466     };
467     expanded.into()
468 }
469 
470 /// Generate an expression to convert a `u32` into an instance of an fieldless enum.
471 /// Assumes the existence of an `Invalid` variant as a fallback, and assumes that a
472 /// `raw_tag_value` function is in scope.
from_raw_tag(name: &Ident, data: &Data) -> TokenStream473 fn from_raw_tag(name: &Ident, data: &Data) -> TokenStream {
474     match data {
475         Data::Enum(enum_data) => {
476             let recurse = enum_data.variants.iter().map(|variant| {
477                 let vname = &variant.ident;
478                 quote_spanned! {variant.span()=>
479                                 x if x == raw_tag_value(#name::#vname) => #name::#vname,
480                 }
481             });
482 
483             quote! {
484                 match raw_tag {
485                     #(#recurse)*
486                     _ => #name::Invalid,
487                 }
488             }
489         }
490         _ => unimplemented!(),
491     }
492 }
493 
494 /// Derive macro that implements the `legacy::InnerSerialize` trait.  Using this macro requires
495 /// that `InnerSerialize` and `Error` from `kmr_wire::legacy` be locally `use`d.
496 #[proc_macro_derive(LegacySerialize)]
derive_legacy_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream497 pub fn derive_legacy_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
498     let input = parse_macro_input!(input as DeriveInput);
499     derive_legacy_serialize_internal(&input)
500 }
501 
derive_legacy_serialize_internal(input: &DeriveInput) -> proc_macro::TokenStream502 fn derive_legacy_serialize_internal(input: &DeriveInput) -> proc_macro::TokenStream {
503     let name = &input.ident;
504 
505     let deserialize_val = deserialize_struct(&input.data);
506     let serialize_val = serialize_struct(&input.data);
507 
508     let expanded = quote! {
509         impl InnerSerialize for #name {
510             fn deserialize(data: &[u8]) -> Result<(Self, &[u8]), Error> {
511                 #deserialize_val
512             }
513             fn serialize_into(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
514                 #serialize_val
515             }
516         }
517     };
518 
519     expanded.into()
520 }
521 
deserialize_struct(data: &Data) -> TokenStream522 fn deserialize_struct(data: &Data) -> TokenStream {
523     match data {
524         Data::Struct(ref data) => {
525             match data.fields {
526                 Fields::Named(ref fields) => {
527                     // Expands to an expression like
528                     //
529                     //     let (x, data) = <XType>::deserialize(data)?;
530                     //     let (y, data) = <YType>::deserialize(data)?;
531                     //     let (z, data) = <ZType>::deserialize(data)?;
532                     //     Ok((Self {
533                     //             x,
534                     //             y,
535                     //             z,
536                     //     }, data))
537                     //
538                     let recurse1 = fields.named.iter().map(|f| {
539                         let name = &f.ident;
540                         let typ = &f.ty;
541                         quote_spanned! {f.span()=>
542                                         let (#name, data) = <#typ>::deserialize(data)?;
543                         }
544                     });
545                     let recurse2 = fields.named.iter().map(|f| {
546                         let name = &f.ident;
547                         quote_spanned! {f.span()=>
548                                         #name
549                         }
550                     });
551                     quote! {
552                         #(#recurse1)*
553                         Ok((Self {
554                             #(#recurse2, )*
555                         }, data))
556                     }
557                 }
558                 Fields::Unnamed(_) => unimplemented!(),
559                 Fields::Unit => unimplemented!(),
560             }
561         }
562         Data::Enum(_) => unimplemented!(),
563         Data::Union(_) => unimplemented!(),
564     }
565 }
566 
serialize_struct(data: &Data) -> TokenStream567 fn serialize_struct(data: &Data) -> TokenStream {
568     match data {
569         Data::Struct(ref data) => {
570             match data.fields {
571                 Fields::Named(ref fields) => {
572                     // Expands to an expression like
573                     //
574                     //     self.x.serialize_into(buf)?;
575                     //     self.y.serialize_into(buf)?;
576                     //     self.z.serialize_into(buf)?;
577                     //     Ok(())
578                     //
579                     let recurse = fields.named.iter().map(|f| {
580                         let name = &f.ident;
581                         quote_spanned! {f.span()=>
582                                         self.#name.serialize_into(buf)?;
583                         }
584                     });
585                     quote! {
586                         #(#recurse)*
587                         Ok(())
588                     }
589                 }
590                 Fields::Unnamed(_) => unimplemented!(),
591                 Fields::Unit => unimplemented!(),
592             }
593         }
594         Data::Enum(_) => unimplemented!(),
595         Data::Union(_) => unimplemented!(),
596     }
597 }
598