1 /*
2  * Copyright 2020, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <certificate_utils.h>
18 
19 #include <openssl/err.h>
20 #include <openssl/evp.h>
21 #include <openssl/mem.h>
22 #include <openssl/ossl_typ.h>
23 #include <openssl/x509v3.h>
24 
25 #include <functional>
26 #include <limits>
27 #include <variant>
28 #include <vector>
29 
30 #ifndef __LP64__
31 #include <time64.h>
32 #endif
33 
34 namespace keystore {
35 
36 namespace {
37 
38 constexpr int kDigitalSignatureKeyUsageBit = 0;
39 constexpr int kKeyEnciphermentKeyUsageBit = 2;
40 constexpr int kDataEnciphermentKeyUsageBit = 3;
41 constexpr int kKeyCertSignBit = 5;
42 constexpr int kMaxKeyUsageBit = 8;
43 
44 DEFINE_OPENSSL_OBJECT_POINTER(ASN1_STRING);
45 DEFINE_OPENSSL_OBJECT_POINTER(RSA_PSS_PARAMS);
46 DEFINE_OPENSSL_OBJECT_POINTER(AUTHORITY_KEYID);
47 DEFINE_OPENSSL_OBJECT_POINTER(BASIC_CONSTRAINTS);
48 DEFINE_OPENSSL_OBJECT_POINTER(X509_ALGOR);
49 DEFINE_OPENSSL_OBJECT_POINTER(BIGNUM);
50 
51 }  // namespace
52 
53 constexpr const char kDefaultCommonName[] = "Default Common Name";
54 
55 std::variant<CertUtilsError, X509_NAME_Ptr>
makeCommonName(std::optional<std::reference_wrapper<const std::vector<uint8_t>>> name)56 makeCommonName(std::optional<std::reference_wrapper<const std::vector<uint8_t>>> name) {
57     if (name) {
58         const uint8_t* p = name->get().data();
59         X509_NAME_Ptr x509_name(d2i_X509_NAME(nullptr, &p, name->get().size()));
60         if (!x509_name) {
61             return CertUtilsError::MemoryAllocation;
62         }
63         return x509_name;
64     }
65 
66     X509_NAME_Ptr x509_name(X509_NAME_new());
67     if (!x509_name) {
68         return CertUtilsError::MemoryAllocation;
69     }
70     if (!X509_NAME_add_entry_by_txt(x509_name.get(), "CN", MBSTRING_ASC,
71                                     reinterpret_cast<const uint8_t*>(kDefaultCommonName),
72                                     sizeof(kDefaultCommonName) - 1, -1 /* loc */, 0 /* set */)) {
73         return CertUtilsError::BoringSsl;
74     }
75     return x509_name;
76 }
77 
makeKeyId(const X509 * cert)78 std::variant<CertUtilsError, std::vector<uint8_t>> makeKeyId(const X509* cert) {
79     std::vector<uint8_t> keyid(20);
80     unsigned int len;
81     if (!X509_pubkey_digest(cert, EVP_sha1(), keyid.data(), &len)) {
82         return CertUtilsError::Encoding;
83     }
84     return keyid;
85 }
86 
87 std::variant<CertUtilsError, AUTHORITY_KEYID_Ptr>
makeAuthorityKeyExtension(const std::vector<uint8_t> & keyid)88 makeAuthorityKeyExtension(const std::vector<uint8_t>& keyid) {
89     AUTHORITY_KEYID_Ptr auth_key(AUTHORITY_KEYID_new());
90     if (!auth_key) {
91         return CertUtilsError::MemoryAllocation;
92     }
93 
94     auth_key->keyid = ASN1_OCTET_STRING_new();
95     if (auth_key->keyid == nullptr) {
96         return CertUtilsError::MemoryAllocation;
97     }
98 
99     if (!ASN1_OCTET_STRING_set(auth_key->keyid, keyid.data(), keyid.size())) {
100         return CertUtilsError::BoringSsl;
101     }
102 
103     return auth_key;
104 }
105 
106 std::variant<CertUtilsError, ASN1_OCTET_STRING_Ptr>
makeSubjectKeyExtension(const std::vector<uint8_t> & keyid)107 makeSubjectKeyExtension(const std::vector<uint8_t>& keyid) {
108 
109     // Build OCTET_STRING
110     ASN1_OCTET_STRING_Ptr keyid_str(ASN1_OCTET_STRING_new());
111     if (!keyid_str || !ASN1_OCTET_STRING_set(keyid_str.get(), keyid.data(), keyid.size())) {
112         return CertUtilsError::BoringSsl;
113     }
114 
115     return keyid_str;
116 }
117 
118 std::variant<CertUtilsError, BASIC_CONSTRAINTS_Ptr>
makeBasicConstraintsExtension(bool is_ca,std::optional<int> path_length)119 makeBasicConstraintsExtension(bool is_ca, std::optional<int> path_length) {
120 
121     BASIC_CONSTRAINTS_Ptr bcons(BASIC_CONSTRAINTS_new());
122     if (!bcons) {
123         return CertUtilsError::MemoryAllocation;
124     }
125 
126     bcons->ca = is_ca;
127     bcons->pathlen = nullptr;
128     if (path_length) {
129         bcons->pathlen = ASN1_INTEGER_new();
130         if (bcons->pathlen == nullptr || !ASN1_INTEGER_set(bcons->pathlen, *path_length)) {
131             return CertUtilsError::BoringSsl;
132         }
133     }
134 
135     return bcons;
136 }
137 
138 std::variant<CertUtilsError, ASN1_BIT_STRING_Ptr>
makeKeyUsageExtension(bool is_signing_key,bool is_encryption_key,bool is_cert_key)139 makeKeyUsageExtension(bool is_signing_key, bool is_encryption_key, bool is_cert_key) {
140     // Build BIT_STRING with correct contents.
141     ASN1_BIT_STRING_Ptr key_usage(ASN1_BIT_STRING_new());
142     if (!key_usage) {
143         return CertUtilsError::BoringSsl;
144     }
145 
146     for (size_t i = 0; i <= kMaxKeyUsageBit; ++i) {
147         if (!ASN1_BIT_STRING_set_bit(key_usage.get(), i, 0)) {
148             return CertUtilsError::BoringSsl;
149         }
150     }
151 
152     if (is_signing_key) {
153         if (!ASN1_BIT_STRING_set_bit(key_usage.get(), kDigitalSignatureKeyUsageBit, 1)) {
154             return CertUtilsError::BoringSsl;
155         }
156     }
157 
158     if (is_encryption_key) {
159         if (!ASN1_BIT_STRING_set_bit(key_usage.get(), kKeyEnciphermentKeyUsageBit, 1) ||
160             !ASN1_BIT_STRING_set_bit(key_usage.get(), kDataEnciphermentKeyUsageBit, 1)) {
161             return CertUtilsError::BoringSsl;
162         }
163     }
164 
165     if (is_cert_key) {
166         if (!ASN1_BIT_STRING_set_bit(key_usage.get(), kKeyCertSignBit, 1)) {
167             return CertUtilsError::BoringSsl;
168         }
169     }
170 
171     return key_usage;
172 }
173 
174 // TODO Once boring ssl can take int64_t instead of time_t we can go back to using
175 //      ASN1_TIME_set: https://bugs.chromium.org/p/boringssl/issues/detail?id=416
toTimeString(int64_t timeMillis)176 std::optional<std::array<char, 16>> toTimeString(int64_t timeMillis) {
177     struct tm time;
178     // If timeMillis is negative the rounding direction should still be to the nearest previous
179     // second.
180     if (timeMillis < 0 && __builtin_add_overflow(timeMillis, -999, &timeMillis)) {
181         return std::nullopt;
182     }
183 #if defined(__LP64__)
184     time_t timeSeconds = timeMillis / 1000;
185     if (gmtime_r(&timeSeconds, &time) == nullptr) {
186         return std::nullopt;
187     }
188 #else
189     time64_t timeSeconds = timeMillis / 1000;
190     if (gmtime64_r(&timeSeconds, &time) == nullptr) {
191         return std::nullopt;
192     }
193 #endif
194     std::array<char, 16> buffer;
195     if (__builtin_add_overflow(time.tm_year, 1900, &time.tm_year)) {
196         return std::nullopt;
197     }
198     if (time.tm_year >= 1950 && time.tm_year < 2050) {
199         // UTCTime according to RFC5280 4.1.2.5.1.
200         snprintf(buffer.data(), buffer.size(), "%02d%02d%02d%02d%02d%02dZ", time.tm_year % 100,
201                  time.tm_mon + 1, time.tm_mday, time.tm_hour, time.tm_min, time.tm_sec);
202     } else if (time.tm_year >= 0 && time.tm_year < 10000) {
203         // GeneralizedTime according to RFC5280 4.1.2.5.2.
204         snprintf(buffer.data(), buffer.size(), "%04d%02d%02d%02d%02d%02dZ", time.tm_year,
205                  time.tm_mon + 1, time.tm_mday, time.tm_hour, time.tm_min, time.tm_sec);
206     } else {
207         return std::nullopt;
208     }
209     return buffer;
210 }
211 
212 // Creates a rump certificate structure with serial, subject and issuer names, as well as
213 // activation and expiry date.
214 // Callers should pass an empty X509_Ptr and check the return value for CertUtilsError::Ok (0)
215 // before accessing the result.
216 std::variant<CertUtilsError, X509_Ptr>
makeCertRump(std::optional<std::reference_wrapper<const std::vector<uint8_t>>> serial,std::optional<std::reference_wrapper<const std::vector<uint8_t>>> subject,const int64_t activeDateTimeMilliSeconds,const int64_t usageExpireDateTimeMilliSeconds)217 makeCertRump(std::optional<std::reference_wrapper<const std::vector<uint8_t>>> serial,
218              std::optional<std::reference_wrapper<const std::vector<uint8_t>>> subject,
219              const int64_t activeDateTimeMilliSeconds,
220              const int64_t usageExpireDateTimeMilliSeconds) {
221 
222     // Create certificate structure.
223     X509_Ptr certificate(X509_new());
224     if (!certificate) {
225         return CertUtilsError::BoringSsl;
226     }
227 
228     // Set the X509 version.
229     if (!X509_set_version(certificate.get(), 2 /* version 3, but zero-based */)) {
230         return CertUtilsError::BoringSsl;
231     }
232 
233     BIGNUM_Ptr bn_serial;
234     if (serial) {
235         bn_serial = BIGNUM_Ptr(BN_bin2bn(serial->get().data(), serial->get().size(), nullptr));
236         if (!bn_serial) {
237             return CertUtilsError::MemoryAllocation;
238         }
239     } else {
240         bn_serial = BIGNUM_Ptr(BN_new());
241         if (!bn_serial) {
242             return CertUtilsError::MemoryAllocation;
243         }
244         BN_zero(bn_serial.get());
245     }
246 
247     // Set the certificate serialNumber
248     ASN1_INTEGER_Ptr serialNumber(ASN1_INTEGER_new());
249     if (!serialNumber || !BN_to_ASN1_INTEGER(bn_serial.get(), serialNumber.get()) ||
250         !X509_set_serialNumber(certificate.get(), serialNumber.get() /* Don't release; copied */))
251         return CertUtilsError::BoringSsl;
252 
253     // Set Subject Name
254     auto subjectName = makeCommonName(subject);
255     if (auto x509_subject = std::get_if<X509_NAME_Ptr>(&subjectName)) {
256         if (!X509_set_subject_name(certificate.get(), x509_subject->get() /* copied */)) {
257             return CertUtilsError::BoringSsl;
258         }
259     } else {
260         return std::get<CertUtilsError>(subjectName);
261     }
262 
263     auto notBeforeTime = toTimeString(activeDateTimeMilliSeconds);
264     if (!notBeforeTime) {
265         return CertUtilsError::TimeError;
266     }
267     // Set activation date.
268     ASN1_TIME_Ptr notBefore(ASN1_TIME_new());
269     if (!notBefore || !ASN1_TIME_set_string(notBefore.get(), notBeforeTime->data()) ||
270         !X509_set_notBefore(certificate.get(), notBefore.get() /* Don't release; copied */))
271         return CertUtilsError::BoringSsl;
272 
273     // Set expiration date.
274     auto notAfterTime = toTimeString(usageExpireDateTimeMilliSeconds);
275     if (!notAfterTime) {
276         return CertUtilsError::TimeError;
277     }
278 
279     ASN1_TIME_Ptr notAfter(ASN1_TIME_new());
280     if (!notAfter || !ASN1_TIME_set_string(notAfter.get(), notAfterTime->data()) ||
281         !X509_set_notAfter(certificate.get(), notAfter.get() /* Don't release; copied */)) {
282         return CertUtilsError::BoringSsl;
283     }
284 
285     return certificate;
286 }
287 
288 std::variant<CertUtilsError, X509_Ptr>
makeCert(const EVP_PKEY * evp_pkey,std::optional<std::reference_wrapper<const std::vector<uint8_t>>> serial,std::optional<std::reference_wrapper<const std::vector<uint8_t>>> subject,const int64_t activeDateTimeMilliSeconds,const int64_t usageExpireDateTimeMilliSeconds,bool addSubjectKeyIdEx,std::optional<KeyUsageExtension> keyUsageEx,std::optional<BasicConstraintsExtension> basicConstraints)289 makeCert(const EVP_PKEY* evp_pkey,
290          std::optional<std::reference_wrapper<const std::vector<uint8_t>>> serial,
291          std::optional<std::reference_wrapper<const std::vector<uint8_t>>> subject,
292          const int64_t activeDateTimeMilliSeconds, const int64_t usageExpireDateTimeMilliSeconds,
293          bool addSubjectKeyIdEx, std::optional<KeyUsageExtension> keyUsageEx,
294          std::optional<BasicConstraintsExtension> basicConstraints) {
295 
296     // Make the rump certificate with serial, subject, not before and not after dates.
297     auto certificateV =
298         makeCertRump(serial, subject, activeDateTimeMilliSeconds, usageExpireDateTimeMilliSeconds);
299     if (auto error = std::get_if<CertUtilsError>(&certificateV)) {
300         return *error;
301     }
302     auto certificate = std::move(std::get<X509_Ptr>(certificateV));
303 
304     // Set the public key.
305     if (!X509_set_pubkey(certificate.get(), const_cast<EVP_PKEY*>(evp_pkey))) {
306         return CertUtilsError::BoringSsl;
307     }
308 
309     if (keyUsageEx) {
310         // Make and add the key usage extension.
311         auto key_usage_extensionV = makeKeyUsageExtension(
312             keyUsageEx->isSigningKey, keyUsageEx->isEncryptionKey, keyUsageEx->isCertificationKey);
313         if (auto error = std::get_if<CertUtilsError>(&key_usage_extensionV)) {
314             return *error;
315         }
316         auto key_usage_extension = std::move(std::get<ASN1_BIT_STRING_Ptr>(key_usage_extensionV));
317         if (!X509_add1_ext_i2d(certificate.get(), NID_key_usage,
318                                key_usage_extension.get() /* Don't release; copied */,
319                                true /* critical */, 0 /* flags */)) {
320             return CertUtilsError::BoringSsl;
321         }
322     }
323 
324     if (basicConstraints) {
325         // Make and add basic constraints
326         auto basic_constraints_extensionV =
327             makeBasicConstraintsExtension(basicConstraints->isCa, basicConstraints->pathLength);
328         if (auto error = std::get_if<CertUtilsError>(&basic_constraints_extensionV)) {
329             return *error;
330         }
331         auto basic_constraints_extension =
332             std::move(std::get<BASIC_CONSTRAINTS_Ptr>(basic_constraints_extensionV));
333         if (!X509_add1_ext_i2d(certificate.get(), NID_basic_constraints,
334                                basic_constraints_extension.get() /* Don't release; copied */,
335                                true /* critical */, 0 /* flags */)) {
336             return CertUtilsError::BoringSsl;
337         }
338     }
339 
340     if (addSubjectKeyIdEx) {
341         // Make and add subject key id extension.
342         auto keyidV = makeKeyId(certificate.get());
343         if (auto error = std::get_if<CertUtilsError>(&keyidV)) {
344             return *error;
345         }
346         auto& keyid = std::get<std::vector<uint8_t>>(keyidV);
347 
348         auto subject_key_extensionV = makeSubjectKeyExtension(keyid);
349         if (auto error = std::get_if<CertUtilsError>(&subject_key_extensionV)) {
350             return *error;
351         }
352         auto subject_key_extension =
353             std::move(std::get<ASN1_OCTET_STRING_Ptr>(subject_key_extensionV));
354         if (!X509_add1_ext_i2d(certificate.get(), NID_subject_key_identifier,
355                                subject_key_extension.get() /* Don't release; copied */,
356                                false /* critical */, 0 /* flags */)) {
357             return CertUtilsError::BoringSsl;
358         }
359     }
360 
361     return certificate;
362 }
363 
setIssuer(X509 * cert,const X509 * signingCert,bool addAuthKeyExt)364 CertUtilsError setIssuer(X509* cert, const X509* signingCert, bool addAuthKeyExt) {
365 
366     X509_NAME* issuerName(X509_get_subject_name(signingCert));
367 
368     // Set Issuer Name
369     if (issuerName) {
370         if (!X509_set_issuer_name(cert, issuerName /* copied */)) {
371             return CertUtilsError::BoringSsl;
372         }
373     } else {
374         return CertUtilsError::Encoding;
375     }
376 
377     if (addAuthKeyExt) {
378         // Make and add authority key extension - self signed.
379         auto keyidV = makeKeyId(signingCert);
380         if (auto error = std::get_if<CertUtilsError>(&keyidV)) {
381             return *error;
382         }
383         auto& keyid = std::get<std::vector<uint8_t>>(keyidV);
384 
385         auto auth_key_extensionV = makeAuthorityKeyExtension(keyid);
386         if (auto error = std::get_if<CertUtilsError>(&auth_key_extensionV)) {
387             return *error;
388         }
389         auto auth_key_extension = std::move(std::get<AUTHORITY_KEYID_Ptr>(auth_key_extensionV));
390         if (!X509_add1_ext_i2d(cert, NID_authority_key_identifier, auth_key_extension.get(), false,
391                                0)) {
392             return CertUtilsError::BoringSsl;
393         }
394     }
395     return CertUtilsError::Ok;
396 }
397 
398 // Takes a certificate a signing certificate and the raw private signing_key. And signs
399 // the certificate with the latter.
signCert(X509 * certificate,EVP_PKEY * signing_key)400 CertUtilsError signCert(X509* certificate, EVP_PKEY* signing_key) {
401 
402     if (certificate == nullptr) {
403         return CertUtilsError::UnexpectedNullPointer;
404     }
405 
406     if (!X509_sign(certificate, signing_key, EVP_sha256())) {
407         return CertUtilsError::BoringSsl;
408     }
409 
410     return CertUtilsError::Ok;
411 }
412 
encodeCert(X509 * certificate)413 std::variant<CertUtilsError, std::vector<uint8_t>> encodeCert(X509* certificate) {
414     int len = i2d_X509(certificate, nullptr);
415     if (len < 0) {
416         return CertUtilsError::BoringSsl;
417     }
418 
419     auto result = std::vector<uint8_t>(len);
420     uint8_t* p = result.data();
421 
422     if (i2d_X509(certificate, &p) < 0) {
423         return CertUtilsError::BoringSsl;
424     }
425     return result;
426 }
427 
setRsaDigestAlgorField(X509_ALGOR ** alg_ptr,const EVP_MD * digest)428 CertUtilsError setRsaDigestAlgorField(X509_ALGOR** alg_ptr, const EVP_MD* digest) {
429     if (alg_ptr == nullptr || digest == nullptr) {
430         return CertUtilsError::UnexpectedNullPointer;
431     }
432     *alg_ptr = X509_ALGOR_new();
433     if (*alg_ptr == nullptr) {
434         return CertUtilsError::MemoryAllocation;
435     }
436     X509_ALGOR_set_md(*alg_ptr, digest);
437     return CertUtilsError::Ok;
438 }
439 
setPssMaskGeneratorField(X509_ALGOR ** alg_ptr,const EVP_MD * digest)440 CertUtilsError setPssMaskGeneratorField(X509_ALGOR** alg_ptr, const EVP_MD* digest) {
441     X509_ALGOR* mgf1_digest = nullptr;
442     if (auto error = setRsaDigestAlgorField(&mgf1_digest, digest)) {
443         return error;
444     }
445     X509_ALGOR_Ptr mgf1_digest_ptr(mgf1_digest);
446 
447     ASN1_OCTET_STRING* mgf1_digest_algor_str = nullptr;
448     if (!ASN1_item_pack(mgf1_digest, ASN1_ITEM_rptr(X509_ALGOR), &mgf1_digest_algor_str)) {
449         return CertUtilsError::Encoding;
450     }
451     ASN1_OCTET_STRING_Ptr mgf1_digest_algor_str_ptr(mgf1_digest_algor_str);
452 
453     *alg_ptr = X509_ALGOR_new();
454     if (*alg_ptr == nullptr) {
455         return CertUtilsError::MemoryAllocation;
456     }
457     X509_ALGOR_set0(*alg_ptr, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, mgf1_digest_algor_str);
458     // *alg_ptr took ownership of the octet string
459     mgf1_digest_algor_str_ptr.release();
460     return CertUtilsError::Ok;
461 }
462 
setSaltLength(RSA_PSS_PARAMS * pss_params,unsigned length)463 static CertUtilsError setSaltLength(RSA_PSS_PARAMS* pss_params, unsigned length) {
464     pss_params->saltLength = ASN1_INTEGER_new();
465     if (pss_params->saltLength == nullptr) {
466         return CertUtilsError::MemoryAllocation;
467     }
468     if (!ASN1_INTEGER_set(pss_params->saltLength, length)) {
469         return CertUtilsError::Encoding;
470     };
471     return CertUtilsError::Ok;
472 }
473 
buildRsaPssParameter(Digest digest)474 std::variant<CertUtilsError, ASN1_STRING_Ptr> buildRsaPssParameter(Digest digest) {
475     RSA_PSS_PARAMS_Ptr pss(RSA_PSS_PARAMS_new());
476     if (!pss) {
477         return CertUtilsError::MemoryAllocation;
478     }
479 
480     const EVP_MD* md = nullptr;
481 
482     switch (digest) {
483     case Digest::SHA1:
484         break;
485     case Digest::SHA224:
486         md = EVP_sha224();
487         break;
488     case Digest::SHA256:
489         md = EVP_sha256();
490         break;
491     case Digest::SHA384:
492         md = EVP_sha384();
493         break;
494     case Digest::SHA512:
495         md = EVP_sha512();
496         break;
497     default:
498         return CertUtilsError::InvalidArgument;
499     }
500 
501     if (md != nullptr) {
502         if (auto error = setSaltLength(pss.get(), EVP_MD_size(md))) {
503             return error;
504         }
505         if (auto error = setRsaDigestAlgorField(&pss->hashAlgorithm, md)) {
506             return error;
507         }
508         if (auto error = setPssMaskGeneratorField(&pss->maskGenAlgorithm, md)) {
509             return error;
510         }
511     }
512 
513     ASN1_STRING* algo_str = nullptr;
514     if (!ASN1_item_pack(pss.get(), ASN1_ITEM_rptr(RSA_PSS_PARAMS), &algo_str)) {
515         return CertUtilsError::BoringSsl;
516     }
517 
518     return ASN1_STRING_Ptr(algo_str);
519 }
520 
makeAlgo(Algo algo,Padding padding,Digest digest)521 std::variant<CertUtilsError, X509_ALGOR_Ptr> makeAlgo(Algo algo, Padding padding, Digest digest) {
522     ASN1_STRING_Ptr param;
523     int param_type = V_ASN1_UNDEF;
524     int nid = 0;
525     switch (algo) {
526     case Algo::ECDSA:
527         switch (digest) {
528         case Digest::SHA1:
529             nid = NID_ecdsa_with_SHA1;
530             break;
531         case Digest::SHA224:
532             nid = NID_ecdsa_with_SHA224;
533             break;
534         case Digest::SHA256:
535             nid = NID_ecdsa_with_SHA256;
536             break;
537         case Digest::SHA384:
538             nid = NID_ecdsa_with_SHA384;
539             break;
540         case Digest::SHA512:
541             nid = NID_ecdsa_with_SHA512;
542             break;
543         default:
544             return CertUtilsError::InvalidArgument;
545         }
546         break;
547     case Algo::RSA:
548         switch (padding) {
549         case Padding::PKCS1_5:
550             param_type = V_ASN1_NULL;
551             switch (digest) {
552             case Digest::SHA1:
553                 nid = NID_sha1WithRSAEncryption;
554                 break;
555             case Digest::SHA224:
556                 nid = NID_sha224WithRSAEncryption;
557                 break;
558             case Digest::SHA256:
559                 nid = NID_sha256WithRSAEncryption;
560                 break;
561             case Digest::SHA384:
562                 nid = NID_sha384WithRSAEncryption;
563                 break;
564             case Digest::SHA512:
565                 nid = NID_sha512WithRSAEncryption;
566                 break;
567             default:
568                 return CertUtilsError::InvalidArgument;
569             }
570             break;
571         case Padding::PSS: {
572             auto v = buildRsaPssParameter(digest);
573             if (auto param_str = std::get_if<ASN1_STRING_Ptr>(&v)) {
574                 param = std::move(*param_str);
575                 param_type = V_ASN1_SEQUENCE;
576                 nid = NID_rsassaPss;
577             } else {
578                 return std::get<CertUtilsError>(v);
579             }
580             break;
581         }
582         default:
583             return CertUtilsError::InvalidArgument;
584         }
585         break;
586     default:
587         return CertUtilsError::InvalidArgument;
588     }
589 
590     X509_ALGOR_Ptr result(X509_ALGOR_new());
591     if (!result) {
592         return CertUtilsError::MemoryAllocation;
593     }
594     if (!X509_ALGOR_set0(result.get(), OBJ_nid2obj(nid), param_type, param.get())) {
595         return CertUtilsError::Encoding;
596     }
597     // The X509 struct took ownership.
598     param.release();
599     return result;
600 }
601 
602 // This function allows for signing a
signCertWith(X509 * certificate,std::function<std::vector<uint8_t> (const uint8_t *,size_t)> sign,Algo algo,Padding padding,Digest digest)603 CertUtilsError signCertWith(X509* certificate,
604                             std::function<std::vector<uint8_t>(const uint8_t*, size_t)> sign,
605                             Algo algo, Padding padding, Digest digest) {
606     auto algo_objV = makeAlgo(algo, padding, digest);
607     if (auto error = std::get_if<CertUtilsError>(&algo_objV)) {
608         return *error;
609     }
610     auto& algo_obj = std::get<X509_ALGOR_Ptr>(algo_objV);
611     if (!X509_set1_signature_algo(certificate, algo_obj.get())) {
612         return CertUtilsError::BoringSsl;
613     }
614 
615     uint8_t* cert_buf = nullptr;
616     int buf_len = i2d_re_X509_tbs(certificate, &cert_buf);
617     if (buf_len < 0) {
618         return CertUtilsError::Encoding;
619     }
620 
621     bssl::UniquePtr<uint8_t> free_cert_buf(cert_buf);
622     auto signature = sign(cert_buf, buf_len);
623     if (signature.empty()) {
624         return CertUtilsError::SignatureFailed;
625     }
626 
627     if (!X509_set1_signature_value(certificate, signature.data(), signature.size())) {
628         return CertUtilsError::BoringSsl;
629     }
630 
631     return CertUtilsError::Ok;
632 }
633 
634 }  // namespace keystore
635