1 // Copyright 2022, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! Derive macro for `AsCborValue`.
16 use proc_macro2::TokenStream;
17 use quote::{format_ident, quote, quote_spanned};
18 use syn::{
19 parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
20 Generics, Ident, Index,
21 };
22
23 /// Derive macro that implements the `AsCborValue` trait. Using this macro requires
24 /// that `AsCborValue`, `CborError` and `cbor_type_error` are locally `use`d.
25 #[proc_macro_derive(AsCborValue)]
derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream26 pub fn derive_as_cbor_value(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
27 let input = parse_macro_input!(input as DeriveInput);
28 derive_as_cbor_value_internal(&input)
29 }
30
derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream31 fn derive_as_cbor_value_internal(input: &DeriveInput) -> proc_macro::TokenStream {
32 let name = &input.ident;
33
34 // Add a bound `T: AsCborValue` for every type parameter `T`.
35 let generics = add_trait_bounds(&input.generics);
36 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37
38 let from_val = from_val_struct(&input.data);
39 let to_val = to_val_struct(&input.data);
40 let cddl = cddl_struct(name, &input.data);
41
42 let expanded = quote! {
43 // The generated impl
44 impl #impl_generics AsCborValue for #name #ty_generics #where_clause {
45 fn from_cbor_value(value: ciborium::value::Value) -> Result<Self, CborError> {
46 #from_val
47 }
48 fn to_cbor_value(self) -> Result<ciborium::value::Value, CborError> {
49 #to_val
50 }
51 fn cddl_typename() -> Option<String> {
52 Some(stringify!(#name).to_string())
53 }
54 fn cddl_schema() -> Option<String> {
55 #cddl
56 }
57 }
58 };
59
60 expanded.into()
61 }
62
63 /// Add a bound `T: AsCborValue` for every type parameter `T`.
add_trait_bounds(generics: &Generics) -> Generics64 fn add_trait_bounds(generics: &Generics) -> Generics {
65 let mut generics = generics.clone();
66 for param in &mut generics.params {
67 if let GenericParam::Type(ref mut type_param) = *param {
68 type_param.bounds.push(parse_quote!(AsCborValue));
69 }
70 }
71 generics
72 }
73
74 /// Generate an expression to convert an instance of a compound type to `ciborium::value::Value`
to_val_struct(data: &Data) -> TokenStream75 fn to_val_struct(data: &Data) -> TokenStream {
76 match *data {
77 Data::Struct(ref data) => {
78 match data.fields {
79 Fields::Named(ref fields) => {
80 // Expands to an expression like
81 //
82 // {
83 // let mut v = Vec::new();
84 // v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
85 // v.push(AsCborValue::to_cbor_value(self.x)?);
86 // v.push(AsCborValue::to_cbor_value(self.y)?);
87 // v.push(AsCborValue::to_cbor_value(self.z)?);
88 // Ok(ciborium::value::Value::Array(v))
89 // }
90 let nfields = fields.named.len();
91 let recurse = fields.named.iter().map(|f| {
92 let name = &f.ident;
93 quote_spanned! {f.span()=>
94 v.push(AsCborValue::to_cbor_value(self.#name)?)
95 }
96 });
97 quote! {
98 {
99 let mut v = Vec::new();
100 v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
101 #(#recurse; )*
102 Ok(ciborium::value::Value::Array(v))
103 }
104 }
105 }
106 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
107 // For a newtype, expands to an expression
108 //
109 // self.0.to_cbor_value()
110 quote! {
111 self.0.to_cbor_value()
112 }
113 }
114 Fields::Unnamed(ref fields) => {
115 // Expands to an expression like
116 //
117 //
118 // {
119 // let mut v = Vec::new();
120 // v.try_reserve(3).map_err(|_e| CborError::AllocationFailed)?;
121 // v.push(AsCborValue::to_cbor_value(self.0)?);
122 // v.push(AsCborValue::to_cbor_value(self.1)?);
123 // v.push(AsCborValue::to_cbor_value(self.2)?);
124 // Ok(ciborium::value::Value::Array(v))
125 // }
126 let nfields = fields.unnamed.len();
127 let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
128 let index = Index::from(i);
129 quote_spanned! {f.span()=>
130 v.push(AsCborValue::to_cbor_value(self.#index)?)
131 }
132 });
133 quote! {
134 {
135 let mut v = Vec::new();
136 v.try_reserve(#nfields).map_err(|_e| CborError::AllocationFailed)?;
137 #(#recurse; )*
138 Ok(ciborium::value::Value::Array(v))
139 }
140 }
141 }
142 Fields::Unit => unimplemented!(),
143 }
144 }
145 Data::Enum(_) => {
146 quote! {
147 let v: ciborium::value::Integer = (self as i32).into();
148 Ok(ciborium::value::Value::Integer(v))
149 }
150 }
151 Data::Union(_) => unimplemented!(),
152 }
153 }
154
155 /// Generate an expression to convert a `ciborium::value::Value` into an instance of a compound
156 /// type.
from_val_struct(data: &Data) -> TokenStream157 fn from_val_struct(data: &Data) -> TokenStream {
158 match data {
159 Data::Struct(ref data) => {
160 match data.fields {
161 Fields::Named(ref fields) => {
162 // Expands to an expression like
163 //
164 // let mut a = match value {
165 // ciborium::value::Value::Array(a) => a,
166 // _ => return cbor_type_error(&value, "arr"),
167 // };
168 // if a.len() != 3 {
169 // return Err(CborError::UnexpectedItem("arr", "arr len 3"));
170 // }
171 // // Fields specified in reverse order to reduce shifting.
172 // Ok(Self {
173 // z: <ZType>::from_cbor_value(a.remove(2))?,
174 // y: <YType>::from_cbor_value(a.remove(1))?,
175 // x: <XType>::from_cbor_value(a.remove(0))?,
176 // })
177 //
178 // but using fully qualified function call syntax.
179 let nfields = fields.named.len();
180 let recurse = fields.named.iter().enumerate().rev().map(|(i, f)| {
181 let name = &f.ident;
182 let index = Index::from(i);
183 let typ = &f.ty;
184 quote_spanned! {f.span()=>
185 #name: <#typ>::from_cbor_value(a.remove(#index))?
186 }
187 });
188 quote! {
189 let mut a = match value {
190 ciborium::value::Value::Array(a) => a,
191 _ => return cbor_type_error(&value, "arr"),
192 };
193 if a.len() != #nfields {
194 return Err(CborError::UnexpectedItem(
195 "arr",
196 concat!("arr len ", stringify!(#nfields)),
197 ));
198 }
199 // Fields specified in reverse order to reduce shifting.
200 Ok(Self {
201 #(#recurse, )*
202 })
203 }
204 }
205 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
206 // For a newtype, expands to an expression like
207 //
208 // Ok(Self(<InnerType>::from_cbor_value(value)?))
209 let inner = fields.unnamed.first().unwrap();
210 let typ = &inner.ty;
211 quote! {
212 Ok(Self(<#typ>::from_cbor_value(value)?))
213 }
214 }
215 Fields::Unnamed(ref fields) => {
216 // Expands to an expression like
217 //
218 // let mut a = match value {
219 // ciborium::value::Value::Array(a) => a,
220 // _ => return cbor_type_error(&value, "arr"),
221 // };
222 // if a.len() != 3 {
223 // return Err(CborError::UnexpectedItem("arr", "arr len 3"));
224 // }
225 // // Fields specified in reverse order to reduce shifting.
226 // let field_2 = <Type2>::from_cbor_value(a.remove(2))?;
227 // let field_1 = <Type1>::from_cbor_value(a.remove(1))?;
228 // let field_0 = <Type0>::from_cbor_value(a.remove(0))?;
229 // Ok(Self(field_0, field_1, field_2))
230 let nfields = fields.unnamed.len();
231 let recurse1 = fields.unnamed.iter().enumerate().rev().map(|(i, f)| {
232 let typ = &f.ty;
233 let varname = format_ident!("field_{}", i);
234 quote_spanned! {f.span()=>
235 let #varname = <#typ>::from_cbor_value(a.remove(#i))?;
236 }
237 });
238 let recurse2 = fields.unnamed.iter().enumerate().map(|(i, f)| {
239 let varname = format_ident!("field_{}", i);
240 quote_spanned! {f.span()=>
241 #varname
242 }
243 });
244 quote! {
245 let mut a = match value {
246 ciborium::value::Value::Array(a) => a,
247 _ => return cbor_type_error(&value, "arr"),
248 };
249 if a.len() != #nfields {
250 return Err(CborError::UnexpectedItem("arr",
251 concat!("arr len ",
252 stringify!(#nfields))));
253 }
254 // Fields specified in reverse order to reduce shifting.
255 #(#recurse1)*
256
257 Ok(Self( #(#recurse2, )* ))
258 }
259 }
260 Fields::Unit => unimplemented!(),
261 }
262 }
263 Data::Enum(enum_data) => {
264 // This only copes with variants with no fields.
265 // Expands to an expression like:
266 //
267 // use core::convert::TryInto;
268 // let v: i32 = match value {
269 // ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
270 // CborError::OutOfRangeIntegerValue
271 // })?,
272 // v => return cbor_type_error(&v, &"int"),
273 // };
274 // match v {
275 // x if x == Self::Variant1 as i32 => Ok(Self::Variant1),
276 // x if x == Self::Variant2 as i32 => Ok(Self::Variant2),
277 // x if x == Self::Variant3 as i32 => Ok(Self::Variant3),
278 // _ => Err( CborError::OutOfRangeIntegerValue),
279 // }
280 let recurse = enum_data.variants.iter().map(|variant| {
281 let vname = &variant.ident;
282 quote_spanned! {variant.span()=>
283 x if x == Self::#vname as i32 => Ok(Self::#vname),
284 }
285 });
286
287 quote! {
288 use core::convert::TryInto;
289 // First get the int value as an `i32`.
290 let v: i32 = match value {
291 ciborium::value::Value::Integer(i) => i.try_into().map_err(|_| {
292 CborError::OutOfRangeIntegerValue
293 })?,
294 v => return cbor_type_error(&v, &"int"),
295 };
296 // Now match against enum possibilities.
297 match v {
298 #(#recurse)*
299 _ => Err(
300 CborError::OutOfRangeIntegerValue
301 ),
302 }
303 }
304 }
305 Data::Union(_) => unimplemented!(),
306 }
307 }
308
309 /// Generate an expression that expresses the CDDL schema for the type.
cddl_struct(name: &Ident, data: &Data) -> TokenStream310 fn cddl_struct(name: &Ident, data: &Data) -> TokenStream {
311 match *data {
312 Data::Struct(ref data) => {
313 match data.fields {
314 Fields::Named(ref fields) => {
315 if fields.named.iter().next().is_none() {
316 return quote! {
317 Some(format!("[]"))
318 };
319 }
320 // Expands to an expression like
321 //
322 // format!("[
323 // x: {},
324 // y: {},
325 // z: {},
326 // ]",
327 // <TypeX>::cddl_ref(),
328 // <TypeY>::cddl_ref(),
329 // <TypeZ>::cddl_ref(),
330 // )
331 let fmt_recurse = fields.named.iter().map(|f| {
332 let name = &f.ident;
333 quote_spanned! {f.span()=>
334 concat!(" ", stringify!(#name), ": {},\n")
335 }
336 });
337 let fmt = quote! {
338 concat!("[\n",
339 #(#fmt_recurse, )*
340 "]")
341 };
342 let recurse = fields.named.iter().map(|f| {
343 let typ = &f.ty;
344 quote_spanned! {f.span()=>
345 <#typ>::cddl_ref()
346 }
347 });
348 quote! {
349 Some(format!(
350 #fmt,
351 #(#recurse, )*
352 ))
353 }
354 }
355 Fields::Unnamed(ref fields) if fields.unnamed.len() == 1 => {
356 let inner = fields.unnamed.first().unwrap();
357 let typ = &inner.ty;
358 quote! {
359 Some(<#typ>::cddl_ref())
360 }
361 }
362 Fields::Unnamed(ref fields) => {
363 if fields.unnamed.iter().next().is_none() {
364 return quote! {
365 Some(format!("()"))
366 };
367 }
368 // Expands to an expression like
369 //
370 // format!("[
371 // {},
372 // {},
373 // {},
374 // ]",
375 // <TypeX>::cddl_ref(),
376 // <TypeY>::cddl_ref(),
377 // <TypeZ>::cddl_ref(),
378 // )
379 //
380 let fmt_recurse = fields.unnamed.iter().map(|f| {
381 quote_spanned! {f.span()=>
382 " {},\n"
383 }
384 });
385 let fmt = quote! {
386 concat!("[\n",
387 #(#fmt_recurse, )*
388 "]")
389 };
390 let recurse = fields.unnamed.iter().map(|f| {
391 let typ = &f.ty;
392 quote_spanned! {f.span()=>
393 <#typ>::cddl_ref()
394 }
395 });
396 quote! {
397 Some(format!(
398 #fmt,
399 #(#recurse, )*
400 ))
401 }
402 }
403 Fields::Unit => unimplemented!(),
404 }
405 }
406 Data::Enum(ref enum_data) => {
407 // This only copes with variants with no fields.
408 // Expands to an expression like:
409 //
410 // format!("&(
411 // EnumName_Variant1: {},
412 // EnumName_Variant2: {},
413 // EnumName_Variant3: {},
414 // )",
415 // Self::Variant1 as i32,
416 // Self::Variant2 as i32,
417 // Self::Variant3 as i32,
418 // )
419 //
420 let fmt_recurse = enum_data.variants.iter().map(|variant| {
421 let vname = &variant.ident;
422 quote_spanned! {variant.span()=>
423 concat!(" ",
424 stringify!(#name),
425 "_",
426 stringify!(#vname),
427 ": {},\n")
428 }
429 });
430 let fmt = quote! {
431 concat!("&(\n",
432 #(#fmt_recurse, )*
433 ")")
434 };
435 let recurse = enum_data.variants.iter().map(|variant| {
436 let vname = &variant.ident;
437 quote_spanned! {variant.span()=>
438 Self::#vname as i32
439 }
440 });
441 quote! {
442 Some(format!(
443 #fmt,
444 #(#recurse, )*
445 ))
446 }
447 }
448 Data::Union(_) => unimplemented!(),
449 }
450 }
451
452 /// Derive macro that implements a `from_raw_tag_value` method for the `Tag` enum.
453 #[proc_macro_derive(FromRawTag)]
derive_from_raw_tag(input: proc_macro::TokenStream) -> proc_macro::TokenStream454 pub fn derive_from_raw_tag(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
455 let input = parse_macro_input!(input as DeriveInput);
456 derive_from_raw_tag_internal(&input)
457 }
458
derive_from_raw_tag_internal(input: &DeriveInput) -> proc_macro::TokenStream459 fn derive_from_raw_tag_internal(input: &DeriveInput) -> proc_macro::TokenStream {
460 let name = &input.ident;
461 let from_val = from_raw_tag(name, &input.data);
462 let expanded = quote! {
463 pub fn from_raw_tag_value(raw_tag: u32) -> #name {
464 #from_val
465 }
466 };
467 expanded.into()
468 }
469
470 /// Generate an expression to convert a `u32` into an instance of an fieldless enum.
471 /// Assumes the existence of an `Invalid` variant as a fallback, and assumes that a
472 /// `raw_tag_value` function is in scope.
from_raw_tag(name: &Ident, data: &Data) -> TokenStream473 fn from_raw_tag(name: &Ident, data: &Data) -> TokenStream {
474 match data {
475 Data::Enum(enum_data) => {
476 let recurse = enum_data.variants.iter().map(|variant| {
477 let vname = &variant.ident;
478 quote_spanned! {variant.span()=>
479 x if x == raw_tag_value(#name::#vname) => #name::#vname,
480 }
481 });
482
483 quote! {
484 match raw_tag {
485 #(#recurse)*
486 _ => #name::Invalid,
487 }
488 }
489 }
490 _ => unimplemented!(),
491 }
492 }
493
494 /// Derive macro that implements the `legacy::InnerSerialize` trait. Using this macro requires
495 /// that `InnerSerialize` and `Error` from `kmr_wire::legacy` be locally `use`d.
496 #[proc_macro_derive(LegacySerialize)]
derive_legacy_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream497 pub fn derive_legacy_serialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
498 let input = parse_macro_input!(input as DeriveInput);
499 derive_legacy_serialize_internal(&input)
500 }
501
derive_legacy_serialize_internal(input: &DeriveInput) -> proc_macro::TokenStream502 fn derive_legacy_serialize_internal(input: &DeriveInput) -> proc_macro::TokenStream {
503 let name = &input.ident;
504
505 let deserialize_val = deserialize_struct(&input.data);
506 let serialize_val = serialize_struct(&input.data);
507
508 let expanded = quote! {
509 impl InnerSerialize for #name {
510 fn deserialize(data: &[u8]) -> Result<(Self, &[u8]), Error> {
511 #deserialize_val
512 }
513 fn serialize_into(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
514 #serialize_val
515 }
516 }
517 };
518
519 expanded.into()
520 }
521
deserialize_struct(data: &Data) -> TokenStream522 fn deserialize_struct(data: &Data) -> TokenStream {
523 match data {
524 Data::Struct(ref data) => {
525 match data.fields {
526 Fields::Named(ref fields) => {
527 // Expands to an expression like
528 //
529 // let (x, data) = <XType>::deserialize(data)?;
530 // let (y, data) = <YType>::deserialize(data)?;
531 // let (z, data) = <ZType>::deserialize(data)?;
532 // Ok((Self {
533 // x,
534 // y,
535 // z,
536 // }, data))
537 //
538 let recurse1 = fields.named.iter().map(|f| {
539 let name = &f.ident;
540 let typ = &f.ty;
541 quote_spanned! {f.span()=>
542 let (#name, data) = <#typ>::deserialize(data)?;
543 }
544 });
545 let recurse2 = fields.named.iter().map(|f| {
546 let name = &f.ident;
547 quote_spanned! {f.span()=>
548 #name
549 }
550 });
551 quote! {
552 #(#recurse1)*
553 Ok((Self {
554 #(#recurse2, )*
555 }, data))
556 }
557 }
558 Fields::Unnamed(_) => unimplemented!(),
559 Fields::Unit => unimplemented!(),
560 }
561 }
562 Data::Enum(_) => unimplemented!(),
563 Data::Union(_) => unimplemented!(),
564 }
565 }
566
serialize_struct(data: &Data) -> TokenStream567 fn serialize_struct(data: &Data) -> TokenStream {
568 match data {
569 Data::Struct(ref data) => {
570 match data.fields {
571 Fields::Named(ref fields) => {
572 // Expands to an expression like
573 //
574 // self.x.serialize_into(buf)?;
575 // self.y.serialize_into(buf)?;
576 // self.z.serialize_into(buf)?;
577 // Ok(())
578 //
579 let recurse = fields.named.iter().map(|f| {
580 let name = &f.ident;
581 quote_spanned! {f.span()=>
582 self.#name.serialize_into(buf)?;
583 }
584 });
585 quote! {
586 #(#recurse)*
587 Ok(())
588 }
589 }
590 Fields::Unnamed(_) => unimplemented!(),
591 Fields::Unit => unimplemented!(),
592 }
593 }
594 Data::Enum(_) => unimplemented!(),
595 Data::Union(_) => unimplemented!(),
596 }
597 }
598