1 // Copyright (C) 2023 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use syn::{parse_macro_input, DeriveInput, Error};
16
17 #[proc_macro_derive(NameAndVersionMap)]
derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream18 pub fn derive_name_and_version_map(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
19 let input = parse_macro_input!(input as DeriveInput);
20 name_and_version_map::expand(input).unwrap_or_else(Error::into_compile_error).into()
21 }
22
23 mod name_and_version_map {
24 use proc_macro2::TokenStream;
25 use quote::quote;
26 use syn::{
27 Data, DataStruct, DeriveInput, Error, Field, GenericArgument, PathArguments, Result, Type,
28 };
29
expand(input: DeriveInput) -> Result<TokenStream>30 pub(crate) fn expand(input: DeriveInput) -> Result<TokenStream> {
31 let name = &input.ident;
32 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33
34 let mapfield = get_map_field(get_struct(&input)?)?;
35 let mapfield_name = mapfield
36 .ident
37 .as_ref()
38 .ok_or(Error::new_spanned(mapfield, "mapfield ident is none"))?;
39 let (_, value_type) = get_map_type(&mapfield.ty)?;
40
41 let expanded = quote! {
42 #[automatically_derived]
43 impl #impl_generics NameAndVersionMap for #name #ty_generics #where_clause {
44 type Value = #value_type;
45
46 fn map_field(&self) -> &BTreeMap<NameAndVersion, Self::Value> {
47 self.#mapfield_name.map_field()
48 }
49
50 fn map_field_mut(&mut self) -> &mut BTreeMap<NameAndVersion, Self::Value> {
51 self.#mapfield_name.map_field_mut()
52 }
53
54 fn insert_or_error(&mut self, key: NameAndVersion, val: Self::Value) -> Result<(), CrateError> {
55 self.#mapfield_name.insert_or_error(key, val)
56 }
57
58 fn num_crates(&self) -> usize {
59 self.#mapfield_name.num_crates()
60 }
61
62 fn get_versions<'a, 'b>(&'a self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a Self::Value)> + 'a> {
63 self.#mapfield_name.get_versions(name)
64 }
65
66 fn get_versions_mut<'a, 'b>(&'a mut self, name: &'b str) -> Box<dyn Iterator<Item = (&'a NameAndVersion, &'a mut Self::Value)> + 'a> {
67 self.#mapfield_name.get_versions_mut(name)
68 }
69
70 fn filter_versions<'a: 'b, 'b, F: Fn(&mut dyn Iterator<Item = (&'b NameAndVersion, &'b Self::Value)>,
71 ) -> HashSet<Version> + 'a>(
72 &'a self,
73 f: F,
74 ) -> Box<dyn Iterator<Item =(&'a NameAndVersion, &'a Self::Value)> + 'a> {
75 self.#mapfield_name.filter_versions(f)
76 }
77 }
78 };
79
80 Ok(TokenStream::from(expanded))
81 }
82
get_struct(input: &DeriveInput) -> Result<&DataStruct>83 fn get_struct(input: &DeriveInput) -> Result<&DataStruct> {
84 match &input.data {
85 Data::Struct(strukt) => Ok(strukt),
86 _ => Err(Error::new_spanned(input, "Not a struct")),
87 }
88 }
89
get_map_field(strukt: &DataStruct) -> Result<&Field>90 fn get_map_field(strukt: &DataStruct) -> Result<&Field> {
91 for field in &strukt.fields {
92 if let Ok((key_type, _value_type)) = get_map_type(&field.ty) {
93 if let syn::Type::Path(path) = &key_type {
94 if path.path.segments.len() == 1
95 && path.path.segments[0].ident == "NameAndVersion"
96 {
97 return Ok(field);
98 }
99 }
100 }
101 }
102 return Err(Error::new_spanned(strukt.struct_token, "No field of type NameAndVersionMap"));
103 }
104
get_map_type(typ: &Type) -> Result<(&Type, &Type)>105 fn get_map_type(typ: &Type) -> Result<(&Type, &Type)> {
106 if let syn::Type::Path(path) = &typ {
107 if path.path.segments.len() == 1 && path.path.segments[0].ident == "BTreeMap" {
108 if let PathArguments::AngleBracketed(args) = &path.path.segments[0].arguments {
109 if args.args.len() == 2 {
110 return Ok((get_type(&args.args[0])?, get_type(&args.args[1])?));
111 }
112 }
113 }
114 }
115 Err(Error::new_spanned(typ, "Must be BTreeMap"))
116 }
117
get_type(arg: &GenericArgument) -> Result<&Type>118 fn get_type(arg: &GenericArgument) -> Result<&Type> {
119 if let GenericArgument::Type(typ) = arg {
120 return Ok(typ);
121 }
122 Err(Error::new_spanned(arg, "Could not extract argument type"))
123 }
124 }
125