1 // Copyright 2022, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Functionality related to RSA.
16 
17 use super::{KeyMaterial, KeySizeInBits, OpaqueOr, RsaExponent};
18 use crate::{der_err, km_err, tag, try_to_vec, Error, FallibleAllocExt};
19 use alloc::vec::Vec;
20 use der::{asn1::BitStringRef, Decode, Encode};
21 use kmr_wire::keymint::{Digest, KeyParam, PaddingMode};
22 use pkcs1::RsaPrivateKey;
23 use spki::{AlgorithmIdentifier, SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
24 use zeroize::ZeroizeOnDrop;
25 
26 /// Overhead for PKCS#1 v1.5 signature padding of undigested messages.  Digested messages have
27 /// additional overhead, for the digest algorithmIdentifier required by PKCS#1.
28 pub const PKCS1_UNDIGESTED_SIGNATURE_PADDING_OVERHEAD: usize = 11;
29 
30 /// OID value for PKCS#1-encoded RSA keys held in PKCS#8 and X.509; see RFC 3447 A.1.
31 pub const X509_OID: pkcs8::ObjectIdentifier =
32     pkcs8::ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
33 
34 /// OID value for PKCS#1 signature with SHA-256 and RSA, see RFC 4055 s5.
35 pub const SHA256_PKCS1_SIGNATURE_OID: pkcs8::ObjectIdentifier =
36     pkcs8::ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.11");
37 
38 /// An RSA key, in the form of an ASN.1 DER encoding of an PKCS#1 `RSAPrivateKey` structure,
39 /// as specified by RFC 3447 sections A.1.2 and 3.2:
40 ///
41 /// ```asn1
42 /// RSAPrivateKey ::= SEQUENCE {
43 ///     version           Version,
44 ///     modulus           INTEGER,  -- n
45 ///     publicExponent    INTEGER,  -- e
46 ///     privateExponent   INTEGER,  -- d
47 ///     prime1            INTEGER,  -- p
48 ///     prime2            INTEGER,  -- q
49 ///     exponent1         INTEGER,  -- d mod (p-1)
50 ///     exponent2         INTEGER,  -- d mod (q-1)
51 ///     coefficient       INTEGER,  -- (inverse of q) mod p
52 ///     otherPrimeInfos   OtherPrimeInfos OPTIONAL
53 /// }
54 ///
55 /// OtherPrimeInfos ::= SEQUENCE SIZE(1..MAX) OF OtherPrimeInfo
56 ///
57 /// OtherPrimeInfo ::= SEQUENCE {
58 ///     prime             INTEGER,  -- ri
59 ///     exponent          INTEGER,  -- di
60 ///     coefficient       INTEGER   -- ti
61 /// }
62 /// ```
63 #[derive(Clone, PartialEq, Eq, ZeroizeOnDrop)]
64 pub struct Key(pub Vec<u8>);
65 
66 impl Key {
67     /// Return the `subjectPublicKey` that holds an ASN.1 DER-encoded `SEQUENCE`
68     /// as per RFC 3279 section 2.3.1:
69     ///     ```asn1
70     ///     RSAPublicKey ::= SEQUENCE {
71     ///        modulus            INTEGER,    -- n
72     ///        publicExponent     INTEGER  }  -- e
73     ///     ```
subject_public_key(&self) -> Result<Vec<u8>, Error>74     pub fn subject_public_key(&self) -> Result<Vec<u8>, Error> {
75         let rsa_pvt_key = RsaPrivateKey::from_der(self.0.as_slice())
76             .map_err(|e| der_err!(e, "failed to parse RsaPrivateKey"))?;
77 
78         let rsa_pub_key = rsa_pvt_key.public_key();
79         let mut encoded_data = Vec::<u8>::new();
80         rsa_pub_key
81             .encode_to_vec(&mut encoded_data)
82             .map_err(|e| der_err!(e, "failed to encode RSA PublicKey"))?;
83         Ok(encoded_data)
84     }
85 
86     /// Size of the key in bytes.
size(&self) -> usize87     pub fn size(&self) -> usize {
88         let rsa_pvt_key = match RsaPrivateKey::from_der(self.0.as_slice()) {
89             Ok(k) => k,
90             Err(e) => {
91                 log::error!("failed to determine RSA key length: {:?}", e);
92                 return 0;
93             }
94         };
95         let len = u32::from(rsa_pvt_key.modulus.len());
96         len as usize
97     }
98 }
99 
100 impl OpaqueOr<Key> {
101     /// Encode into `buf` the public key information as an ASN.1 DER encodable
102     /// `SubjectPublicKeyInfo`, as described in RFC 5280 section 4.1.
103     ///
104     /// ```asn1
105     /// SubjectPublicKeyInfo  ::=  SEQUENCE  {
106     ///    algorithm            AlgorithmIdentifier,
107     ///    subjectPublicKey     BIT STRING  }
108     ///
109     /// AlgorithmIdentifier  ::=  SEQUENCE  {
110     ///    algorithm               OBJECT IDENTIFIER,
111     ///    parameters              ANY DEFINED BY algorithm OPTIONAL  }
112     /// ```
113     ///
114     /// For RSA keys, the contents are described in RFC 3279 section 2.3.1.
115     ///
116     /// - The `AlgorithmIdentifier` has an algorithm OID of 1.2.840.113549.1.1.1.
117     /// - The `AlgorithmIdentifier` has `NULL` parameters.
118     /// - The `subjectPublicKey` bit string holds an ASN.1 DER-encoded `SEQUENCE`:
119     ///     ```asn1
120     ///     RSAPublicKey ::= SEQUENCE {
121     ///        modulus            INTEGER,    -- n
122     ///        publicExponent     INTEGER  }  -- e
123     ///     ```
subject_public_key_info<'a>( &'a self, buf: &'a mut Vec<u8>, rsa: &dyn super::Rsa, ) -> Result<SubjectPublicKeyInfoRef<'a>, Error>124     pub fn subject_public_key_info<'a>(
125         &'a self,
126         buf: &'a mut Vec<u8>,
127         rsa: &dyn super::Rsa,
128     ) -> Result<SubjectPublicKeyInfoRef<'a>, Error> {
129         let pub_key = rsa.subject_public_key(self)?;
130         buf.try_extend_from_slice(&pub_key)?;
131         Ok(SubjectPublicKeyInfo {
132             algorithm: AlgorithmIdentifier { oid: X509_OID, parameters: Some(der::AnyRef::NULL) },
133             subject_public_key: BitStringRef::from_bytes(buf).unwrap(),
134         })
135     }
136 }
137 
138 /// RSA decryption mode.
139 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
140 pub enum DecryptionMode {
141     /// No padding.
142     NoPadding,
143     /// RSA-OAEP padding.
144     OaepPadding {
145         /// Digest to use for the message
146         msg_digest: Digest,
147         /// Digest to use in the MGF1 function.
148         mgf_digest: Digest,
149     },
150     /// PKCS#1 v1.5 padding.
151     Pkcs1_1_5Padding,
152 }
153 
154 impl DecryptionMode {
155     /// Determine the [`DecryptionMode`] from parameters.
new(params: &[KeyParam]) -> Result<Self, Error>156     pub fn new(params: &[KeyParam]) -> Result<Self, Error> {
157         let padding = tag::get_padding_mode(params)?;
158         match padding {
159             PaddingMode::None => Ok(DecryptionMode::NoPadding),
160             PaddingMode::RsaOaep => {
161                 let msg_digest = tag::get_digest(params)?;
162                 let mgf_digest = tag::get_mgf_digest(params)?;
163                 Ok(DecryptionMode::OaepPadding { msg_digest, mgf_digest })
164             }
165             PaddingMode::RsaPkcs115Encrypt => Ok(DecryptionMode::Pkcs1_1_5Padding),
166             _ => Err(km_err!(
167                 UnsupportedPaddingMode,
168                 "padding mode {:?} not supported for RSA decryption",
169                 padding
170             )),
171         }
172     }
173 }
174 
175 /// RSA signature mode.
176 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
177 pub enum SignMode {
178     /// No padding.
179     NoPadding,
180     /// RSA-PSS signature scheme using the given digest.
181     PssPadding(Digest),
182     /// PKCS#1 v1.5 padding using the given digest.
183     Pkcs1_1_5Padding(Digest),
184 }
185 
186 impl SignMode {
187     /// Determine the [`SignMode`] from parameters.
new(params: &[KeyParam]) -> Result<Self, Error>188     pub fn new(params: &[KeyParam]) -> Result<Self, Error> {
189         let padding = tag::get_padding_mode(params)?;
190         match padding {
191             PaddingMode::None => Ok(SignMode::NoPadding),
192             PaddingMode::RsaPss => {
193                 let digest = tag::get_digest(params)?;
194                 Ok(SignMode::PssPadding(digest))
195             }
196             PaddingMode::RsaPkcs115Sign => {
197                 let digest = tag::get_digest(params)?;
198                 Ok(SignMode::Pkcs1_1_5Padding(digest))
199             }
200             _ => Err(km_err!(
201                 UnsupportedPaddingMode,
202                 "padding mode {:?} not supported for RSA signing",
203                 padding
204             )),
205         }
206     }
207 }
208 
209 /// Import an RSA key in PKCS#8 format, also returning the key size in bits and public exponent.
import_pkcs8_key(data: &[u8]) -> Result<(KeyMaterial, KeySizeInBits, RsaExponent), Error>210 pub fn import_pkcs8_key(data: &[u8]) -> Result<(KeyMaterial, KeySizeInBits, RsaExponent), Error> {
211     let key_info = pkcs8::PrivateKeyInfo::try_from(data)
212         .map_err(|_| km_err!(InvalidArgument, "failed to parse PKCS#8 RSA key"))?;
213     if key_info.algorithm.oid != X509_OID {
214         return Err(km_err!(
215             InvalidArgument,
216             "unexpected OID {:?} for PKCS#1 RSA key import",
217             key_info.algorithm.oid
218         ));
219     }
220     // For RSA, the inner private key is an ASN.1 `RSAPrivateKey`, as per PKCS#1 (RFC 3447 A.1.2).
221     import_pkcs1_key(key_info.private_key)
222 }
223 
224 /// Import an RSA key in PKCS#1 format, also returning the key size in bits and public exponent.
import_pkcs1_key( private_key: &[u8], ) -> Result<(KeyMaterial, KeySizeInBits, RsaExponent), Error>225 pub fn import_pkcs1_key(
226     private_key: &[u8],
227 ) -> Result<(KeyMaterial, KeySizeInBits, RsaExponent), Error> {
228     let key = Key(try_to_vec(private_key)?);
229 
230     // Need to parse it to find size/exponent.
231     let parsed_key = pkcs1::RsaPrivateKey::try_from(private_key)
232         .map_err(|_| km_err!(InvalidArgument, "failed to parse inner PKCS#1 key"))?;
233     let key_size = parsed_key.modulus.as_bytes().len() as u32 * 8;
234 
235     let pub_exponent_bytes = parsed_key.public_exponent.as_bytes();
236     if pub_exponent_bytes.len() > 8 {
237         return Err(km_err!(
238             InvalidArgument,
239             "public exponent of length {} too big",
240             pub_exponent_bytes.len()
241         ));
242     }
243     let offset = 8 - pub_exponent_bytes.len();
244     let mut pub_exponent_arr = [0u8; 8];
245     pub_exponent_arr[offset..].copy_from_slice(pub_exponent_bytes);
246     let pub_exponent = u64::from_be_bytes(pub_exponent_arr);
247 
248     Ok((KeyMaterial::Rsa(key.into()), KeySizeInBits(key_size), RsaExponent(pub_exponent)))
249 }
250