1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <keymaster/cppcose/cppcose.h>
18 
19 #include <iostream>
20 #include <stdio.h>
21 
22 #include <cppbor.h>
23 #include <cppbor_parse.h>
24 #include <openssl/ecdsa.h>
25 
26 #include <openssl/err.h>
27 
28 namespace cppcose {
29 constexpr int kP256AffinePointSize = 32;
30 constexpr int kP384AffinePointSize = 48;
31 
32 using EVP_PKEY_Ptr = bssl::UniquePtr<EVP_PKEY>;
33 using EVP_PKEY_CTX_Ptr = bssl::UniquePtr<EVP_PKEY_CTX>;
34 using ECDSA_SIG_Ptr = bssl::UniquePtr<ECDSA_SIG>;
35 using EC_KEY_Ptr = bssl::UniquePtr<EC_KEY>;
36 
37 namespace {
38 
aesGcmInitAndProcessAad(const bytevec & key,const bytevec & nonce,const bytevec & aad,bool encrypt)39 ErrMsgOr<bssl::UniquePtr<EVP_CIPHER_CTX>> aesGcmInitAndProcessAad(const bytevec& key,
40                                                                   const bytevec& nonce,
41                                                                   const bytevec& aad,
42                                                                   bool encrypt) {
43     if (key.size() != kAesGcmKeySize) return "Invalid key size";
44 
45     bssl::UniquePtr<EVP_CIPHER_CTX> ctx(EVP_CIPHER_CTX_new());
46     if (!ctx) return "Failed to allocate cipher context";
47 
48     if (!EVP_CipherInit_ex(ctx.get(), EVP_aes_256_gcm(), nullptr /* engine */, key.data(),
49                            nonce.data(), encrypt ? 1 : 0)) {
50         return "Failed to initialize cipher";
51     }
52 
53     int outlen;
54     if (!aad.empty() && !EVP_CipherUpdate(ctx.get(), nullptr /* out; null means AAD */, &outlen,
55                                           aad.data(), aad.size())) {
56         return "Failed to process AAD";
57     }
58 
59     return std::move(ctx);
60 }
61 
signP256Digest(const bytevec & key,const bytevec & data)62 ErrMsgOr<bytevec> signP256Digest(const bytevec& key, const bytevec& data) {
63     auto bn = BIGNUM_Ptr(BN_bin2bn(key.data(), key.size(), nullptr));
64     if (bn.get() == nullptr) {
65         return "Error creating BIGNUM";
66     }
67 
68     auto ec_key = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
69     if (EC_KEY_set_private_key(ec_key.get(), bn.get()) != 1) {
70         return "Error setting private key from BIGNUM";
71     }
72 
73     auto sig = ECDSA_SIG_Ptr(ECDSA_do_sign(data.data(), data.size(), ec_key.get()));
74     if (sig == nullptr) {
75         return "Error signing digest";
76     }
77     size_t len = i2d_ECDSA_SIG(sig.get(), nullptr);
78     bytevec signature(len);
79     unsigned char* p = (unsigned char*)signature.data();
80     i2d_ECDSA_SIG(sig.get(), &p);
81     return signature;
82 }
83 
ecdh(const bytevec & publicKey,const bytevec & privateKey)84 ErrMsgOr<bytevec> ecdh(const bytevec& publicKey, const bytevec& privateKey) {
85     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1));
86     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
87     if (EC_POINT_oct2point(group.get(), point.get(), publicKey.data(), publicKey.size(), nullptr) !=
88         1) {
89         return "Error decoding publicKey";
90     }
91     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
92     auto pkey = EVP_PKEY_Ptr(EVP_PKEY_new());
93     if (ecKey.get() == nullptr || pkey.get() == nullptr) {
94         return "Memory allocation failed";
95     }
96     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
97         return "Error setting group";
98     }
99     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
100         return "Error setting point";
101     }
102     if (EVP_PKEY_set1_EC_KEY(pkey.get(), ecKey.get()) != 1) {
103         return "Error setting key";
104     }
105 
106     auto bn = BIGNUM_Ptr(BN_bin2bn(privateKey.data(), privateKey.size(), nullptr));
107     if (bn.get() == nullptr) {
108         return "Error creating BIGNUM for private key";
109     }
110     auto privEcKey = EC_KEY_Ptr(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
111     if (EC_KEY_set_private_key(privEcKey.get(), bn.get()) != 1) {
112         return "Error setting private key from BIGNUM";
113     }
114     auto privPkey = EVP_PKEY_Ptr(EVP_PKEY_new());
115     if (EVP_PKEY_set1_EC_KEY(privPkey.get(), privEcKey.get()) != 1) {
116         return "Error setting private key";
117     }
118 
119     auto ctx = EVP_PKEY_CTX_Ptr(EVP_PKEY_CTX_new(privPkey.get(), NULL));
120     if (ctx.get() == nullptr) {
121         return "Error creating context";
122     }
123 
124     if (EVP_PKEY_derive_init(ctx.get()) != 1) {
125         return "Error initializing context";
126     }
127 
128     if (EVP_PKEY_derive_set_peer(ctx.get(), pkey.get()) != 1) {
129         return "Error setting peer";
130     }
131 
132     /* Determine buffer length for shared secret */
133     size_t secretLen = 0;
134     if (EVP_PKEY_derive(ctx.get(), NULL, &secretLen) != 1) {
135         return "Error determing length of shared secret";
136     }
137     bytevec sharedSecret(secretLen);
138 
139     if (EVP_PKEY_derive(ctx.get(), sharedSecret.data(), &secretLen) != 1) {
140         return "Error deriving shared secret";
141     }
142     return sharedSecret;
143 }
144 
ecdsaCoseSignatureToDer(int point_size,const bytevec & ecdsaCoseSignature)145 ErrMsgOr<bytevec> ecdsaCoseSignatureToDer(int point_size, const bytevec& ecdsaCoseSignature) {
146     if (ecdsaCoseSignature.size() != (size_t)(point_size * 2)) {
147         return "COSE signature wrong length";
148     }
149 
150     auto rBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data(), point_size, nullptr));
151     if (rBn.get() == nullptr) {
152         return "Error creating BIGNUM for r";
153     }
154 
155     auto sBn = BIGNUM_Ptr(BN_bin2bn(ecdsaCoseSignature.data() + point_size, point_size, nullptr));
156     if (sBn.get() == nullptr) {
157         return "Error creating BIGNUM for s";
158     }
159 
160     ECDSA_SIG sig;
161     sig.r = rBn.get();
162     sig.s = sBn.get();
163 
164     size_t len = i2d_ECDSA_SIG(&sig, nullptr);
165     bytevec derSignature(len);
166     unsigned char* p = (unsigned char*)derSignature.data();
167     i2d_ECDSA_SIG(&sig, &p);
168     return derSignature;
169 }
170 
ecdsaDerSignatureToCose(int point_size,const bytevec & ecdsaSignature)171 ErrMsgOr<bytevec> ecdsaDerSignatureToCose(int point_size, const bytevec& ecdsaSignature) {
172     const unsigned char* p = ecdsaSignature.data();
173     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, ecdsaSignature.size()));
174     if (sig == nullptr) {
175         return "Error decoding DER signature";
176     }
177 
178     bytevec ecdsaCoseSignature(point_size * 2, 0);
179     if (BN_bn2binpad(ECDSA_SIG_get0_r(sig.get()), ecdsaCoseSignature.data(), point_size) !=
180         point_size) {
181         return "Error encoding r";
182     }
183     if (BN_bn2binpad(ECDSA_SIG_get0_s(sig.get()), ecdsaCoseSignature.data() + point_size,
184                      point_size) != point_size) {
185         return "Error encoding s";
186     }
187     return ecdsaCoseSignature;
188 }
189 
verifyEcdsaDigest(int curve_nid,const bytevec & key,const bytevec & digest,const bytevec & signature)190 bool verifyEcdsaDigest(int curve_nid, const bytevec& key, const bytevec& digest,
191                        const bytevec& signature) {
192     const unsigned char* p = (unsigned char*)signature.data();
193     auto sig = ECDSA_SIG_Ptr(d2i_ECDSA_SIG(nullptr, &p, signature.size()));
194     if (sig.get() == nullptr) {
195         return false;
196     }
197 
198     auto group = EC_GROUP_Ptr(EC_GROUP_new_by_curve_name(curve_nid));
199     auto point = EC_POINT_Ptr(EC_POINT_new(group.get()));
200     if (EC_POINT_oct2point(group.get(), point.get(), key.data(), key.size(), nullptr) != 1) {
201         return false;
202     }
203     auto ecKey = EC_KEY_Ptr(EC_KEY_new());
204     if (ecKey.get() == nullptr) {
205         return false;
206     }
207     if (EC_KEY_set_group(ecKey.get(), group.get()) != 1) {
208         return false;
209     }
210     if (EC_KEY_set_public_key(ecKey.get(), point.get()) != 1) {
211         return false;
212     }
213 
214     int rc = ECDSA_do_verify(digest.data(), digest.size(), sig.get(), ecKey.get());
215     if (rc != 1) {
216         return false;
217     }
218     return true;
219 }
220 
221 }  // namespace
222 
generateHmacSha256(const bytevec & key,const bytevec & data)223 ErrMsgOr<HmacSha256> generateHmacSha256(const bytevec& key, const bytevec& data) {
224     HmacSha256 digest;
225     unsigned int outLen;
226     uint8_t* out = HMAC(EVP_sha256(),              //
227                         key.data(), key.size(),    //
228                         data.data(), data.size(),  //
229                         digest.data(), &outLen);
230 
231     if (out == nullptr || outLen != digest.size()) {
232         return "Error generating HMAC";
233     }
234     return digest;
235 }
236 
generateCoseMac0Mac(HmacSha256Function macFunction,const bytevec & externalAad,const bytevec & payload)237 ErrMsgOr<HmacSha256> generateCoseMac0Mac(HmacSha256Function macFunction, const bytevec& externalAad,
238                                          const bytevec& payload) {
239     auto macStructure = cppbor::Array()
240                             .add("MAC0")
241                             .add(cppbor::Map().add(ALGORITHM, HMAC_256).canonicalize().encode())
242                             .add(externalAad)
243                             .add(payload)
244                             .encode();
245 
246     auto macTag = macFunction(macStructure);
247     if (!macTag) {
248         return "Error computing public key MAC";
249     }
250 
251     return *macTag;
252 }
253 
constructCoseMac0(HmacSha256Function macFunction,const bytevec & externalAad,const bytevec & payload)254 ErrMsgOr<cppbor::Array> constructCoseMac0(HmacSha256Function macFunction,
255                                           const bytevec& externalAad, const bytevec& payload) {
256     auto tag = generateCoseMac0Mac(macFunction, externalAad, payload);
257     if (!tag) return tag.moveMessage();
258 
259     return cppbor::Array()
260         .add(cppbor::Map().add(ALGORITHM, HMAC_256).canonicalize().encode())
261         .add(cppbor::Map() /* unprotected */)
262         .add(payload)
263         .add(std::pair(tag->begin(), tag->end()));
264 }
265 
verifyAndParseCoseMac0(const cppbor::Item * macItem,const bytevec & macKey)266 ErrMsgOr<bytevec /* payload */> verifyAndParseCoseMac0(const cppbor::Item* macItem,
267                                                        const bytevec& macKey) {
268     auto mac = macItem ? macItem->asArray() : nullptr;
269     if (!mac || mac->size() != kCoseMac0EntryCount) {
270         return "Invalid COSE_Mac0";
271     }
272 
273     auto protectedParms = mac->get(kCoseMac0ProtectedParams)->asBstr();
274     auto unprotectedParms = mac->get(kCoseMac0UnprotectedParams)->asMap();
275     auto payload = mac->get(kCoseMac0Payload)->asBstr();
276     auto tag = mac->get(kCoseMac0Tag)->asBstr();
277     if (!protectedParms || !unprotectedParms || !payload || !tag) {
278         return "Invalid COSE_Mac0 contents";
279     }
280 
281     auto [protectedMap, _, errMsg] = cppbor::parse(protectedParms);
282     if (!protectedMap || !protectedMap->asMap()) {
283         return "Invalid Mac0 protected: " + errMsg;
284     }
285     auto& algo = protectedMap->asMap()->get(ALGORITHM);
286     if (!algo || !algo->asInt() || algo->asInt()->value() != HMAC_256) {
287         return "Unsupported Mac0 algorithm";
288     }
289 
290     auto macFunction = [&macKey](const bytevec& input) {
291         return generateHmacSha256(macKey, input);
292     };
293     auto macTag = generateCoseMac0Mac(macFunction, {} /* external_aad */, payload->value());
294     if (!macTag) return macTag.moveMessage();
295 
296     if (macTag->size() != tag->value().size() ||
297         CRYPTO_memcmp(macTag->data(), tag->value().data(), macTag->size()) != 0) {
298         return "MAC tag mismatch";
299     }
300 
301     return payload->value();
302 }
303 
createECDSACoseSign1Signature(const bytevec & key,const bytevec & protectedParams,const bytevec & payload,const bytevec & aad)304 ErrMsgOr<bytevec> createECDSACoseSign1Signature(const bytevec& key, const bytevec& protectedParams,
305                                                 const bytevec& payload, const bytevec& aad) {
306     bytevec signatureInput = cppbor::Array()
307                                  .add("Signature1")  //
308                                  .add(protectedParams)
309                                  .add(aad)
310                                  .add(payload)
311                                  .encode();
312     auto ecdsaSignature = signP256Digest(key, sha256(signatureInput));
313     if (!ecdsaSignature) return ecdsaSignature.moveMessage();
314 
315     return ecdsaDerSignatureToCose(kP256AffinePointSize, *ecdsaSignature);
316 }
317 
createCoseSign1Signature(const bytevec & key,const bytevec & protectedParams,const bytevec & payload,const bytevec & aad)318 ErrMsgOr<bytevec> createCoseSign1Signature(const bytevec& key, const bytevec& protectedParams,
319                                            const bytevec& payload, const bytevec& aad) {
320     bytevec signatureInput = cppbor::Array()
321                                  .add("Signature1")  //
322                                  .add(protectedParams)
323                                  .add(aad)
324                                  .add(payload)
325                                  .encode();
326 
327     if (key.size() != ED25519_PRIVATE_KEY_LEN) return "Invalid signing key";
328     bytevec signature(ED25519_SIGNATURE_LEN);
329     if (!ED25519_sign(signature.data(), signatureInput.data(), signatureInput.size(), key.data())) {
330         return "Signing failed";
331     }
332 
333     return signature;
334 }
335 
constructECDSACoseSign1(const bytevec & key,cppbor::Map protectedParams,const bytevec & payload,const bytevec & aad)336 ErrMsgOr<cppbor::Array> constructECDSACoseSign1(const bytevec& key, cppbor::Map protectedParams,
337                                                 const bytevec& payload, const bytevec& aad) {
338     bytevec protParms = protectedParams.add(ALGORITHM, ES256).canonicalize().encode();
339     auto signature = createECDSACoseSign1Signature(key, protParms, payload, aad);
340     if (!signature) return signature.moveMessage();
341 
342     return cppbor::Array()
343         .add(std::move(protParms))
344         .add(cppbor::Map() /* unprotected parameters */)
345         .add(std::move(payload))
346         .add(std::move(*signature));
347 }
348 
constructCoseSign1(const bytevec & key,cppbor::Map protectedParams,const bytevec & payload,const bytevec & aad)349 ErrMsgOr<cppbor::Array> constructCoseSign1(const bytevec& key, cppbor::Map protectedParams,
350                                            const bytevec& payload, const bytevec& aad) {
351     bytevec protParms = protectedParams.add(ALGORITHM, EDDSA).canonicalize().encode();
352     auto signature = createCoseSign1Signature(key, protParms, payload, aad);
353     if (!signature) return signature.moveMessage();
354 
355     return cppbor::Array()
356         .add(std::move(protParms))
357         .add(cppbor::Map() /* unprotected parameters */)
358         .add(std::move(payload))
359         .add(std::move(*signature));
360 }
361 
constructCoseSign1(const bytevec & key,const bytevec & payload,const bytevec & aad)362 ErrMsgOr<cppbor::Array> constructCoseSign1(const bytevec& key, const bytevec& payload,
363                                            const bytevec& aad) {
364     return constructCoseSign1(key, {} /* protectedParams */, payload, aad);
365 }
366 
verifyAndParseCoseSign1(const cppbor::Array * coseSign1,const bytevec & signingCoseKey,const bytevec & aad)367 ErrMsgOr<bytevec> verifyAndParseCoseSign1(const cppbor::Array* coseSign1,
368                                           const bytevec& signingCoseKey, const bytevec& aad) {
369     if (!coseSign1 || coseSign1->size() != kCoseSign1EntryCount) {
370         return "Invalid COSE_Sign1";
371     }
372 
373     const cppbor::Bstr* protectedParams = coseSign1->get(kCoseSign1ProtectedParams)->asBstr();
374     const cppbor::Map* unprotectedParams = coseSign1->get(kCoseSign1UnprotectedParams)->asMap();
375     const cppbor::Bstr* payload = coseSign1->get(kCoseSign1Payload)->asBstr();
376 
377     if (!protectedParams || !unprotectedParams || !payload) {
378         return "Missing input parameters";
379     }
380 
381     auto [parsedProtParams, _, errMsg] = cppbor::parse(protectedParams);
382     if (!parsedProtParams) {
383         return errMsg + " when parsing protected params.";
384     }
385     if (!parsedProtParams->asMap()) {
386         return "Protected params must be a map";
387     }
388 
389     auto& algorithm = parsedProtParams->asMap()->get(ALGORITHM);
390     if (!algorithm || !algorithm->asInt() ||
391         !(algorithm->asInt()->value() == EDDSA || algorithm->asInt()->value() == ES256 ||
392           algorithm->asInt()->value() == ES384)) {
393         return "Unsupported signature algorithm";
394     }
395 
396     const cppbor::Bstr* signature = coseSign1->get(kCoseSign1Signature)->asBstr();
397     if (!signature || signature->value().empty()) {
398         return "Missing signature input";
399     }
400 
401     bool selfSigned = signingCoseKey.empty();
402     bytevec signatureInput =
403         cppbor::Array().add("Signature1").add(*protectedParams).add(aad).add(*payload).encode();
404     if (algorithm->asInt()->value() == EDDSA) {
405         auto key = CoseKey::parseEd25519(selfSigned ? payload->value() : signingCoseKey);
406         if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty()) {
407             return "Bad signing key: " + key.moveMessage();
408         }
409 
410         if (!ED25519_verify(signatureInput.data(), signatureInput.size(), signature->value().data(),
411                             key->getBstrValue(CoseKey::PUBKEY_X)->data())) {
412             return "Signature verification failed";
413         }
414     } else if (algorithm->asInt()->value() == ES256) {
415         auto key = CoseKey::parseP256(selfSigned ? payload->value() : signingCoseKey);
416         if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty() ||
417             key->getBstrValue(CoseKey::PUBKEY_Y)->empty()) {
418             return "Bad signing key: " + key.moveMessage();
419         }
420         auto publicKey = key->getEcPublicKey();
421         if (!publicKey) return publicKey.moveMessage();
422 
423         auto ecdsaDerSignature = ecdsaCoseSignatureToDer(kP256AffinePointSize, signature->value());
424         if (!ecdsaDerSignature) return ecdsaDerSignature.moveMessage();
425 
426         // convert public key to uncompressed form by prepending 0x04 at begin.
427         publicKey->insert(publicKey->begin(), 0x04);
428 
429         if (!verifyEcdsaDigest(NID_X9_62_prime256v1, publicKey.moveValue(), sha256(signatureInput),
430                                *ecdsaDerSignature)) {
431             return "Signature verification failed";
432         }
433     } else {  // ES384
434         auto key = CoseKey::parseP384(selfSigned ? payload->value() : signingCoseKey);
435         if (!key || key->getBstrValue(CoseKey::PUBKEY_X)->empty() ||
436             key->getBstrValue(CoseKey::PUBKEY_Y)->empty()) {
437             return "Bad signing key: " + key.moveMessage();
438         }
439         auto publicKey = key->getEcPublicKey();
440         if (!publicKey) return publicKey.moveMessage();
441 
442         auto ecdsaDerSignature = ecdsaCoseSignatureToDer(kP384AffinePointSize, signature->value());
443         if (!ecdsaDerSignature) return ecdsaDerSignature.moveMessage();
444 
445         // convert public key to uncompressed form by prepending 0x04 at begin.
446         publicKey->insert(publicKey->begin(), 0x04);
447 
448         if (!verifyEcdsaDigest(NID_secp384r1, publicKey.moveValue(), sha384(signatureInput),
449                                *ecdsaDerSignature)) {
450             return "Signature verification failed";
451         }
452     }
453 
454     return payload->value();
455 }
456 
createCoseEncryptCiphertext(const bytevec & key,const bytevec & nonce,const bytevec & protectedParams,const bytevec & plaintextPayload,const bytevec & aad)457 ErrMsgOr<bytevec> createCoseEncryptCiphertext(const bytevec& key, const bytevec& nonce,
458                                               const bytevec& protectedParams,
459                                               const bytevec& plaintextPayload, const bytevec& aad) {
460     auto ciphertext = aesGcmEncrypt(key, nonce,
461                                     cppbor::Array()            // Enc strucure as AAD
462                                         .add("Encrypt")        // Context
463                                         .add(protectedParams)  // Protected
464                                         .add(aad)              // External AAD
465                                         .encode(),
466                                     plaintextPayload);
467 
468     if (!ciphertext) return ciphertext.moveMessage();
469     return ciphertext.moveValue();
470 }
471 
constructCoseEncrypt(const bytevec & key,const bytevec & nonce,const bytevec & plaintextPayload,const bytevec & aad,cppbor::Array recipients)472 ErrMsgOr<cppbor::Array> constructCoseEncrypt(const bytevec& key, const bytevec& nonce,
473                                              const bytevec& plaintextPayload, const bytevec& aad,
474                                              cppbor::Array recipients) {
475     auto encryptProtectedHeader = cppbor::Map()  //
476                                       .add(ALGORITHM, AES_GCM_256)
477                                       .canonicalize()
478                                       .encode();
479 
480     auto ciphertext =
481         createCoseEncryptCiphertext(key, nonce, encryptProtectedHeader, plaintextPayload, aad);
482     if (!ciphertext) return ciphertext.moveMessage();
483 
484     return cppbor::Array()
485         .add(encryptProtectedHeader)                       // Protected
486         .add(cppbor::Map().add(IV, nonce).canonicalize())  // Unprotected
487         .add(*ciphertext)                                  // Payload
488         .add(std::move(recipients));
489 }
490 
491 ErrMsgOr<std::pair<bytevec /* pubkey */, bytevec /* key ID */>>
getSenderPubKeyFromCoseEncrypt(const cppbor::Item * coseEncrypt)492 getSenderPubKeyFromCoseEncrypt(const cppbor::Item* coseEncrypt) {
493     if (!coseEncrypt || !coseEncrypt->asArray() ||
494         coseEncrypt->asArray()->size() != kCoseEncryptEntryCount) {
495         return "Invalid COSE_Encrypt";
496     }
497 
498     auto& recipients = coseEncrypt->asArray()->get(kCoseEncryptRecipients);
499     if (!recipients || !recipients->asArray() || recipients->asArray()->size() != 1) {
500         return "Invalid recipients list";
501     }
502 
503     auto& recipient = recipients->asArray()->get(0);
504     if (!recipient || !recipient->asArray() || recipient->asArray()->size() != 3) {
505         return "Invalid COSE_recipient";
506     }
507 
508     auto& ciphertext = recipient->asArray()->get(2);
509     if (!ciphertext->asSimple() || !ciphertext->asSimple()->asNull()) {
510         return "Unexpected value in recipients ciphertext field " +
511                cppbor::prettyPrint(ciphertext.get());
512     }
513 
514     auto& protParms = recipient->asArray()->get(0);
515     if (!protParms || !protParms->asBstr()) return "Invalid protected params";
516     auto [parsedProtParms, _, errMsg] = cppbor::parse(protParms->asBstr());
517     if (!parsedProtParms) return "Failed to parse protected params: " + errMsg;
518     if (!parsedProtParms->asMap()) return "Invalid protected params";
519 
520     auto& algorithm = parsedProtParms->asMap()->get(ALGORITHM);
521     if (!algorithm || !algorithm->asInt() || algorithm->asInt()->value() != ECDH_ES_HKDF_256) {
522         return "Invalid algorithm";
523     }
524 
525     auto& unprotParms = recipient->asArray()->get(1);
526     if (!unprotParms || !unprotParms->asMap()) return "Invalid unprotected params";
527 
528     auto& senderCoseKey = unprotParms->asMap()->get(COSE_KEY);
529     if (!senderCoseKey || !senderCoseKey->asMap()) return "Invalid sender COSE_Key";
530 
531     auto& keyType = senderCoseKey->asMap()->get(CoseKey::KEY_TYPE);
532     if (!keyType || !keyType->asInt() ||
533         (keyType->asInt()->value() != OCTET_KEY_PAIR && keyType->asInt()->value() != EC2)) {
534         return "Invalid key type";
535     }
536 
537     auto& curve = senderCoseKey->asMap()->get(CoseKey::CURVE);
538     if (!curve || !curve->asInt() ||
539         (keyType->asInt()->value() == OCTET_KEY_PAIR && curve->asInt()->value() != X25519) ||
540         (keyType->asInt()->value() == EC2 && curve->asInt()->value() != P256)) {
541         return "Unsupported curve";
542     }
543 
544     bytevec publicKey;
545     if (keyType->asInt()->value() == EC2) {
546         auto& pubX = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
547         if (!pubX || !pubX->asBstr() || pubX->asBstr()->value().size() != kP256AffinePointSize) {
548             return "Invalid EC public key";
549         }
550         auto& pubY = senderCoseKey->asMap()->get(CoseKey::PUBKEY_Y);
551         if (!pubY || !pubY->asBstr() || pubY->asBstr()->value().size() != kP256AffinePointSize) {
552             return "Invalid EC public key";
553         }
554         auto key = CoseKey::getEcPublicKey(pubX->asBstr()->value(), pubY->asBstr()->value());
555         if (!key) return key.moveMessage();
556         publicKey = key.moveValue();
557     } else {
558         auto& pubkey = senderCoseKey->asMap()->get(CoseKey::PUBKEY_X);
559         if (!pubkey || !pubkey->asBstr() ||
560             pubkey->asBstr()->value().size() != X25519_PUBLIC_VALUE_LEN) {
561             return "Invalid X25519 public key";
562         }
563         publicKey = pubkey->asBstr()->value();
564     }
565 
566     auto& key_id = unprotParms->asMap()->get(KEY_ID);
567     if (key_id && key_id->asBstr()) {
568         return std::make_pair(publicKey, key_id->asBstr()->value());
569     }
570 
571     // If no key ID, just return an empty vector.
572     return std::make_pair(publicKey, bytevec{});
573 }
574 
decryptCoseEncrypt(const bytevec & key,const cppbor::Item * coseEncrypt,const bytevec & external_aad)575 ErrMsgOr<bytevec> decryptCoseEncrypt(const bytevec& key, const cppbor::Item* coseEncrypt,
576                                      const bytevec& external_aad) {
577     if (!coseEncrypt || !coseEncrypt->asArray() ||
578         coseEncrypt->asArray()->size() != kCoseEncryptEntryCount) {
579         return "Invalid COSE_Encrypt";
580     }
581 
582     auto& protParms = coseEncrypt->asArray()->get(kCoseEncryptProtectedParams);
583     auto& unprotParms = coseEncrypt->asArray()->get(kCoseEncryptUnprotectedParams);
584     auto& ciphertext = coseEncrypt->asArray()->get(kCoseEncryptPayload);
585     auto& recipients = coseEncrypt->asArray()->get(kCoseEncryptRecipients);
586 
587     if (!protParms || !protParms->asBstr() || !unprotParms || !ciphertext || !recipients) {
588         return "Invalid COSE_Encrypt";
589     }
590 
591     auto [parsedProtParams, _, errMsg] = cppbor::parse(protParms->asBstr()->value());
592     if (!parsedProtParams) {
593         return errMsg + " when parsing protected params.";
594     }
595     if (!parsedProtParams->asMap()) {
596         return "Protected params must be a map";
597     }
598 
599     auto& algorithm = parsedProtParams->asMap()->get(ALGORITHM);
600     if (!algorithm || !algorithm->asInt() || algorithm->asInt()->value() != AES_GCM_256) {
601         return "Unsupported encryption algorithm";
602     }
603 
604     if (!unprotParms->asMap() || unprotParms->asMap()->size() != 1) {
605         return "Invalid unprotected params";
606     }
607 
608     auto& nonce = unprotParms->asMap()->get(IV);
609     if (!nonce || !nonce->asBstr() || nonce->asBstr()->value().size() != kAesGcmNonceLength) {
610         return "Invalid nonce";
611     }
612 
613     if (!ciphertext->asBstr()) return "Invalid ciphertext";
614 
615     auto aad = cppbor::Array()                         // Enc strucure as AAD
616                    .add("Encrypt")                     // Context
617                    .add(protParms->asBstr()->value())  // Protected
618                    .add(external_aad)                  // External AAD
619                    .encode();
620 
621     return aesGcmDecrypt(key, nonce->asBstr()->value(), aad, ciphertext->asBstr()->value());
622 }
623 
consructKdfContext(const bytevec & pubKeyA,const bytevec & privKeyA,const bytevec & pubKeyB,bool senderIsA)624 ErrMsgOr<bytevec> consructKdfContext(const bytevec& pubKeyA, const bytevec& privKeyA,
625                                      const bytevec& pubKeyB, bool senderIsA) {
626     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
627         return "Missing input key parameters";
628     }
629 
630     bytevec kdfContext = cppbor::Array()
631                              .add(AES_GCM_256)
632                              .add(cppbor::Array()  // Sender Info
633                                       .add(cppbor::Bstr("client"))
634                                       .add(bytevec{} /* nonce */)
635                                       .add(senderIsA ? pubKeyA : pubKeyB))
636                              .add(cppbor::Array()  // Recipient Info
637                                       .add(cppbor::Bstr("server"))
638                                       .add(bytevec{} /* nonce */)
639                                       .add(senderIsA ? pubKeyB : pubKeyA))
640                              .add(cppbor::Array()               // SuppPubInfo
641                                       .add(kAesGcmKeySizeBits)  // output key length
642                                       .add(bytevec{}))          // protected
643                              .encode();
644     return kdfContext;
645 }
646 
ECDH_HKDF_DeriveKey(const bytevec & pubKeyA,const bytevec & privKeyA,const bytevec & pubKeyB,bool senderIsA)647 ErrMsgOr<bytevec> ECDH_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
648                                       const bytevec& pubKeyB, bool senderIsA) {
649     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
650         return "Missing input key parameters";
651     }
652 
653     // convert public key to uncompressed form by prepending 0x04 at begin
654     bytevec publicKey;
655     publicKey.insert(publicKey.begin(), 0x04);
656     publicKey.insert(publicKey.end(), pubKeyB.begin(), pubKeyB.end());
657     auto rawSharedKey = ecdh(publicKey, privKeyA);
658     if (!rawSharedKey) return rawSharedKey.moveMessage();
659 
660     auto kdfContext = consructKdfContext(pubKeyA, privKeyA, pubKeyB, senderIsA);
661     if (!kdfContext) return kdfContext.moveMessage();
662 
663     bytevec retval(SHA256_DIGEST_LENGTH);
664     bytevec salt{};
665     if (!HKDF(retval.data(), retval.size(),                //
666               EVP_sha256(),                                //
667               rawSharedKey->data(), rawSharedKey->size(),  //
668               salt.data(), salt.size(),                    //
669               kdfContext->data(), kdfContext->size())) {
670         return "ECDH HKDF failed";
671     }
672 
673     return retval;
674 }
675 
x25519_HKDF_DeriveKey(const bytevec & pubKeyA,const bytevec & privKeyA,const bytevec & pubKeyB,bool senderIsA)676 ErrMsgOr<bytevec> x25519_HKDF_DeriveKey(const bytevec& pubKeyA, const bytevec& privKeyA,
677                                         const bytevec& pubKeyB, bool senderIsA) {
678     if (privKeyA.empty() || pubKeyA.empty() || pubKeyB.empty()) {
679         return "Missing input key parameters";
680     }
681 
682     bytevec rawSharedKey(X25519_SHARED_KEY_LEN);
683     if (!::X25519(rawSharedKey.data(), privKeyA.data(), pubKeyB.data())) {
684         return "ECDH operation failed";
685     }
686 
687     auto kdfContext = consructKdfContext(pubKeyA, privKeyA, pubKeyB, senderIsA);
688     if (!kdfContext) return kdfContext.moveMessage();
689 
690     bytevec retval(SHA256_DIGEST_LENGTH);
691     bytevec salt{};
692     if (!HKDF(retval.data(), retval.size(),              //
693               EVP_sha256(),                              //
694               rawSharedKey.data(), rawSharedKey.size(),  //
695               salt.data(), salt.size(),                  //
696               kdfContext->data(), kdfContext->size())) {
697         return "ECDH HKDF failed";
698     }
699 
700     return retval;
701 }
702 
aesGcmEncrypt(const bytevec & key,const bytevec & nonce,const bytevec & aad,const bytevec & plaintext)703 ErrMsgOr<bytevec> aesGcmEncrypt(const bytevec& key, const bytevec& nonce, const bytevec& aad,
704                                 const bytevec& plaintext) {
705     auto ctx = aesGcmInitAndProcessAad(key, nonce, aad, true /* encrypt */);
706     if (!ctx) return ctx.moveMessage();
707 
708     bytevec ciphertext(plaintext.size() + kAesGcmTagSize);
709     int outlen;
710     if (!EVP_CipherUpdate(ctx->get(), ciphertext.data(), &outlen, plaintext.data(),
711                           plaintext.size())) {
712         return "Failed to encrypt plaintext";
713     }
714     assert(plaintext.size() == static_cast<uint64_t>(outlen));
715 
716     if (!EVP_CipherFinal_ex(ctx->get(), ciphertext.data() + outlen, &outlen)) {
717         return "Failed to finalize encryption";
718     }
719     assert(outlen == 0);
720 
721     if (!EVP_CIPHER_CTX_ctrl(ctx->get(), EVP_CTRL_GCM_GET_TAG, kAesGcmTagSize,
722                              ciphertext.data() + plaintext.size())) {
723         return "Failed to retrieve tag";
724     }
725 
726     return ciphertext;
727 }
728 
aesGcmDecrypt(const bytevec & key,const bytevec & nonce,const bytevec & aad,const bytevec & ciphertextWithTag)729 ErrMsgOr<bytevec> aesGcmDecrypt(const bytevec& key, const bytevec& nonce, const bytevec& aad,
730                                 const bytevec& ciphertextWithTag) {
731     auto ctx = aesGcmInitAndProcessAad(key, nonce, aad, false /* encrypt */);
732     if (!ctx) return ctx.moveMessage();
733 
734     if (ciphertextWithTag.size() < kAesGcmTagSize) return "Missing tag";
735 
736     bytevec plaintext(ciphertextWithTag.size() - kAesGcmTagSize);
737     int outlen;
738     if (!EVP_CipherUpdate(ctx->get(), plaintext.data(), &outlen, ciphertextWithTag.data(),
739                           ciphertextWithTag.size() - kAesGcmTagSize)) {
740         return "Failed to decrypt plaintext";
741     }
742     assert(plaintext.size() == static_cast<uint64_t>(outlen));
743 
744     bytevec tag(ciphertextWithTag.end() - kAesGcmTagSize, ciphertextWithTag.end());
745     if (!EVP_CIPHER_CTX_ctrl(ctx->get(), EVP_CTRL_GCM_SET_TAG, kAesGcmTagSize, tag.data())) {
746         return "Failed to set tag: " + std::to_string(ERR_peek_last_error());
747     }
748 
749     if (!EVP_CipherFinal_ex(ctx->get(), nullptr, &outlen)) {
750         return "Failed to finalize encryption";
751     }
752     assert(outlen == 0);
753 
754     return plaintext;
755 }
756 
sha256(const bytevec & data)757 bytevec sha256(const bytevec& data) {
758     bytevec ret(SHA256_DIGEST_LENGTH);
759     SHA256_CTX ctx;
760     SHA256_Init(&ctx);
761     SHA256_Update(&ctx, data.data(), data.size());
762     SHA256_Final((unsigned char*)ret.data(), &ctx);
763     return ret;
764 }
765 
sha384(const bytevec & data)766 bytevec sha384(const bytevec& data) {
767     bytevec ret(SHA384_DIGEST_LENGTH);
768     SHA512_CTX ctx;
769     SHA384_Init(&ctx);
770     SHA384_Update(&ctx, data.data(), data.size());
771     SHA384_Final((unsigned char*)ret.data(), &ctx);
772     return ret;
773 }
774 
775 }  // namespace cppcose
776