1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "keymint_benchmark"
18 
19 #include <iostream>
20 
21 #include <base/command_line.h>
22 #include <benchmark/benchmark.h>
23 
24 #include <aidl/Vintf.h>
25 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
26 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
27 #include <android/binder_manager.h>
28 #include <binder/IServiceManager.h>
29 
30 #include <keymint_support/authorization_set.h>
31 #include <keymint_support/openssl_utils.h>
32 #include <openssl/curve25519.h>
33 #include <openssl/x509.h>
34 
35 #define SMALL_MESSAGE_SIZE 64
36 #define MEDIUM_MESSAGE_SIZE 1024
37 #define LARGE_MESSAGE_SIZE 131072
38 
39 namespace aidl::android::hardware::security::keymint::test {
40 
41 ::std::ostream& operator<<(::std::ostream& os, const keymint::AuthorizationSet& set);
42 
43 using ::android::sp;
44 using Status = ::ndk::ScopedAStatus;
45 using ::std::optional;
46 using ::std::shared_ptr;
47 using ::std::string;
48 using ::std::vector;
49 
50 class KeyMintBenchmarkTest {
51   public:
KeyMintBenchmarkTest()52     KeyMintBenchmarkTest() {
53         message_cache_.push_back(string(SMALL_MESSAGE_SIZE, 'x'));
54         message_cache_.push_back(string(MEDIUM_MESSAGE_SIZE, 'x'));
55         message_cache_.push_back(string(LARGE_MESSAGE_SIZE, 'x'));
56     }
57 
newInstance(const char * instanceName)58     static KeyMintBenchmarkTest* newInstance(const char* instanceName) {
59         if (AServiceManager_isDeclared(instanceName)) {
60             ::ndk::SpAIBinder binder(AServiceManager_waitForService(instanceName));
61             KeyMintBenchmarkTest* test = new KeyMintBenchmarkTest();
62             test->InitializeKeyMint(IKeyMintDevice::fromBinder(binder));
63             return test;
64         } else {
65             return nullptr;
66         }
67     }
68 
getError()69     int getError() { return static_cast<int>(error_); }
70 
GenerateMessage(int size)71     const string GenerateMessage(int size) {
72         for (const string& message : message_cache_) {
73             if (message.size() == size) {
74                 return message;
75             }
76         }
77         string message = string(size, 'x');
78         message_cache_.push_back(message);
79         return message;
80     }
81 
getBlockMode(string transform)82     optional<BlockMode> getBlockMode(string transform) {
83         if (transform.find("/ECB") != string::npos) {
84             return BlockMode::ECB;
85         } else if (transform.find("/CBC") != string::npos) {
86             return BlockMode::CBC;
87         } else if (transform.find("/CTR") != string::npos) {
88             return BlockMode::CTR;
89         } else if (transform.find("/GCM") != string::npos) {
90             return BlockMode::GCM;
91         }
92         return {};
93     }
94 
getPadding(string transform,bool sign)95     PaddingMode getPadding(string transform, bool sign) {
96         if (transform.find("/PKCS7") != string::npos) {
97             return PaddingMode::PKCS7;
98         } else if (transform.find("/PSS") != string::npos) {
99             return PaddingMode::RSA_PSS;
100         } else if (transform.find("/OAEP") != string::npos) {
101             return PaddingMode::RSA_OAEP;
102         } else if (transform.find("/PKCS1") != string::npos) {
103             return sign ? PaddingMode::RSA_PKCS1_1_5_SIGN : PaddingMode::RSA_PKCS1_1_5_ENCRYPT;
104         } else if (sign && transform.find("RSA") != string::npos) {
105             // RSA defaults to PKCS1 for sign
106             return PaddingMode::RSA_PKCS1_1_5_SIGN;
107         }
108         return PaddingMode::NONE;
109     }
110 
getAlgorithm(string transform)111     optional<Algorithm> getAlgorithm(string transform) {
112         if (transform.find("AES") != string::npos) {
113             return Algorithm::AES;
114         } else if (transform.find("Hmac") != string::npos) {
115             return Algorithm::HMAC;
116         } else if (transform.find("DESede") != string::npos) {
117             return Algorithm::TRIPLE_DES;
118         } else if (transform.find("RSA") != string::npos) {
119             return Algorithm::RSA;
120         } else if (transform.find("EC") != string::npos) {
121             return Algorithm::EC;
122         }
123         std::cerr << "Can't find algorithm for " << transform << std::endl;
124         return {};
125     }
126 
getAlgorithmString(string transform)127     string getAlgorithmString(string transform) {
128         if (transform.find("AES") != string::npos) {
129             return "AES";
130         } else if (transform.find("Hmac") != string::npos) {
131             return "HMAC";
132         } else if (transform.find("DESede") != string::npos) {
133             return "TRIPLE_DES";
134         } else if (transform.find("RSA") != string::npos) {
135             return "RSA";
136         } else if (transform.find("EC") != string::npos) {
137             return "EC";
138         }
139         std::cerr << "Can't find algorithm for " << transform << std::endl;
140         return "";
141     }
142 
getDigest(string transform)143     Digest getDigest(string transform) {
144         if (transform.find("MD5") != string::npos) {
145             return Digest::MD5;
146         } else if (transform.find("SHA1") != string::npos ||
147                    transform.find("SHA-1") != string::npos) {
148             return Digest::SHA1;
149         } else if (transform.find("SHA224") != string::npos) {
150             return Digest::SHA_2_224;
151         } else if (transform.find("SHA256") != string::npos) {
152             return Digest::SHA_2_256;
153         } else if (transform.find("SHA384") != string::npos) {
154             return Digest::SHA_2_384;
155         } else if (transform.find("SHA512") != string::npos) {
156             return Digest::SHA_2_512;
157         } else if (transform.find("RSA") != string::npos &&
158                    transform.find("OAEP") != string::npos) {
159             if (securityLevel_ == SecurityLevel::STRONGBOX) {
160                 return Digest::SHA_2_256;
161             } else {
162                 return Digest::SHA1;
163             }
164         } else if (transform.find("Hmac") != string::npos) {
165             return Digest::SHA_2_256;
166         }
167         return Digest::NONE;
168     }
169 
getDigestString(string transform)170     string getDigestString(string transform) {
171         if (transform.find("MD5") != string::npos) {
172             return "MD5";
173         } else if (transform.find("SHA1") != string::npos ||
174                    transform.find("SHA-1") != string::npos) {
175             return "SHA1";
176         } else if (transform.find("SHA224") != string::npos) {
177             return "SHA_2_224";
178         } else if (transform.find("SHA256") != string::npos) {
179             return "SHA_2_256";
180         } else if (transform.find("SHA384") != string::npos) {
181             return "SHA_2_384";
182         } else if (transform.find("SHA512") != string::npos) {
183             return "SHA_2_512";
184         } else if (transform.find("RSA") != string::npos &&
185                    transform.find("OAEP") != string::npos) {
186             if (securityLevel_ == SecurityLevel::STRONGBOX) {
187                 return "SHA_2_256";
188             } else {
189                 return "SHA1";
190             }
191         } else if (transform.find("Hmac") != string::npos) {
192             return "SHA_2_256";
193         }
194         return "";
195     }
196 
getCurveFromLength(int keySize)197     optional<EcCurve> getCurveFromLength(int keySize) {
198         switch (keySize) {
199             case 224:
200                 return EcCurve::P_224;
201             case 256:
202                 return EcCurve::P_256;
203             case 384:
204                 return EcCurve::P_384;
205             case 521:
206                 return EcCurve::P_521;
207             default:
208                 return std::nullopt;
209         }
210     }
211 
GenerateKey(string transform,int keySize,bool sign=false)212     bool GenerateKey(string transform, int keySize, bool sign = false) {
213         if (transform == key_transform_) {
214             return true;
215         } else if (key_transform_ != "") {
216             // Deleting old key first
217             key_transform_ = "";
218             if (DeleteKey() != ErrorCode::OK) {
219                 return false;
220             }
221         }
222         std::optional<Algorithm> algorithm = getAlgorithm(transform);
223         if (!algorithm) {
224             std::cerr << "Error: invalid algorithm " << transform << std::endl;
225             return false;
226         }
227         key_transform_ = transform;
228         AuthorizationSetBuilder authSet = AuthorizationSetBuilder()
229                                                   .Authorization(TAG_NO_AUTH_REQUIRED)
230                                                   .Authorization(TAG_PURPOSE, KeyPurpose::ENCRYPT)
231                                                   .Authorization(TAG_PURPOSE, KeyPurpose::DECRYPT)
232                                                   .Authorization(TAG_PURPOSE, KeyPurpose::SIGN)
233                                                   .Authorization(TAG_PURPOSE, KeyPurpose::VERIFY)
234                                                   .Authorization(TAG_KEY_SIZE, keySize)
235                                                   .Authorization(TAG_ALGORITHM, algorithm.value())
236                                                   .Digest(getDigest(transform))
237                                                   .Padding(getPadding(transform, sign));
238         std::optional<BlockMode> blockMode = getBlockMode(transform);
239         if (blockMode) {
240             authSet.BlockMode(blockMode.value());
241             if (blockMode == BlockMode::GCM) {
242                 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
243             }
244         }
245         if (algorithm == Algorithm::HMAC) {
246             authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
247         }
248         if (algorithm == Algorithm::RSA) {
249             authSet.Authorization(TAG_RSA_PUBLIC_EXPONENT, 65537U);
250             authSet.SetDefaultValidity();
251         }
252         if (algorithm == Algorithm::EC) {
253             authSet.SetDefaultValidity();
254             std::optional<EcCurve> curve = getCurveFromLength(keySize);
255             if (!curve) {
256                 std::cerr << "Error: invalid EC-Curve from size " << keySize << std::endl;
257                 return false;
258             }
259             authSet.Authorization(TAG_EC_CURVE, curve.value());
260         }
261         error_ = GenerateKey(authSet);
262         return error_ == ErrorCode::OK;
263     }
264 
getOperationParams(string transform,bool sign=false)265     AuthorizationSet getOperationParams(string transform, bool sign = false) {
266         AuthorizationSetBuilder builder = AuthorizationSetBuilder()
267                                                   .Padding(getPadding(transform, sign))
268                                                   .Digest(getDigest(transform));
269         std::optional<BlockMode> blockMode = getBlockMode(transform);
270         if (sign && (transform.find("Hmac") != string::npos)) {
271             builder.Authorization(TAG_MAC_LENGTH, 128);
272         }
273         if (blockMode) {
274             builder.BlockMode(*blockMode);
275             if (blockMode == BlockMode::GCM) {
276                 builder.Authorization(TAG_MAC_LENGTH, 128);
277             }
278         }
279         return std::move(builder);
280     }
281 
Process(const string & message,const string & signature="")282     optional<string> Process(const string& message, const string& signature = "") {
283         ErrorCode result;
284 
285         string output;
286         result = Finish(message, signature, &output);
287         if (result != ErrorCode::OK) {
288             error_ = result;
289             return {};
290         }
291         return output;
292     }
293 
DeleteKey()294     ErrorCode DeleteKey() {
295         Status result = keymint_->deleteKey(key_blob_);
296         key_blob_ = vector<uint8_t>();
297         return GetReturnErrorCode(result);
298     }
299 
Begin(KeyPurpose purpose,const AuthorizationSet & in_params,AuthorizationSet * out_params)300     ErrorCode Begin(KeyPurpose purpose, const AuthorizationSet& in_params,
301                     AuthorizationSet* out_params) {
302         Status result;
303         BeginResult out;
304         result = keymint_->begin(purpose, key_blob_, in_params.vector_data(), std::nullopt, &out);
305         if (result.isOk()) {
306             *out_params = out.params;
307             op_ = out.operation;
308         }
309         return GetReturnErrorCode(result);
310     }
311 
312     /* Copied the function LocalRsaEncryptMessage from
313      * hardware/interfaces/security/keymint/aidl/vts/functional/KeyMintAidlTestBase.cpp in VTS.
314      * Replaced asserts with the condition check and return false in case of failure condition.
315      * Require return value to skip the benchmark test case from further execution in case
316      * LocalRsaEncryptMessage fails.
317      */
LocalRsaEncryptMessage(const string & message,const AuthorizationSet & params)318     optional<string> LocalRsaEncryptMessage(const string& message, const AuthorizationSet& params) {
319         // Retrieve the public key from the leaf certificate.
320         if (cert_chain_.empty()) {
321             std::cerr << "Local RSA encrypt Error: invalid cert_chain_" << std::endl;
322             return "Failure";
323         }
324         X509_Ptr key_cert(parse_cert_blob(cert_chain_[0].encodedCertificate));
325         EVP_PKEY_Ptr pub_key(X509_get_pubkey(key_cert.get()));
326         RSA_Ptr rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(pub_key.get())));
327 
328         // Retrieve relevant tags.
329         Digest digest = Digest::NONE;
330         Digest mgf_digest = Digest::SHA1;
331         PaddingMode padding = PaddingMode::NONE;
332 
333         auto digest_tag = params.GetTagValue(TAG_DIGEST);
334         if (digest_tag.has_value()) digest = digest_tag.value();
335         auto pad_tag = params.GetTagValue(TAG_PADDING);
336         if (pad_tag.has_value()) padding = pad_tag.value();
337         auto mgf_tag = params.GetTagValue(TAG_RSA_OAEP_MGF_DIGEST);
338         if (mgf_tag.has_value()) mgf_digest = mgf_tag.value();
339 
340         const EVP_MD* md = openssl_digest(digest);
341         const EVP_MD* mgf_md = openssl_digest(mgf_digest);
342 
343         // Set up encryption context.
344         EVP_PKEY_CTX_Ptr ctx(EVP_PKEY_CTX_new(pub_key.get(), /* engine= */ nullptr));
345         if (EVP_PKEY_encrypt_init(ctx.get()) <= 0) {
346             std::cerr << "Local RSA encrypt Error: Encryption init failed" << std::endl;
347             return "Failure";
348         }
349 
350         int rc = -1;
351         switch (padding) {
352             case PaddingMode::NONE:
353                 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_NO_PADDING);
354                 break;
355             case PaddingMode::RSA_PKCS1_1_5_ENCRYPT:
356                 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING);
357                 break;
358             case PaddingMode::RSA_OAEP:
359                 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING);
360                 break;
361             default:
362                 break;
363         }
364         if (rc <= 0) {
365             std::cerr << "Local RSA encrypt Error: Set padding failed" << std::endl;
366             return "Failure";
367         }
368         if (padding == PaddingMode::RSA_OAEP) {
369             if (!EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), md)) {
370                 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
371                           << std::endl;
372                 return "Failure";
373             }
374             if (!EVP_PKEY_CTX_set_rsa_mgf1_md(ctx.get(), mgf_md)) {
375                 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
376                           << std::endl;
377                 return "Failure";
378             }
379         }
380 
381         // Determine output size.
382         size_t outlen;
383         if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen,
384                              reinterpret_cast<const uint8_t*>(message.data()),
385                              message.size()) <= 0) {
386             std::cerr << "Local RSA encrypt Error: Determine output size failed: "
387                       << ERR_peek_last_error() << std::endl;
388             return "Failure";
389         }
390 
391         // Left-zero-pad the input if necessary.
392         const uint8_t* to_encrypt = reinterpret_cast<const uint8_t*>(message.data());
393         size_t to_encrypt_len = message.size();
394 
395         std::unique_ptr<string> zero_padded_message;
396         if (padding == PaddingMode::NONE && to_encrypt_len < outlen) {
397             zero_padded_message.reset(new string(outlen, '\0'));
398             memcpy(zero_padded_message->data() + (outlen - to_encrypt_len), message.data(),
399                    message.size());
400             to_encrypt = reinterpret_cast<const uint8_t*>(zero_padded_message->data());
401             to_encrypt_len = outlen;
402         }
403 
404         // Do the encryption.
405         string output(outlen, '\0');
406         if (EVP_PKEY_encrypt(ctx.get(), reinterpret_cast<uint8_t*>(output.data()), &outlen,
407                              to_encrypt, to_encrypt_len) <= 0) {
408             std::cerr << "Local RSA encrypt Error: Encryption failed: " << ERR_peek_last_error()
409                       << std::endl;
410             return "Failure";
411         }
412         return output;
413     }
414 
415     SecurityLevel securityLevel_;
416     string name_;
417 
418   private:
GenerateKey(const AuthorizationSet & key_desc,const optional<AttestationKey> & attest_key=std::nullopt)419     ErrorCode GenerateKey(const AuthorizationSet& key_desc,
420                           const optional<AttestationKey>& attest_key = std::nullopt) {
421         key_blob_.clear();
422         cert_chain_.clear();
423         KeyCreationResult creationResult;
424         Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
425         if (result.isOk()) {
426             key_blob_ = std::move(creationResult.keyBlob);
427             cert_chain_ = std::move(creationResult.certificateChain);
428             creationResult.keyCharacteristics.clear();
429         }
430         return GetReturnErrorCode(result);
431     }
432 
InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint)433     void InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint) {
434         if (!keyMint) {
435             std::cerr << "Trying initialize nullptr in InitializeKeyMint" << std::endl;
436             return;
437         }
438         keymint_ = std::move(keyMint);
439         KeyMintHardwareInfo info;
440         Status result = keymint_->getHardwareInfo(&info);
441         if (!result.isOk()) {
442             std::cerr << "InitializeKeyMint: getHardwareInfo failed with "
443                       << result.getServiceSpecificError() << std::endl;
444         }
445         securityLevel_ = info.securityLevel;
446         name_.assign(info.keyMintName.begin(), info.keyMintName.end());
447     }
448 
Finish(const string & input,const string & signature,string * output)449     ErrorCode Finish(const string& input, const string& signature, string* output) {
450         if (!op_) {
451             std::cerr << "Finish: Operation is nullptr" << std::endl;
452             return ErrorCode::UNEXPECTED_NULL_POINTER;
453         }
454 
455         vector<uint8_t> oPut;
456         Status result =
457                 op_->finish(vector<uint8_t>(input.begin(), input.end()),
458                             vector<uint8_t>(signature.begin(), signature.end()), {} /* authToken */,
459                             {} /* timestampToken */, {} /* confirmationToken */, &oPut);
460 
461         if (result.isOk()) output->append(oPut.begin(), oPut.end());
462 
463         op_.reset();
464         return GetReturnErrorCode(result);
465     }
466 
Update(const string & input,string * output)467     ErrorCode Update(const string& input, string* output) {
468         Status result;
469         if (!op_) {
470             std::cerr << "Update: Operation is nullptr" << std::endl;
471             return ErrorCode::UNEXPECTED_NULL_POINTER;
472         }
473 
474         std::vector<uint8_t> o_put;
475         result = op_->update(vector<uint8_t>(input.begin(), input.end()), {} /* authToken */,
476                              {} /* timestampToken */, &o_put);
477 
478         if (result.isOk() && output) *output = {o_put.begin(), o_put.end()};
479         return GetReturnErrorCode(result);
480     }
481 
GetReturnErrorCode(const Status & result)482     ErrorCode GetReturnErrorCode(const Status& result) {
483         error_ = static_cast<ErrorCode>(result.getServiceSpecificError());
484         if (result.isOk()) return ErrorCode::OK;
485 
486         if (result.getExceptionCode() == EX_SERVICE_SPECIFIC) {
487             return static_cast<ErrorCode>(result.getServiceSpecificError());
488         }
489 
490         return ErrorCode::UNKNOWN_ERROR;
491     }
492 
parse_cert_blob(const vector<uint8_t> & blob)493     X509_Ptr parse_cert_blob(const vector<uint8_t>& blob) {
494         const uint8_t* p = blob.data();
495         return X509_Ptr(d2i_X509(nullptr /* allocate new */, &p, blob.size()));
496     }
497 
498     std::shared_ptr<IKeyMintOperation> op_;
499     vector<Certificate> cert_chain_;
500     vector<uint8_t> key_blob_;
501     vector<KeyCharacteristics> key_characteristics_;
502     std::shared_ptr<IKeyMintDevice> keymint_;
503     std::vector<string> message_cache_;
504     std::string key_transform_;
505     ErrorCode error_;
506 };
507 
508 KeyMintBenchmarkTest* keymintTest;
509 
settings(benchmark::internal::Benchmark * benchmark)510 static void settings(benchmark::internal::Benchmark* benchmark) {
511     benchmark->Unit(benchmark::kMillisecond);
512 }
513 
addDefaultLabel(benchmark::State & state)514 static void addDefaultLabel(benchmark::State& state) {
515     std::string secLevel;
516     switch (keymintTest->securityLevel_) {
517         case SecurityLevel::STRONGBOX:
518             secLevel = "STRONGBOX";
519             break;
520         case SecurityLevel::SOFTWARE:
521             secLevel = "SOFTWARE";
522             break;
523         case SecurityLevel::TRUSTED_ENVIRONMENT:
524             secLevel = "TEE";
525             break;
526         case SecurityLevel::KEYSTORE:
527             secLevel = "KEYSTORE";
528             break;
529     }
530     state.SetLabel("hardware_name:" + keymintTest->name_ + " sec_level:" + secLevel);
531 }
532 
533 // clang-format off
534 #define BENCHMARK_KM(func, transform, keySize) \
535     BENCHMARK_CAPTURE(func, transform/keySize, #transform "/" #keySize, keySize)->Apply(settings);
536 #define BENCHMARK_KM_MSG(func, transform, keySize, msgSize)                                      \
537     BENCHMARK_CAPTURE(func, transform/keySize/msgSize, #transform "/" #keySize "/" #msgSize, \
538                       keySize, msgSize)                                                          \
539             ->Apply(settings);
540 
541 #define BENCHMARK_KM_ALL_MSGS(func, transform, keySize)             \
542     BENCHMARK_KM_MSG(func, transform, keySize, SMALL_MESSAGE_SIZE)  \
543     BENCHMARK_KM_MSG(func, transform, keySize, MEDIUM_MESSAGE_SIZE) \
544     BENCHMARK_KM_MSG(func, transform, keySize, LARGE_MESSAGE_SIZE)
545 
546 #define BENCHMARK_KM_CIPHER(transform, keySize, msgSize)   \
547     BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
548     BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
549 
550 // Skip public key operations as they are not supported in KeyMint.
551 #define BENCHMARK_KM_ASYM_CIPHER(transform, keySize, msgSize)   \
552     BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
553 
554 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
555     BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize)   \
556     BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
557 
558 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
559     BENCHMARK_KM_ALL_MSGS(sign, transform, keySize)         \
560     BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
561 
562 // Skip public key operations as they are not supported in KeyMint.
563 #define BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, keySize) \
564     BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
565     // clang-format on
566 
567 /*
568  * ============= KeyGen TESTS ==================
569  */
570 
isValidSBKeySize(string transform,int keySize)571 static bool isValidSBKeySize(string transform, int keySize) {
572     std::optional<Algorithm> algorithm = keymintTest->getAlgorithm(transform);
573     switch (algorithm.value()) {
574         case Algorithm::AES:
575             return (keySize == 128 || keySize == 256);
576         case Algorithm::HMAC:
577             return (keySize % 8 == 0 && keySize >= 64 && keySize <= 512);
578         case Algorithm::TRIPLE_DES:
579             return (keySize == 168);
580         case Algorithm::RSA:
581             return (keySize == 2048);
582         case Algorithm::EC:
583             return (keySize == 256);
584     }
585     return false;
586 }
587 
keygen(benchmark::State & state,string transform,int keySize)588 static void keygen(benchmark::State& state, string transform, int keySize) {
589     // Skip the test for unsupported key size in StrongBox
590     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
591         !isValidSBKeySize(transform, keySize)) {
592         state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
593                              " is not supported in StrongBox for algorithm: " +
594                              keymintTest->getAlgorithmString(transform))
595                                     .c_str());
596         return;
597     }
598     addDefaultLabel(state);
599     for (auto _ : state) {
600         if (!keymintTest->GenerateKey(transform, keySize)) {
601             state.SkipWithError(
602                     ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
603         }
604         state.PauseTiming();
605 
606         keymintTest->DeleteKey();
607         state.ResumeTiming();
608     }
609 }
610 
611 BENCHMARK_KM(keygen, AES, 128);
612 BENCHMARK_KM(keygen, AES, 256);
613 
614 BENCHMARK_KM(keygen, RSA, 2048);
615 BENCHMARK_KM(keygen, RSA, 3072);
616 BENCHMARK_KM(keygen, RSA, 4096);
617 
618 BENCHMARK_KM(keygen, EC, 224);
619 BENCHMARK_KM(keygen, EC, 256);
620 BENCHMARK_KM(keygen, EC, 384);
621 BENCHMARK_KM(keygen, EC, 521);
622 
623 BENCHMARK_KM(keygen, DESede, 168);
624 
625 BENCHMARK_KM(keygen, Hmac, 64);
626 BENCHMARK_KM(keygen, Hmac, 128);
627 BENCHMARK_KM(keygen, Hmac, 256);
628 BENCHMARK_KM(keygen, Hmac, 512);
629 
630 /*
631  * ============= SIGNATURE TESTS ==================
632  */
sign(benchmark::State & state,string transform,int keySize,int msgSize)633 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
634     // Skip the test for unsupported key size or unsupported digest in StrongBox
635     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
636         if (!isValidSBKeySize(transform, keySize)) {
637             state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
638                                  " is not supported in StrongBox for algorithm: " +
639                                  keymintTest->getAlgorithmString(transform))
640                                         .c_str());
641             return;
642         }
643         if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
644             state.SkipWithError(
645                     ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
646                      " is not supported in StrongBox")
647                             .c_str());
648             return;
649         }
650     }
651     addDefaultLabel(state);
652     if (!keymintTest->GenerateKey(transform, keySize, true)) {
653         state.SkipWithError(
654                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
655         return;
656     }
657 
658     auto in_params = keymintTest->getOperationParams(transform, true);
659     AuthorizationSet out_params;
660     string message = keymintTest->GenerateMessage(msgSize);
661 
662     for (auto _ : state) {
663         state.PauseTiming();
664         ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
665         if (error != ErrorCode::OK) {
666             state.SkipWithError(
667                     ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
668             return;
669         }
670         state.ResumeTiming();
671         out_params.Clear();
672         if (!keymintTest->Process(message)) {
673             state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
674             break;
675         }
676     }
677 }
678 
verify(benchmark::State & state,string transform,int keySize,int msgSize)679 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
680     // Skip the test for unsupported key size or unsupported digest in StrongBox
681     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
682         if (!isValidSBKeySize(transform, keySize)) {
683             state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
684                                  " is not supported in StrongBox for algorithm: " +
685                                  keymintTest->getAlgorithmString(transform))
686                                         .c_str());
687             return;
688         }
689         if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
690             state.SkipWithError(
691                     ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
692                      " is not supported in StrongBox")
693                             .c_str());
694             return;
695         }
696     }
697     addDefaultLabel(state);
698     if (!keymintTest->GenerateKey(transform, keySize, true)) {
699         state.SkipWithError(
700                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
701         return;
702     }
703     AuthorizationSet out_params;
704     auto in_params = keymintTest->getOperationParams(transform, true);
705     string message = keymintTest->GenerateMessage(msgSize);
706     ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
707     if (error != ErrorCode::OK) {
708         state.SkipWithError(
709                 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
710         return;
711     }
712     std::optional<string> signature = keymintTest->Process(message);
713     if (!signature) {
714         state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
715         return;
716     }
717     out_params.Clear();
718     if (transform.find("Hmac") != string::npos) {
719         in_params = keymintTest->getOperationParams(transform, false);
720     }
721     for (auto _ : state) {
722         state.PauseTiming();
723         error = keymintTest->Begin(KeyPurpose::VERIFY, in_params, &out_params);
724         if (error != ErrorCode::OK) {
725             state.SkipWithError(
726                     ("Verify begin error, " + std::to_string(keymintTest->getError())).c_str());
727             return;
728         }
729         state.ResumeTiming();
730         if (!keymintTest->Process(message, *signature)) {
731             state.SkipWithError(
732                     ("Verify error, " + std::to_string(keymintTest->getError())).c_str());
733             break;
734         }
735     }
736 }
737 
738 // clang-format off
739 #define BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(transform) \
740     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 64)      \
741     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 128)     \
742     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256)     \
743     BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 512)
744 
745 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA1)
746 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
747 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA224)
748 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
749 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA384)
750 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
751 
752 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
753     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 224)      \
754     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 256)      \
755     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 384)      \
756     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 521)
757 
758 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
759 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
760 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA224withECDSA);
761 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA256withECDSA);
762 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA384withECDSA);
763 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
764 
765 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
766     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 2048)   \
767     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 3072)   \
768     BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 4096)
769 
770 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
771 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
772 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
773 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA256withRSA);
774 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
775 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
776 
777 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA/PSS);
778 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA/PSS);
779 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
780 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
781 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
782 
783 // clang-format on
784 
785 /*
786  * ============= CIPHER TESTS ==================
787  */
788 
encrypt(benchmark::State & state,string transform,int keySize,int msgSize)789 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
790     // Skip the test for unsupported key size in StrongBox
791     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
792         (!isValidSBKeySize(transform, keySize))) {
793         state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
794                              " is not supported in StrongBox for algorithm: " +
795                              keymintTest->getAlgorithmString(transform))
796                                     .c_str());
797         return;
798     }
799     addDefaultLabel(state);
800     if (!keymintTest->GenerateKey(transform, keySize)) {
801         state.SkipWithError(
802                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
803         return;
804     }
805     auto in_params = keymintTest->getOperationParams(transform);
806     AuthorizationSet out_params;
807     string message = keymintTest->GenerateMessage(msgSize);
808 
809     for (auto _ : state) {
810         state.PauseTiming();
811         auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
812         if (error != ErrorCode::OK) {
813             state.SkipWithError(
814                     ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
815             return;
816         }
817         out_params.Clear();
818         state.ResumeTiming();
819         if (!keymintTest->Process(message)) {
820             state.SkipWithError(
821                     ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
822             break;
823         }
824     }
825 }
826 
decrypt(benchmark::State & state,string transform,int keySize,int msgSize)827 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
828     // Skip the test for unsupported key size in StrongBox
829     if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
830         (!isValidSBKeySize(transform, keySize))) {
831         state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
832                              " is not supported in StrongBox for algorithm: " +
833                              keymintTest->getAlgorithmString(transform))
834                                     .c_str());
835         return;
836     }
837     addDefaultLabel(state);
838     if (!keymintTest->GenerateKey(transform, keySize)) {
839         state.SkipWithError(
840                 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
841         return;
842     }
843     AuthorizationSet out_params;
844     AuthorizationSet in_params = keymintTest->getOperationParams(transform);
845     string message = keymintTest->GenerateMessage(msgSize);
846     optional<string> encryptedMessage;
847 
848     if (keymintTest->getAlgorithm(transform).value() == Algorithm::RSA) {
849         // Public key operation not supported, doing local Encryption
850         encryptedMessage = keymintTest->LocalRsaEncryptMessage(message, in_params);
851         if ((keySize / 8) != (*encryptedMessage).size()) {
852             state.SkipWithError("Local Encryption falied");
853             return;
854         }
855     } else {
856         auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
857         if (error != ErrorCode::OK) {
858             state.SkipWithError(
859                     ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
860             return;
861         }
862         encryptedMessage = keymintTest->Process(message);
863         if (!encryptedMessage) {
864             state.SkipWithError(
865                     ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
866             return;
867         }
868         in_params.push_back(out_params);
869         out_params.Clear();
870     }
871     for (auto _ : state) {
872         state.PauseTiming();
873         auto error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
874         if (error != ErrorCode::OK) {
875             state.SkipWithError(
876                     ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
877             return;
878         }
879         state.ResumeTiming();
880         if (!keymintTest->Process(*encryptedMessage)) {
881             state.SkipWithError(
882                     ("Decryption error, " + std::to_string(keymintTest->getError())).c_str());
883             break;
884         }
885     }
886 }
887 
888 // clang-format off
889 // AES
890 #define BENCHMARK_KM_CIPHER_ALL_AES_KEYS(transform) \
891     BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 128)    \
892     BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 256)
893 
894 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/NoPadding);
895 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/PKCS7Padding);
896 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CTR/NoPadding);
897 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/NoPadding);
898 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/PKCS7Padding);
899 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/GCM/NoPadding);
900 
901 // Triple DES
902 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/NoPadding, 168);
903 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/PKCS7Padding, 168);
904 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/NoPadding, 168);
905 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
906 
907 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
908     BENCHMARK_KM_ASYM_CIPHER(transform, 2048, msgSize)            \
909     BENCHMARK_KM_ASYM_CIPHER(transform, 3072, msgSize)            \
910     BENCHMARK_KM_ASYM_CIPHER(transform, 4096, msgSize)
911 
912 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
913 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);
914 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/OAEPPadding, SMALL_MESSAGE_SIZE);
915 
916 // clang-format on
917 }  // namespace aidl::android::hardware::security::keymint::test
918 
main(int argc,char ** argv)919 int main(int argc, char** argv) {
920     ::benchmark::Initialize(&argc, argv);
921     base::CommandLine::Init(argc, argv);
922     base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
923     auto service_name = command_line->GetSwitchValueASCII("service_name");
924     if (service_name.empty()) {
925         service_name =
926                 std::string(
927                         aidl::android::hardware::security::keymint::IKeyMintDevice::descriptor) +
928                 "/default";
929     }
930     std::cerr << service_name << std::endl;
931     aidl::android::hardware::security::keymint::test::keymintTest =
932             aidl::android::hardware::security::keymint::test::KeyMintBenchmarkTest::newInstance(
933                     service_name.c_str());
934     if (!aidl::android::hardware::security::keymint::test::keymintTest) {
935         return 1;
936     }
937     ::benchmark::RunSpecifiedBenchmarks();
938 }
939