1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.rkpdapp.provisioner;
18 
19 import android.content.Context;
20 import android.os.RemoteException;
21 import android.util.Log;
22 
23 import com.android.rkpdapp.GeekResponse;
24 import com.android.rkpdapp.RkpdException;
25 import com.android.rkpdapp.database.InstantConverter;
26 import com.android.rkpdapp.database.ProvisionedKey;
27 import com.android.rkpdapp.database.ProvisionedKeyDao;
28 import com.android.rkpdapp.database.RkpKey;
29 import com.android.rkpdapp.interfaces.ServerInterface;
30 import com.android.rkpdapp.interfaces.SystemInterface;
31 import com.android.rkpdapp.metrics.ProvisioningAttempt;
32 import com.android.rkpdapp.utils.Settings;
33 import com.android.rkpdapp.utils.StatsProcessor;
34 import com.android.rkpdapp.utils.X509Utils;
35 
36 import java.security.cert.X509Certificate;
37 import java.time.Instant;
38 import java.time.temporal.ChronoUnit;
39 import java.util.ArrayList;
40 import java.util.Arrays;
41 import java.util.List;
42 
43 import co.nstant.in.cbor.CborException;
44 
45 /**
46  * Provides an easy package to run the provisioning process from start to finish, interfacing
47  * with the system interface and the server backend in order to provision attestation certificates
48  * to the device.
49  */
50 public class Provisioner {
51     private static final String TAG = "RkpdProvisioner";
52     private static final int FAILURE_MAXIMUM = 5;
53     private static final Object provisionKeysLock = new Object();
54 
55     private final Context mContext;
56     private final ProvisionedKeyDao mKeyDao;
57     private final boolean mIsAsync;
58 
Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao, boolean isAsync)59     public Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao,
60             boolean isAsync) {
61         mContext = applicationContext;
62         mKeyDao = keyDao;
63         mIsAsync = isAsync;
64     }
65 
66     /**
67      * Check to see if we need to perform provisioning or not for the given
68      * IRemotelyProvisionedComponent.
69      * @param serviceName the name of the remotely provisioned component to be provisioned
70      * @return true if the remotely provisioned component requires more keys, false if the pool
71      *         of available keys is healthy.
72      */
isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName)73     public boolean isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName) {
74         return calculateKeysRequired(metrics, serviceName) > 0;
75     }
76 
77     /**
78      * Generate, sign and store remotely provisioned keys.
79      */
provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface, GeekResponse geekResponse)80     public void provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface,
81             GeekResponse geekResponse) throws CborException, RkpdException, InterruptedException {
82         synchronized (provisionKeysLock) {
83             try {
84                 int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName());
85                 Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired);
86                 if (keysRequired == 0) {
87                     metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED);
88                     return;
89                 }
90 
91                 List<RkpKey> keysGenerated = generateKeys(metrics, keysRequired, systemInterface);
92                 checkForInterrupts();
93                 List<byte[]> certChains = fetchCertificates(metrics, keysGenerated, systemInterface,
94                         geekResponse);
95                 checkForInterrupts();
96                 List<ProvisionedKey> keys = associateCertsWithKeys(certChains, keysGenerated);
97 
98                 mKeyDao.insertKeys(keys);
99                 Log.i(TAG, "Total provisioned keys: " + keys.size());
100                 metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED);
101             } catch (InterruptedException e) {
102                 metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED);
103                 throw e;
104             } catch (RkpdException e) {
105                 if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
106                     Log.e(TAG, "Too many failures, resetting defaults.");
107                     Settings.resetDefaultConfig(mContext);
108                 }
109                 // Rethrow to provide failure signal to caller
110                 throw e;
111             }
112         }
113     }
114 
generateKeys(ProvisioningAttempt metrics, int numKeysRequired, SystemInterface systemInterface)115     private List<RkpKey> generateKeys(ProvisioningAttempt metrics, int numKeysRequired,
116             SystemInterface systemInterface)
117             throws CborException, RkpdException, InterruptedException {
118         List<RkpKey> keyArray = new ArrayList<>(numKeysRequired);
119         checkForInterrupts();
120         for (long i = 0; i < numKeysRequired; i++) {
121             keyArray.add(systemInterface.generateKey(metrics));
122         }
123         return keyArray;
124     }
125 
fetchCertificates(ProvisioningAttempt metrics, List<RkpKey> keysGenerated, SystemInterface systemInterface, GeekResponse geekResponse)126     private List<byte[]> fetchCertificates(ProvisioningAttempt metrics, List<RkpKey> keysGenerated,
127             SystemInterface systemInterface, GeekResponse geekResponse)
128             throws RkpdException, CborException, InterruptedException {
129         int provisionedSoFar = 0;
130         List<byte[]> certChains = new ArrayList<>(keysGenerated.size());
131         int maxBatchSize;
132         try {
133             maxBatchSize = systemInterface.getBatchSize();
134         } catch (RemoteException e) {
135             throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
136                     "Error getting batch size from the system", e);
137         }
138         while (provisionedSoFar != keysGenerated.size()) {
139             int batchSize = Math.min(keysGenerated.size() - provisionedSoFar, maxBatchSize);
140             certChains.addAll(batchProvision(metrics, systemInterface, geekResponse,
141                     keysGenerated.subList(provisionedSoFar, batchSize + provisionedSoFar)));
142             provisionedSoFar += batchSize;
143         }
144         return certChains;
145     }
146 
batchProvision(ProvisioningAttempt metrics, SystemInterface systemInterface, GeekResponse response, List<RkpKey> keysGenerated)147     private List<byte[]> batchProvision(ProvisioningAttempt metrics,
148             SystemInterface systemInterface,
149             GeekResponse response, List<RkpKey> keysGenerated)
150             throws RkpdException, CborException, InterruptedException {
151         int batch_size = keysGenerated.size();
152         if (batch_size < 1) {
153             throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
154                     "Request at least 1 key to be signed. Num requested: " + batch_size);
155         }
156         byte[] certRequest = systemInterface.generateCsr(metrics, response, keysGenerated);
157         if (certRequest == null) {
158             throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
159                     "Failed to serialize payload");
160         }
161         return new ServerInterface(mContext, mIsAsync).requestSignedCertificates(certRequest,
162                 response.getChallenge(), metrics);
163     }
164 
associateCertsWithKeys(List<byte[]> certChains, List<RkpKey> keysGenerated)165     private List<ProvisionedKey> associateCertsWithKeys(List<byte[]> certChains,
166             List<RkpKey> keysGenerated) throws RkpdException {
167         List<ProvisionedKey> provisionedKeys = new ArrayList<>();
168         for (byte[] chain : certChains) {
169             X509Certificate[] certChain = X509Utils.formatX509Certs(chain);
170             X509Certificate leafCertificate = certChain[0];
171             long expirationDate = X509Utils.getExpirationTimeForCertificateChain(certChain)
172                     .toInstant().toEpochMilli();
173             byte[] rawPublicKey = X509Utils.getAndFormatRawPublicKey(leafCertificate);
174             if (rawPublicKey == null) {
175                 Log.e(TAG, "Skipping malformed public key.");
176                 continue;
177             }
178             for (RkpKey key : keysGenerated) {
179                 if (Arrays.equals(key.getPublicKey(), rawPublicKey)) {
180                     provisionedKeys.add(key.generateProvisionedKey(chain,
181                             InstantConverter.fromTimestamp(expirationDate)));
182                     keysGenerated.remove(key);
183                     break;
184                 }
185             }
186         }
187         return provisionedKeys;
188     }
189 
190     /**
191      * Calculate the number of keys to be provisioned.
192      */
calculateKeysRequired(ProvisioningAttempt metrics, String serviceName)193     private int calculateKeysRequired(ProvisioningAttempt metrics, String serviceName) {
194         int numExtraAttestationKeys = Settings.getExtraSignedKeysAvailable(mContext);
195         Instant expirationTime = Settings.getExpirationTime(mContext);
196         StatsProcessor.PoolStats poolStats = StatsProcessor.processPool(mKeyDao, serviceName,
197                 numExtraAttestationKeys, expirationTime);
198         metrics.setIsKeyPoolEmpty(poolStats.keysUnassigned == 0);
199         return poolStats.keysToGenerate;
200     }
201 
checkForInterrupts()202     private void checkForInterrupts() throws InterruptedException {
203         if (Thread.interrupted()) {
204             throw new InterruptedException();
205         }
206     }
207 
208     /**
209      * Clears bad attestation keys on the basis of information provided in the FetchGeek response.
210      */
clearBadAttestationKeys(GeekResponse resp)211     public void clearBadAttestationKeys(GeekResponse resp) {
212         if (resp.lastBadCertTimeStart == null || resp.lastBadCertTimeEnd == null) {
213             // if there is no time sent, no need to do anything.
214             return;
215         }
216         if (resp.lastBadCertTimeStart.equals(Settings.getLastBadCertTimeStart(mContext))
217                 && resp.lastBadCertTimeEnd.equals(Settings.getLastBadCertTimeEnd(mContext))) {
218             // if the time is same as already stored version, no need to do anything.
219             return;
220         }
221         // clear the attestation keys on the basis of time.
222         checkAndDeleteBadKeys(resp.lastBadCertTimeStart, resp.lastBadCertTimeEnd);
223 
224         // store the time.
225         Settings.setLastBadCertTimeRange(mContext, resp.lastBadCertTimeStart,
226                 resp.lastBadCertTimeEnd);
227     }
228 
checkAndDeleteBadKeys(Instant startTime, Instant endTime)229     private void checkAndDeleteBadKeys(Instant startTime, Instant endTime) {
230         try {
231             List<ProvisionedKey> allKeys = mKeyDao.getAllKeys();
232             for (int i = 0; i < allKeys.size(); i++) {
233                 ProvisionedKey key = allKeys.get(i);
234                 X509Certificate[] certChain = X509Utils.formatX509Certs(key.certificateChain);
235                 X509Certificate leafCertificate = certChain[0];
236                 Instant creationTime = leafCertificate.getNotBefore().toInstant()
237                         .truncatedTo(ChronoUnit.MILLIS);
238 
239                 if (!creationTime.isBefore(startTime) && !creationTime.isAfter(endTime)) {
240                     mKeyDao.deleteKey(key.keyBlob);
241                 }
242             }
243         } catch (RkpdException ex) {
244             Log.e(TAG, "Could not convert certificate chain to X509 certificates.", ex);
245         }
246     }
247 }
248