1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "keymint_benchmark"
18
19 #include <iostream>
20
21 #include <base/command_line.h>
22 #include <benchmark/benchmark.h>
23
24 #include <aidl/Vintf.h>
25 #include <aidl/android/hardware/security/keymint/ErrorCode.h>
26 #include <aidl/android/hardware/security/keymint/IKeyMintDevice.h>
27 #include <android/binder_manager.h>
28 #include <binder/IServiceManager.h>
29
30 #include <keymint_support/authorization_set.h>
31 #include <keymint_support/openssl_utils.h>
32 #include <openssl/curve25519.h>
33 #include <openssl/x509.h>
34
35 #define SMALL_MESSAGE_SIZE 64
36 #define MEDIUM_MESSAGE_SIZE 1024
37 #define LARGE_MESSAGE_SIZE 131072
38
39 namespace aidl::android::hardware::security::keymint::test {
40
41 ::std::ostream& operator<<(::std::ostream& os, const keymint::AuthorizationSet& set);
42
43 using ::android::sp;
44 using Status = ::ndk::ScopedAStatus;
45 using ::std::optional;
46 using ::std::shared_ptr;
47 using ::std::string;
48 using ::std::vector;
49
50 class KeyMintBenchmarkTest {
51 public:
KeyMintBenchmarkTest()52 KeyMintBenchmarkTest() {
53 message_cache_.push_back(string(SMALL_MESSAGE_SIZE, 'x'));
54 message_cache_.push_back(string(MEDIUM_MESSAGE_SIZE, 'x'));
55 message_cache_.push_back(string(LARGE_MESSAGE_SIZE, 'x'));
56 }
57
newInstance(const char * instanceName)58 static KeyMintBenchmarkTest* newInstance(const char* instanceName) {
59 if (AServiceManager_isDeclared(instanceName)) {
60 ::ndk::SpAIBinder binder(AServiceManager_waitForService(instanceName));
61 KeyMintBenchmarkTest* test = new KeyMintBenchmarkTest();
62 test->InitializeKeyMint(IKeyMintDevice::fromBinder(binder));
63 return test;
64 } else {
65 return nullptr;
66 }
67 }
68
getError()69 int getError() { return static_cast<int>(error_); }
70
GenerateMessage(int size)71 const string GenerateMessage(int size) {
72 for (const string& message : message_cache_) {
73 if (message.size() == size) {
74 return message;
75 }
76 }
77 string message = string(size, 'x');
78 message_cache_.push_back(message);
79 return message;
80 }
81
getBlockMode(string transform)82 optional<BlockMode> getBlockMode(string transform) {
83 if (transform.find("/ECB") != string::npos) {
84 return BlockMode::ECB;
85 } else if (transform.find("/CBC") != string::npos) {
86 return BlockMode::CBC;
87 } else if (transform.find("/CTR") != string::npos) {
88 return BlockMode::CTR;
89 } else if (transform.find("/GCM") != string::npos) {
90 return BlockMode::GCM;
91 }
92 return {};
93 }
94
getPadding(string transform,bool sign)95 PaddingMode getPadding(string transform, bool sign) {
96 if (transform.find("/PKCS7") != string::npos) {
97 return PaddingMode::PKCS7;
98 } else if (transform.find("/PSS") != string::npos) {
99 return PaddingMode::RSA_PSS;
100 } else if (transform.find("/OAEP") != string::npos) {
101 return PaddingMode::RSA_OAEP;
102 } else if (transform.find("/PKCS1") != string::npos) {
103 return sign ? PaddingMode::RSA_PKCS1_1_5_SIGN : PaddingMode::RSA_PKCS1_1_5_ENCRYPT;
104 } else if (sign && transform.find("RSA") != string::npos) {
105 // RSA defaults to PKCS1 for sign
106 return PaddingMode::RSA_PKCS1_1_5_SIGN;
107 }
108 return PaddingMode::NONE;
109 }
110
getAlgorithm(string transform)111 optional<Algorithm> getAlgorithm(string transform) {
112 if (transform.find("AES") != string::npos) {
113 return Algorithm::AES;
114 } else if (transform.find("Hmac") != string::npos) {
115 return Algorithm::HMAC;
116 } else if (transform.find("DESede") != string::npos) {
117 return Algorithm::TRIPLE_DES;
118 } else if (transform.find("RSA") != string::npos) {
119 return Algorithm::RSA;
120 } else if (transform.find("EC") != string::npos) {
121 return Algorithm::EC;
122 }
123 std::cerr << "Can't find algorithm for " << transform << std::endl;
124 return {};
125 }
126
getAlgorithmString(string transform)127 string getAlgorithmString(string transform) {
128 if (transform.find("AES") != string::npos) {
129 return "AES";
130 } else if (transform.find("Hmac") != string::npos) {
131 return "HMAC";
132 } else if (transform.find("DESede") != string::npos) {
133 return "TRIPLE_DES";
134 } else if (transform.find("RSA") != string::npos) {
135 return "RSA";
136 } else if (transform.find("EC") != string::npos) {
137 return "EC";
138 }
139 std::cerr << "Can't find algorithm for " << transform << std::endl;
140 return "";
141 }
142
getDigest(string transform)143 Digest getDigest(string transform) {
144 if (transform.find("MD5") != string::npos) {
145 return Digest::MD5;
146 } else if (transform.find("SHA1") != string::npos ||
147 transform.find("SHA-1") != string::npos) {
148 return Digest::SHA1;
149 } else if (transform.find("SHA224") != string::npos) {
150 return Digest::SHA_2_224;
151 } else if (transform.find("SHA256") != string::npos) {
152 return Digest::SHA_2_256;
153 } else if (transform.find("SHA384") != string::npos) {
154 return Digest::SHA_2_384;
155 } else if (transform.find("SHA512") != string::npos) {
156 return Digest::SHA_2_512;
157 } else if (transform.find("RSA") != string::npos &&
158 transform.find("OAEP") != string::npos) {
159 if (securityLevel_ == SecurityLevel::STRONGBOX) {
160 return Digest::SHA_2_256;
161 } else {
162 return Digest::SHA1;
163 }
164 } else if (transform.find("Hmac") != string::npos) {
165 return Digest::SHA_2_256;
166 }
167 return Digest::NONE;
168 }
169
getDigestString(string transform)170 string getDigestString(string transform) {
171 if (transform.find("MD5") != string::npos) {
172 return "MD5";
173 } else if (transform.find("SHA1") != string::npos ||
174 transform.find("SHA-1") != string::npos) {
175 return "SHA1";
176 } else if (transform.find("SHA224") != string::npos) {
177 return "SHA_2_224";
178 } else if (transform.find("SHA256") != string::npos) {
179 return "SHA_2_256";
180 } else if (transform.find("SHA384") != string::npos) {
181 return "SHA_2_384";
182 } else if (transform.find("SHA512") != string::npos) {
183 return "SHA_2_512";
184 } else if (transform.find("RSA") != string::npos &&
185 transform.find("OAEP") != string::npos) {
186 if (securityLevel_ == SecurityLevel::STRONGBOX) {
187 return "SHA_2_256";
188 } else {
189 return "SHA1";
190 }
191 } else if (transform.find("Hmac") != string::npos) {
192 return "SHA_2_256";
193 }
194 return "";
195 }
196
getCurveFromLength(int keySize)197 optional<EcCurve> getCurveFromLength(int keySize) {
198 switch (keySize) {
199 case 224:
200 return EcCurve::P_224;
201 case 256:
202 return EcCurve::P_256;
203 case 384:
204 return EcCurve::P_384;
205 case 521:
206 return EcCurve::P_521;
207 default:
208 return std::nullopt;
209 }
210 }
211
GenerateKey(string transform,int keySize,bool sign=false)212 bool GenerateKey(string transform, int keySize, bool sign = false) {
213 if (transform == key_transform_) {
214 return true;
215 } else if (key_transform_ != "") {
216 // Deleting old key first
217 key_transform_ = "";
218 if (DeleteKey() != ErrorCode::OK) {
219 return false;
220 }
221 }
222 std::optional<Algorithm> algorithm = getAlgorithm(transform);
223 if (!algorithm) {
224 std::cerr << "Error: invalid algorithm " << transform << std::endl;
225 return false;
226 }
227 key_transform_ = transform;
228 AuthorizationSetBuilder authSet = AuthorizationSetBuilder()
229 .Authorization(TAG_NO_AUTH_REQUIRED)
230 .Authorization(TAG_PURPOSE, KeyPurpose::ENCRYPT)
231 .Authorization(TAG_PURPOSE, KeyPurpose::DECRYPT)
232 .Authorization(TAG_PURPOSE, KeyPurpose::SIGN)
233 .Authorization(TAG_PURPOSE, KeyPurpose::VERIFY)
234 .Authorization(TAG_KEY_SIZE, keySize)
235 .Authorization(TAG_ALGORITHM, algorithm.value())
236 .Digest(getDigest(transform))
237 .Padding(getPadding(transform, sign));
238 std::optional<BlockMode> blockMode = getBlockMode(transform);
239 if (blockMode) {
240 authSet.BlockMode(blockMode.value());
241 if (blockMode == BlockMode::GCM) {
242 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
243 }
244 }
245 if (algorithm == Algorithm::HMAC) {
246 authSet.Authorization(TAG_MIN_MAC_LENGTH, 128);
247 }
248 if (algorithm == Algorithm::RSA) {
249 authSet.Authorization(TAG_RSA_PUBLIC_EXPONENT, 65537U);
250 authSet.SetDefaultValidity();
251 }
252 if (algorithm == Algorithm::EC) {
253 authSet.SetDefaultValidity();
254 std::optional<EcCurve> curve = getCurveFromLength(keySize);
255 if (!curve) {
256 std::cerr << "Error: invalid EC-Curve from size " << keySize << std::endl;
257 return false;
258 }
259 authSet.Authorization(TAG_EC_CURVE, curve.value());
260 }
261 error_ = GenerateKey(authSet);
262 return error_ == ErrorCode::OK;
263 }
264
getOperationParams(string transform,bool sign=false)265 AuthorizationSet getOperationParams(string transform, bool sign = false) {
266 AuthorizationSetBuilder builder = AuthorizationSetBuilder()
267 .Padding(getPadding(transform, sign))
268 .Digest(getDigest(transform));
269 std::optional<BlockMode> blockMode = getBlockMode(transform);
270 if (sign && (transform.find("Hmac") != string::npos)) {
271 builder.Authorization(TAG_MAC_LENGTH, 128);
272 }
273 if (blockMode) {
274 builder.BlockMode(*blockMode);
275 if (blockMode == BlockMode::GCM) {
276 builder.Authorization(TAG_MAC_LENGTH, 128);
277 }
278 }
279 return std::move(builder);
280 }
281
Process(const string & message,const string & signature="")282 optional<string> Process(const string& message, const string& signature = "") {
283 ErrorCode result;
284
285 string output;
286 result = Finish(message, signature, &output);
287 if (result != ErrorCode::OK) {
288 error_ = result;
289 return {};
290 }
291 return output;
292 }
293
DeleteKey()294 ErrorCode DeleteKey() {
295 Status result = keymint_->deleteKey(key_blob_);
296 key_blob_ = vector<uint8_t>();
297 return GetReturnErrorCode(result);
298 }
299
Begin(KeyPurpose purpose,const AuthorizationSet & in_params,AuthorizationSet * out_params)300 ErrorCode Begin(KeyPurpose purpose, const AuthorizationSet& in_params,
301 AuthorizationSet* out_params) {
302 Status result;
303 BeginResult out;
304 result = keymint_->begin(purpose, key_blob_, in_params.vector_data(), std::nullopt, &out);
305 if (result.isOk()) {
306 *out_params = out.params;
307 op_ = out.operation;
308 }
309 return GetReturnErrorCode(result);
310 }
311
312 /* Copied the function LocalRsaEncryptMessage from
313 * hardware/interfaces/security/keymint/aidl/vts/functional/KeyMintAidlTestBase.cpp in VTS.
314 * Replaced asserts with the condition check and return false in case of failure condition.
315 * Require return value to skip the benchmark test case from further execution in case
316 * LocalRsaEncryptMessage fails.
317 */
LocalRsaEncryptMessage(const string & message,const AuthorizationSet & params)318 optional<string> LocalRsaEncryptMessage(const string& message, const AuthorizationSet& params) {
319 // Retrieve the public key from the leaf certificate.
320 if (cert_chain_.empty()) {
321 std::cerr << "Local RSA encrypt Error: invalid cert_chain_" << std::endl;
322 return "Failure";
323 }
324 X509_Ptr key_cert(parse_cert_blob(cert_chain_[0].encodedCertificate));
325 EVP_PKEY_Ptr pub_key(X509_get_pubkey(key_cert.get()));
326 RSA_Ptr rsa(EVP_PKEY_get1_RSA(const_cast<EVP_PKEY*>(pub_key.get())));
327
328 // Retrieve relevant tags.
329 Digest digest = Digest::NONE;
330 Digest mgf_digest = Digest::SHA1;
331 PaddingMode padding = PaddingMode::NONE;
332
333 auto digest_tag = params.GetTagValue(TAG_DIGEST);
334 if (digest_tag.has_value()) digest = digest_tag.value();
335 auto pad_tag = params.GetTagValue(TAG_PADDING);
336 if (pad_tag.has_value()) padding = pad_tag.value();
337 auto mgf_tag = params.GetTagValue(TAG_RSA_OAEP_MGF_DIGEST);
338 if (mgf_tag.has_value()) mgf_digest = mgf_tag.value();
339
340 const EVP_MD* md = openssl_digest(digest);
341 const EVP_MD* mgf_md = openssl_digest(mgf_digest);
342
343 // Set up encryption context.
344 EVP_PKEY_CTX_Ptr ctx(EVP_PKEY_CTX_new(pub_key.get(), /* engine= */ nullptr));
345 if (EVP_PKEY_encrypt_init(ctx.get()) <= 0) {
346 std::cerr << "Local RSA encrypt Error: Encryption init failed" << std::endl;
347 return "Failure";
348 }
349
350 int rc = -1;
351 switch (padding) {
352 case PaddingMode::NONE:
353 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_NO_PADDING);
354 break;
355 case PaddingMode::RSA_PKCS1_1_5_ENCRYPT:
356 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_PADDING);
357 break;
358 case PaddingMode::RSA_OAEP:
359 rc = EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING);
360 break;
361 default:
362 break;
363 }
364 if (rc <= 0) {
365 std::cerr << "Local RSA encrypt Error: Set padding failed" << std::endl;
366 return "Failure";
367 }
368 if (padding == PaddingMode::RSA_OAEP) {
369 if (!EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), md)) {
370 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
371 << std::endl;
372 return "Failure";
373 }
374 if (!EVP_PKEY_CTX_set_rsa_mgf1_md(ctx.get(), mgf_md)) {
375 std::cerr << "Local RSA encrypt Error: Set digest failed: " << ERR_peek_last_error()
376 << std::endl;
377 return "Failure";
378 }
379 }
380
381 // Determine output size.
382 size_t outlen;
383 if (EVP_PKEY_encrypt(ctx.get(), nullptr /* out */, &outlen,
384 reinterpret_cast<const uint8_t*>(message.data()),
385 message.size()) <= 0) {
386 std::cerr << "Local RSA encrypt Error: Determine output size failed: "
387 << ERR_peek_last_error() << std::endl;
388 return "Failure";
389 }
390
391 // Left-zero-pad the input if necessary.
392 const uint8_t* to_encrypt = reinterpret_cast<const uint8_t*>(message.data());
393 size_t to_encrypt_len = message.size();
394
395 std::unique_ptr<string> zero_padded_message;
396 if (padding == PaddingMode::NONE && to_encrypt_len < outlen) {
397 zero_padded_message.reset(new string(outlen, '\0'));
398 memcpy(zero_padded_message->data() + (outlen - to_encrypt_len), message.data(),
399 message.size());
400 to_encrypt = reinterpret_cast<const uint8_t*>(zero_padded_message->data());
401 to_encrypt_len = outlen;
402 }
403
404 // Do the encryption.
405 string output(outlen, '\0');
406 if (EVP_PKEY_encrypt(ctx.get(), reinterpret_cast<uint8_t*>(output.data()), &outlen,
407 to_encrypt, to_encrypt_len) <= 0) {
408 std::cerr << "Local RSA encrypt Error: Encryption failed: " << ERR_peek_last_error()
409 << std::endl;
410 return "Failure";
411 }
412 return output;
413 }
414
415 SecurityLevel securityLevel_;
416 string name_;
417
418 private:
GenerateKey(const AuthorizationSet & key_desc,const optional<AttestationKey> & attest_key=std::nullopt)419 ErrorCode GenerateKey(const AuthorizationSet& key_desc,
420 const optional<AttestationKey>& attest_key = std::nullopt) {
421 key_blob_.clear();
422 cert_chain_.clear();
423 KeyCreationResult creationResult;
424 Status result = keymint_->generateKey(key_desc.vector_data(), attest_key, &creationResult);
425 if (result.isOk()) {
426 key_blob_ = std::move(creationResult.keyBlob);
427 cert_chain_ = std::move(creationResult.certificateChain);
428 creationResult.keyCharacteristics.clear();
429 }
430 return GetReturnErrorCode(result);
431 }
432
InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint)433 void InitializeKeyMint(std::shared_ptr<IKeyMintDevice> keyMint) {
434 if (!keyMint) {
435 std::cerr << "Trying initialize nullptr in InitializeKeyMint" << std::endl;
436 return;
437 }
438 keymint_ = std::move(keyMint);
439 KeyMintHardwareInfo info;
440 Status result = keymint_->getHardwareInfo(&info);
441 if (!result.isOk()) {
442 std::cerr << "InitializeKeyMint: getHardwareInfo failed with "
443 << result.getServiceSpecificError() << std::endl;
444 }
445 securityLevel_ = info.securityLevel;
446 name_.assign(info.keyMintName.begin(), info.keyMintName.end());
447 }
448
Finish(const string & input,const string & signature,string * output)449 ErrorCode Finish(const string& input, const string& signature, string* output) {
450 if (!op_) {
451 std::cerr << "Finish: Operation is nullptr" << std::endl;
452 return ErrorCode::UNEXPECTED_NULL_POINTER;
453 }
454
455 vector<uint8_t> oPut;
456 Status result =
457 op_->finish(vector<uint8_t>(input.begin(), input.end()),
458 vector<uint8_t>(signature.begin(), signature.end()), {} /* authToken */,
459 {} /* timestampToken */, {} /* confirmationToken */, &oPut);
460
461 if (result.isOk()) output->append(oPut.begin(), oPut.end());
462
463 op_.reset();
464 return GetReturnErrorCode(result);
465 }
466
Update(const string & input,string * output)467 ErrorCode Update(const string& input, string* output) {
468 Status result;
469 if (!op_) {
470 std::cerr << "Update: Operation is nullptr" << std::endl;
471 return ErrorCode::UNEXPECTED_NULL_POINTER;
472 }
473
474 std::vector<uint8_t> o_put;
475 result = op_->update(vector<uint8_t>(input.begin(), input.end()), {} /* authToken */,
476 {} /* timestampToken */, &o_put);
477
478 if (result.isOk() && output) *output = {o_put.begin(), o_put.end()};
479 return GetReturnErrorCode(result);
480 }
481
GetReturnErrorCode(const Status & result)482 ErrorCode GetReturnErrorCode(const Status& result) {
483 error_ = static_cast<ErrorCode>(result.getServiceSpecificError());
484 if (result.isOk()) return ErrorCode::OK;
485
486 if (result.getExceptionCode() == EX_SERVICE_SPECIFIC) {
487 return static_cast<ErrorCode>(result.getServiceSpecificError());
488 }
489
490 return ErrorCode::UNKNOWN_ERROR;
491 }
492
parse_cert_blob(const vector<uint8_t> & blob)493 X509_Ptr parse_cert_blob(const vector<uint8_t>& blob) {
494 const uint8_t* p = blob.data();
495 return X509_Ptr(d2i_X509(nullptr /* allocate new */, &p, blob.size()));
496 }
497
498 std::shared_ptr<IKeyMintOperation> op_;
499 vector<Certificate> cert_chain_;
500 vector<uint8_t> key_blob_;
501 vector<KeyCharacteristics> key_characteristics_;
502 std::shared_ptr<IKeyMintDevice> keymint_;
503 std::vector<string> message_cache_;
504 std::string key_transform_;
505 ErrorCode error_;
506 };
507
508 KeyMintBenchmarkTest* keymintTest;
509
settings(benchmark::internal::Benchmark * benchmark)510 static void settings(benchmark::internal::Benchmark* benchmark) {
511 benchmark->Unit(benchmark::kMillisecond);
512 }
513
addDefaultLabel(benchmark::State & state)514 static void addDefaultLabel(benchmark::State& state) {
515 std::string secLevel;
516 switch (keymintTest->securityLevel_) {
517 case SecurityLevel::STRONGBOX:
518 secLevel = "STRONGBOX";
519 break;
520 case SecurityLevel::SOFTWARE:
521 secLevel = "SOFTWARE";
522 break;
523 case SecurityLevel::TRUSTED_ENVIRONMENT:
524 secLevel = "TEE";
525 break;
526 case SecurityLevel::KEYSTORE:
527 secLevel = "KEYSTORE";
528 break;
529 }
530 state.SetLabel("hardware_name:" + keymintTest->name_ + " sec_level:" + secLevel);
531 }
532
533 // clang-format off
534 #define BENCHMARK_KM(func, transform, keySize) \
535 BENCHMARK_CAPTURE(func, transform/keySize, #transform "/" #keySize, keySize)->Apply(settings);
536 #define BENCHMARK_KM_MSG(func, transform, keySize, msgSize) \
537 BENCHMARK_CAPTURE(func, transform/keySize/msgSize, #transform "/" #keySize "/" #msgSize, \
538 keySize, msgSize) \
539 ->Apply(settings);
540
541 #define BENCHMARK_KM_ALL_MSGS(func, transform, keySize) \
542 BENCHMARK_KM_MSG(func, transform, keySize, SMALL_MESSAGE_SIZE) \
543 BENCHMARK_KM_MSG(func, transform, keySize, MEDIUM_MESSAGE_SIZE) \
544 BENCHMARK_KM_MSG(func, transform, keySize, LARGE_MESSAGE_SIZE)
545
546 #define BENCHMARK_KM_CIPHER(transform, keySize, msgSize) \
547 BENCHMARK_KM_MSG(encrypt, transform, keySize, msgSize) \
548 BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
549
550 // Skip public key operations as they are not supported in KeyMint.
551 #define BENCHMARK_KM_ASYM_CIPHER(transform, keySize, msgSize) \
552 BENCHMARK_KM_MSG(decrypt, transform, keySize, msgSize)
553
554 #define BENCHMARK_KM_CIPHER_ALL_MSGS(transform, keySize) \
555 BENCHMARK_KM_ALL_MSGS(encrypt, transform, keySize) \
556 BENCHMARK_KM_ALL_MSGS(decrypt, transform, keySize)
557
558 #define BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, keySize) \
559 BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
560 BENCHMARK_KM_ALL_MSGS(verify, transform, keySize)
561
562 // Skip public key operations as they are not supported in KeyMint.
563 #define BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, keySize) \
564 BENCHMARK_KM_ALL_MSGS(sign, transform, keySize) \
565 // clang-format on
566
567 /*
568 * ============= KeyGen TESTS ==================
569 */
570
isValidSBKeySize(string transform,int keySize)571 static bool isValidSBKeySize(string transform, int keySize) {
572 std::optional<Algorithm> algorithm = keymintTest->getAlgorithm(transform);
573 switch (algorithm.value()) {
574 case Algorithm::AES:
575 return (keySize == 128 || keySize == 256);
576 case Algorithm::HMAC:
577 return (keySize % 8 == 0 && keySize >= 64 && keySize <= 512);
578 case Algorithm::TRIPLE_DES:
579 return (keySize == 168);
580 case Algorithm::RSA:
581 return (keySize == 2048);
582 case Algorithm::EC:
583 return (keySize == 256);
584 }
585 return false;
586 }
587
keygen(benchmark::State & state,string transform,int keySize)588 static void keygen(benchmark::State& state, string transform, int keySize) {
589 // Skip the test for unsupported key size in StrongBox
590 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
591 !isValidSBKeySize(transform, keySize)) {
592 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
593 " is not supported in StrongBox for algorithm: " +
594 keymintTest->getAlgorithmString(transform))
595 .c_str());
596 return;
597 }
598 addDefaultLabel(state);
599 for (auto _ : state) {
600 if (!keymintTest->GenerateKey(transform, keySize)) {
601 state.SkipWithError(
602 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
603 }
604 state.PauseTiming();
605
606 keymintTest->DeleteKey();
607 state.ResumeTiming();
608 }
609 }
610
611 BENCHMARK_KM(keygen, AES, 128);
612 BENCHMARK_KM(keygen, AES, 256);
613
614 BENCHMARK_KM(keygen, RSA, 2048);
615 BENCHMARK_KM(keygen, RSA, 3072);
616 BENCHMARK_KM(keygen, RSA, 4096);
617
618 BENCHMARK_KM(keygen, EC, 224);
619 BENCHMARK_KM(keygen, EC, 256);
620 BENCHMARK_KM(keygen, EC, 384);
621 BENCHMARK_KM(keygen, EC, 521);
622
623 BENCHMARK_KM(keygen, DESede, 168);
624
625 BENCHMARK_KM(keygen, Hmac, 64);
626 BENCHMARK_KM(keygen, Hmac, 128);
627 BENCHMARK_KM(keygen, Hmac, 256);
628 BENCHMARK_KM(keygen, Hmac, 512);
629
630 /*
631 * ============= SIGNATURE TESTS ==================
632 */
sign(benchmark::State & state,string transform,int keySize,int msgSize)633 static void sign(benchmark::State& state, string transform, int keySize, int msgSize) {
634 // Skip the test for unsupported key size or unsupported digest in StrongBox
635 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
636 if (!isValidSBKeySize(transform, keySize)) {
637 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
638 " is not supported in StrongBox for algorithm: " +
639 keymintTest->getAlgorithmString(transform))
640 .c_str());
641 return;
642 }
643 if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
644 state.SkipWithError(
645 ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
646 " is not supported in StrongBox")
647 .c_str());
648 return;
649 }
650 }
651 addDefaultLabel(state);
652 if (!keymintTest->GenerateKey(transform, keySize, true)) {
653 state.SkipWithError(
654 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
655 return;
656 }
657
658 auto in_params = keymintTest->getOperationParams(transform, true);
659 AuthorizationSet out_params;
660 string message = keymintTest->GenerateMessage(msgSize);
661
662 for (auto _ : state) {
663 state.PauseTiming();
664 ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
665 if (error != ErrorCode::OK) {
666 state.SkipWithError(
667 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
668 return;
669 }
670 state.ResumeTiming();
671 out_params.Clear();
672 if (!keymintTest->Process(message)) {
673 state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
674 break;
675 }
676 }
677 }
678
verify(benchmark::State & state,string transform,int keySize,int msgSize)679 static void verify(benchmark::State& state, string transform, int keySize, int msgSize) {
680 // Skip the test for unsupported key size or unsupported digest in StrongBox
681 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX) {
682 if (!isValidSBKeySize(transform, keySize)) {
683 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
684 " is not supported in StrongBox for algorithm: " +
685 keymintTest->getAlgorithmString(transform))
686 .c_str());
687 return;
688 }
689 if (keymintTest->getDigest(transform) != Digest::SHA_2_256) {
690 state.SkipWithError(
691 ("Skipped for STRONGBOX: Digest: " + keymintTest->getDigestString(transform) +
692 " is not supported in StrongBox")
693 .c_str());
694 return;
695 }
696 }
697 addDefaultLabel(state);
698 if (!keymintTest->GenerateKey(transform, keySize, true)) {
699 state.SkipWithError(
700 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
701 return;
702 }
703 AuthorizationSet out_params;
704 auto in_params = keymintTest->getOperationParams(transform, true);
705 string message = keymintTest->GenerateMessage(msgSize);
706 ErrorCode error = keymintTest->Begin(KeyPurpose::SIGN, in_params, &out_params);
707 if (error != ErrorCode::OK) {
708 state.SkipWithError(
709 ("Error beginning sign, " + std::to_string(keymintTest->getError())).c_str());
710 return;
711 }
712 std::optional<string> signature = keymintTest->Process(message);
713 if (!signature) {
714 state.SkipWithError(("Sign error, " + std::to_string(keymintTest->getError())).c_str());
715 return;
716 }
717 out_params.Clear();
718 if (transform.find("Hmac") != string::npos) {
719 in_params = keymintTest->getOperationParams(transform, false);
720 }
721 for (auto _ : state) {
722 state.PauseTiming();
723 error = keymintTest->Begin(KeyPurpose::VERIFY, in_params, &out_params);
724 if (error != ErrorCode::OK) {
725 state.SkipWithError(
726 ("Verify begin error, " + std::to_string(keymintTest->getError())).c_str());
727 return;
728 }
729 state.ResumeTiming();
730 if (!keymintTest->Process(message, *signature)) {
731 state.SkipWithError(
732 ("Verify error, " + std::to_string(keymintTest->getError())).c_str());
733 break;
734 }
735 }
736 }
737
738 // clang-format off
739 #define BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(transform) \
740 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 64) \
741 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 128) \
742 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 256) \
743 BENCHMARK_KM_SIGNATURE_ALL_MSGS(transform, 512)
744
745 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA1)
746 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
747 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA224)
748 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA256)
749 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA384)
750 BENCHMARK_KM_SIGNATURE_ALL_HMAC_KEYS(HmacSHA512)
751
752 #define BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(transform) \
753 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 224) \
754 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 256) \
755 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 384) \
756 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 521)
757
758 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(NONEwithECDSA);
759 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA1withECDSA);
760 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA224withECDSA);
761 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA256withECDSA);
762 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA384withECDSA);
763 BENCHMARK_KM_SIGNATURE_ALL_ECDSA_KEYS(SHA512withECDSA);
764
765 #define BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(transform) \
766 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 2048) \
767 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 3072) \
768 BENCHMARK_KM_ASYM_SIGNATURE_ALL_MSGS(transform, 4096)
769
770 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA);
771 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA);
772 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA);
773 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA256withRSA);
774 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA);
775 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA);
776
777 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(MD5withRSA/PSS);
778 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA1withRSA/PSS);
779 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA224withRSA/PSS);
780 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA384withRSA/PSS);
781 BENCHMARK_KM_SIGNATURE_ALL_RSA_KEYS(SHA512withRSA/PSS);
782
783 // clang-format on
784
785 /*
786 * ============= CIPHER TESTS ==================
787 */
788
encrypt(benchmark::State & state,string transform,int keySize,int msgSize)789 static void encrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
790 // Skip the test for unsupported key size in StrongBox
791 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
792 (!isValidSBKeySize(transform, keySize))) {
793 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
794 " is not supported in StrongBox for algorithm: " +
795 keymintTest->getAlgorithmString(transform))
796 .c_str());
797 return;
798 }
799 addDefaultLabel(state);
800 if (!keymintTest->GenerateKey(transform, keySize)) {
801 state.SkipWithError(
802 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
803 return;
804 }
805 auto in_params = keymintTest->getOperationParams(transform);
806 AuthorizationSet out_params;
807 string message = keymintTest->GenerateMessage(msgSize);
808
809 for (auto _ : state) {
810 state.PauseTiming();
811 auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
812 if (error != ErrorCode::OK) {
813 state.SkipWithError(
814 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
815 return;
816 }
817 out_params.Clear();
818 state.ResumeTiming();
819 if (!keymintTest->Process(message)) {
820 state.SkipWithError(
821 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
822 break;
823 }
824 }
825 }
826
decrypt(benchmark::State & state,string transform,int keySize,int msgSize)827 static void decrypt(benchmark::State& state, string transform, int keySize, int msgSize) {
828 // Skip the test for unsupported key size in StrongBox
829 if (keymintTest->securityLevel_ == SecurityLevel::STRONGBOX &&
830 (!isValidSBKeySize(transform, keySize))) {
831 state.SkipWithError(("Skipped for STRONGBOX: Keysize: " + std::to_string(keySize) +
832 " is not supported in StrongBox for algorithm: " +
833 keymintTest->getAlgorithmString(transform))
834 .c_str());
835 return;
836 }
837 addDefaultLabel(state);
838 if (!keymintTest->GenerateKey(transform, keySize)) {
839 state.SkipWithError(
840 ("Key generation error, " + std::to_string(keymintTest->getError())).c_str());
841 return;
842 }
843 AuthorizationSet out_params;
844 AuthorizationSet in_params = keymintTest->getOperationParams(transform);
845 string message = keymintTest->GenerateMessage(msgSize);
846 optional<string> encryptedMessage;
847
848 if (keymintTest->getAlgorithm(transform).value() == Algorithm::RSA) {
849 // Public key operation not supported, doing local Encryption
850 encryptedMessage = keymintTest->LocalRsaEncryptMessage(message, in_params);
851 if ((keySize / 8) != (*encryptedMessage).size()) {
852 state.SkipWithError("Local Encryption falied");
853 return;
854 }
855 } else {
856 auto error = keymintTest->Begin(KeyPurpose::ENCRYPT, in_params, &out_params);
857 if (error != ErrorCode::OK) {
858 state.SkipWithError(
859 ("Encryption begin error, " + std::to_string(keymintTest->getError())).c_str());
860 return;
861 }
862 encryptedMessage = keymintTest->Process(message);
863 if (!encryptedMessage) {
864 state.SkipWithError(
865 ("Encryption error, " + std::to_string(keymintTest->getError())).c_str());
866 return;
867 }
868 in_params.push_back(out_params);
869 out_params.Clear();
870 }
871 for (auto _ : state) {
872 state.PauseTiming();
873 auto error = keymintTest->Begin(KeyPurpose::DECRYPT, in_params, &out_params);
874 if (error != ErrorCode::OK) {
875 state.SkipWithError(
876 ("Decryption begin error, " + std::to_string(keymintTest->getError())).c_str());
877 return;
878 }
879 state.ResumeTiming();
880 if (!keymintTest->Process(*encryptedMessage)) {
881 state.SkipWithError(
882 ("Decryption error, " + std::to_string(keymintTest->getError())).c_str());
883 break;
884 }
885 }
886 }
887
888 // clang-format off
889 // AES
890 #define BENCHMARK_KM_CIPHER_ALL_AES_KEYS(transform) \
891 BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 128) \
892 BENCHMARK_KM_CIPHER_ALL_MSGS(transform, 256)
893
894 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/NoPadding);
895 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CBC/PKCS7Padding);
896 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/CTR/NoPadding);
897 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/NoPadding);
898 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/ECB/PKCS7Padding);
899 BENCHMARK_KM_CIPHER_ALL_AES_KEYS(AES/GCM/NoPadding);
900
901 // Triple DES
902 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/NoPadding, 168);
903 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/CBC/PKCS7Padding, 168);
904 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/NoPadding, 168);
905 BENCHMARK_KM_CIPHER_ALL_MSGS(DESede/ECB/PKCS7Padding, 168);
906
907 #define BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(transform, msgSize) \
908 BENCHMARK_KM_ASYM_CIPHER(transform, 2048, msgSize) \
909 BENCHMARK_KM_ASYM_CIPHER(transform, 3072, msgSize) \
910 BENCHMARK_KM_ASYM_CIPHER(transform, 4096, msgSize)
911
912 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/NoPadding, SMALL_MESSAGE_SIZE);
913 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/PKCS1Padding, SMALL_MESSAGE_SIZE);
914 BENCHMARK_KM_CIPHER_ALL_RSA_KEYS(RSA/ECB/OAEPPadding, SMALL_MESSAGE_SIZE);
915
916 // clang-format on
917 } // namespace aidl::android::hardware::security::keymint::test
918
main(int argc,char ** argv)919 int main(int argc, char** argv) {
920 ::benchmark::Initialize(&argc, argv);
921 base::CommandLine::Init(argc, argv);
922 base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
923 auto service_name = command_line->GetSwitchValueASCII("service_name");
924 if (service_name.empty()) {
925 service_name =
926 std::string(
927 aidl::android::hardware::security::keymint::IKeyMintDevice::descriptor) +
928 "/default";
929 }
930 std::cerr << service_name << std::endl;
931 aidl::android::hardware::security::keymint::test::keymintTest =
932 aidl::android::hardware::security::keymint::test::KeyMintBenchmarkTest::newInstance(
933 service_name.c_str());
934 if (!aidl::android::hardware::security::keymint::test::keymintTest) {
935 return 1;
936 }
937 ::benchmark::RunSpecifiedBenchmarks();
938 }
939