1 // Copyright (C) 2023 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use syn::{parse_macro_input, DeriveInput, Error};
16 
17 #[proc_macro_derive(NameAndVersionMap)]
derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream18 pub fn derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19     let input = parse_macro_input!(input as DeriveInput);
20     name_and_version_map::expand(input).unwrap_or_else(Error::into_compile_error).into()
21 }
22 
23 mod name_and_version_map {
24     use proc_macro2::TokenStream;
25     use quote::quote;
26     use syn::{
27         Data, DataStruct, DeriveInput, Error, Field, GenericArgument, PathArguments, Result, Type,
28     };
29 
expand(input: DeriveInput) -> Result<TokenStream>30     pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
31         let name = &input.ident;
32         let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33 
34         let mapfield = get_map_field(get_struct(&input)?)?;
35         let mapfield_name = mapfield
36             .ident
37             .as_ref()
38             .ok_or(Error::new_spanned(mapfield, "mapfield ident is none"))?;
39         let (_, value_type) = get_map_type(&mapfield.ty)?;
40 
41         let expanded = quote! {
42             #[automatically_derived]
43             impl #impl_generics NameAndVersionMap for #name #ty_generics #where_clause {
44                 type Value = #value_type;
45 
46                 fn map_field(&self) -> &BTreeMap<NameAndVersion, Self::Value> {
47                     self.#mapfield_name.map_field()
48                 }
49 
50                 fn map_field_mut(&mut self) -> &mut BTreeMap<NameAndVersion, Self::Value> {
51                     self.#mapfield_name.map_field_mut()
52                 }
53 
54                 fn insert_or_error(&mut self, key: NameAndVersion, val: Self::Value) -> Result<(), CrateError> {
55                     self.#mapfield_name.insert_or_error(key, val)
56                 }
57 
58                 fn num_crates(&self) -> usize {
59                     self.#mapfield_name.num_crates()
60                 }
61 
62                 fn get_versions<'a, 'b>(&'a self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a Self::Value)> + 'a> {
63                     self.#mapfield_name.get_versions(name)
64                 }
65 
66                 fn get_versions_mut<'a, 'b>(&'a mut self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a mut Self::Value)> + 'a> {
67                     self.#mapfield_name.get_versions_mut(name)
68                 }
69 
70                 fn filter_versions<'a: 'b, 'b, F: Fn(&mut dyn Iterator<Item = (&'b NameAndVersion, &'b Self::Value)>,
71                 ) -> HashSet<Version> + 'a>(
72                     &'a self,
73                     f: F,
74                 ) -> Box<dyn Iterator<Item =(&'a NameAndVersion, &'a Self::Value)> + 'a> {
75                     self.#mapfield_name.filter_versions(f)
76                 }
77             }
78         };
79 
80         Ok(TokenStream::from(expanded))
81     }
82 
get_struct(input: &DeriveInput) -> Result<&DataStruct>83     fn get_struct(input: &DeriveInput) -> Result<&DataStruct> {
84         match &input.data {
85             Data::Struct(strukt) => Ok(strukt),
86             _ => Err(Error::new_spanned(input, "Not a struct")),
87         }
88     }
89 
get_map_field(strukt: &DataStruct) -> Result<&Field>90     fn get_map_field(strukt: &DataStruct) -> Result<&Field> {
91         for field in &strukt.fields {
92             if let Ok((key_type, _value_type)) = get_map_type(&field.ty) {
93                 if let syn::Type::Path(path) = &key_type {
94                     if path.path.segments.len() == 1
95                         && path.path.segments[0].ident == "NameAndVersion"
96                     {
97                         return Ok(field);
98                     }
99                 }
100             }
101         }
102         return Err(Error::new_spanned(strukt.struct_token, "No field of type NameAndVersionMap"));
103     }
104 
get_map_type(typ: &Type) -> Result<(&Type, &Type)>105     fn get_map_type(typ: &Type) -> Result<(&Type, &Type)> {
106         if let syn::Type::Path(path) = &typ {
107             if path.path.segments.len() == 1 && path.path.segments[0].ident == "BTreeMap" {
108                 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
109                     if args.args.len() == 2 {
110                         return Ok((get_type(&args.args[0])?, get_type(&args.args[1])?));
111                     }
112                 }
113             }
114         }
115         Err(Error::new_spanned(typ, "Must be BTreeMap"))
116     }
117 
get_type(arg: &GenericArgument) -> Result<&Type>118     fn get_type(arg: &GenericArgument) -> Result<&Type> {
119         if let GenericArgument::Type(typ) = arg {
120             return Ok(typ);
121         }
122         Err(Error::new_spanned(arg, "Could not extract argument type"))
123     }
124 }
125