1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "hpke_jni.h"
18 #include <openssl/curve25519.h>
19 #include <openssl/hpke.h>
20 #include <openssl/span.h>
21 #include <vector>
22 
23 // Hybrid Public Key Encryption (HPKE) encryption operation
24 // RFC: https://datatracker.ietf.org/doc/rfc9180
25 //
26 // Based from chromium's boringSSL implementation
27 // https://source.chromium.org/chromium/chromium/src/+/main:content/browser/aggregation_service/aggregatable_report.cc;l=211
Java_com_android_federatedcompute_services_encryption_jni_HpkeJni_encrypt(JNIEnv * env,jobject object,jbyteArray publicKey,jbyteArray plainText,jbyteArray associatedData)28 JNIEXPORT jbyteArray JNICALL Java_com_android_federatedcompute_services_encryption_jni_HpkeJni_encrypt
29         (JNIEnv* env, jobject object,
30          jbyteArray publicKey, jbyteArray plainText, jbyteArray associatedData) {
31 
32     if (publicKey == NULL || plainText == NULL || associatedData == NULL) {
33         return {};
34     }
35 
36     if (env->GetArrayLength(publicKey) != X25519_PUBLIC_VALUE_LEN) {
37         return {};
38     }
39 
40     bssl::ScopedEVP_HPKE_CTX sender_context;
41 
42     std::vector<uint8_t> payload(EVP_HPKE_MAX_ENC_LENGTH);
43     size_t encapsulated_shared_secret_len;
44 
45     jbyte* peer_public_key = env->GetByteArrayElements(publicKey, 0);
46     jbyte* info = env->GetByteArrayElements(associatedData, 0);
47 
48     if (!EVP_HPKE_CTX_setup_sender(
49             /* ctx= */ sender_context.get(),
50             /* out_enc= */ payload.data(),
51             /* out_enc_len= */ &encapsulated_shared_secret_len,
52             /* max_enc= */ payload.size(),
53             /* kem= */ EVP_hpke_x25519_hkdf_sha256(),
54             /* kdf= */ EVP_hpke_hkdf_sha256(),
55             /* aead= */ EVP_hpke_chacha20_poly1305(),
56             /* peer_public_key= */ reinterpret_cast<const uint8_t*>(peer_public_key),
57             /* peer_public_key_len= */ env->GetArrayLength(publicKey),
58             /* info= */ reinterpret_cast<const uint8_t*>(info),
59             /* info_len= */ env->GetArrayLength(associatedData))) {
60         env->ReleaseByteArrayElements(publicKey, peer_public_key, JNI_ABORT);
61         env->ReleaseByteArrayElements(associatedData, info, JNI_ABORT);
62         return {};
63     }
64 
65     env->ReleaseByteArrayElements(publicKey, peer_public_key, JNI_ABORT);
66     env->ReleaseByteArrayElements(associatedData, info, JNI_ABORT);
67 
68     payload.resize(encapsulated_shared_secret_len + env->GetArrayLength(plainText) +
69                    EVP_HPKE_CTX_max_overhead(sender_context.get()));
70 
71     bssl::Span<uint8_t> ciphertext = bssl::MakeSpan(payload).subspan(encapsulated_shared_secret_len);
72     size_t ciphertext_len;
73 
74     jbyte* plain_text = env->GetByteArrayElements(plainText, 0);
75 
76     if (!EVP_HPKE_CTX_seal(
77             /* ctx= */ sender_context.get(),
78             /* out= */ ciphertext.data(),
79             /* out_len= */ &ciphertext_len,
80             /* max_out_len= */ ciphertext.size(),
81             /* in= */ reinterpret_cast<const uint8_t*>(plain_text),
82             /* in_len=*/ env->GetArrayLength(plainText),
83             /* ad= */ nullptr,
84             /* ad_len= */ 0)) {
85         env->ReleaseByteArrayElements(plainText, plain_text, JNI_ABORT);
86         return {};
87     }
88 
89     env->ReleaseByteArrayElements(plainText, plain_text, JNI_ABORT);
90 
91     payload.resize(encapsulated_shared_secret_len + ciphertext_len);
92 
93     jbyteArray payload_byte_array = env->NewByteArray(payload.size());
94     env->SetByteArrayRegion(payload_byte_array, 0, payload.size(),
95                             reinterpret_cast<const jbyte*>(payload.data()));
96     return payload_byte_array;
97 }
98 
99 // Hybrid Public Key Encryption (HPKE) decryption operation
100 // RFC: https://datatracker.ietf.org/doc/rfc9180
101 //
102 // Based from chromium's boringSSL implementation
103 // https://source.chromium.org/chromium/chromium/src/+/main:content/browser/aggregation_service/aggregation_service_test_utils.cc;l=305
Java_com_android_federatedcompute_services_encryption_jni_HpkeJni_decrypt(JNIEnv * env,jobject object,jbyteArray privateKey,jbyteArray ciphertext,jbyteArray associatedData)104 JNIEXPORT jbyteArray JNICALL Java_com_android_federatedcompute_services_encryption_jni_HpkeJni_decrypt
105         (JNIEnv* env, jobject object,
106          jbyteArray privateKey, jbyteArray ciphertext, jbyteArray associatedData) {
107 
108     if (privateKey == NULL || ciphertext == NULL || associatedData == NULL) {
109         return {};
110     }
111 
112     if (env->GetArrayLength(privateKey) != X25519_PRIVATE_KEY_LEN) {
113         return {};
114     }
115 
116     bssl::ScopedEVP_HPKE_KEY hpke_key;
117     jbyte* private_key = env->GetByteArrayElements(privateKey, 0);
118     if (!EVP_HPKE_KEY_init(hpke_key.get(),
119                            EVP_hpke_x25519_hkdf_sha256(),
120                            reinterpret_cast<const uint8_t*>(private_key),
121                            env->GetArrayLength(privateKey))) {
122         env->ReleaseByteArrayElements(privateKey, private_key, JNI_ABORT);
123         return {};
124     }
125 
126     env->ReleaseByteArrayElements(privateKey, private_key, JNI_ABORT);
127 
128     std::vector<uint8_t> payload(env->GetArrayLength(ciphertext));
129     env->GetByteArrayRegion(ciphertext,
130                             0,
131                             env->GetArrayLength(ciphertext),
132                             reinterpret_cast<jbyte*>(payload.data()));
133 
134     bssl::Span<uint8_t> payload_span = bssl::MakeSpan(payload);
135     bssl::Span<uint8_t> enc = payload_span.subspan(0, X25519_PUBLIC_VALUE_LEN);
136     bssl::ScopedEVP_HPKE_CTX recipient_context;
137     jbyte* associated_data = env->GetByteArrayElements(associatedData, 0);
138     if (!EVP_HPKE_CTX_setup_recipient(
139             /*ctx=*/ recipient_context.get(),
140             /*key=*/ hpke_key.get(),
141             /*kdf=*/ EVP_hpke_hkdf_sha256(),
142             /*aead=*/ EVP_hpke_chacha20_poly1305(),
143             /*enc=*/ enc.data(),
144             /*enc_len=*/ enc.size(),
145             /*info=*/ reinterpret_cast<const uint8_t*>(associated_data),
146             /*info_len=*/ env->GetArrayLength(associatedData))) {
147         env->ReleaseByteArrayElements(associatedData, associated_data, JNI_ABORT);
148         return {};
149     }
150 
151     env->ReleaseByteArrayElements(associatedData, associated_data, JNI_ABORT);
152 
153     bssl::Span<const uint8_t> ciphertext_span = payload_span.subspan(X25519_PUBLIC_VALUE_LEN);
154     std::vector<uint8_t> plaintext(ciphertext_span.size());
155     size_t plaintext_len;
156     if (!EVP_HPKE_CTX_open(
157             /*ctx=*/ recipient_context.get(),
158             /*out=*/ plaintext.data(),
159             /*out_len*/ &plaintext_len,
160             /*max_out_len=*/ plaintext.size(),
161             /*in=*/ ciphertext_span.data(),
162             /*in_len=*/ ciphertext_span.size(),
163             /*ad=*/ nullptr,
164             /*ad_len=*/ 0)) {
165         return {};
166     }
167 
168     plaintext.resize(plaintext_len);
169 
170     jbyteArray payload_byte_array = env->NewByteArray(plaintext.size());
171     env->SetByteArrayRegion(payload_byte_array,
172                             0,
173                             plaintext.size(),
174                             reinterpret_cast<const jbyte*>(plaintext.data()));
175     return payload_byte_array;
176 }