/* * Copyright (C) 2022 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.android.rkpdapp.provisioner; import android.content.Context; import android.os.RemoteException; import android.util.Log; import com.android.rkpdapp.GeekResponse; import com.android.rkpdapp.RkpdException; import com.android.rkpdapp.database.InstantConverter; import com.android.rkpdapp.database.ProvisionedKey; import com.android.rkpdapp.database.ProvisionedKeyDao; import com.android.rkpdapp.database.RkpKey; import com.android.rkpdapp.interfaces.ServerInterface; import com.android.rkpdapp.interfaces.SystemInterface; import com.android.rkpdapp.metrics.ProvisioningAttempt; import com.android.rkpdapp.utils.Settings; import com.android.rkpdapp.utils.StatsProcessor; import com.android.rkpdapp.utils.X509Utils; import java.security.cert.X509Certificate; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import co.nstant.in.cbor.CborException; /** * Provides an easy package to run the provisioning process from start to finish, interfacing * with the system interface and the server backend in order to provision attestation certificates * to the device. */ public class Provisioner { private static final String TAG = "RkpdProvisioner"; private static final int FAILURE_MAXIMUM = 5; private static final Object provisionKeysLock = new Object(); private final Context mContext; private final ProvisionedKeyDao mKeyDao; private final boolean mIsAsync; public Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao, boolean isAsync) { mContext = applicationContext; mKeyDao = keyDao; mIsAsync = isAsync; } /** * Check to see if we need to perform provisioning or not for the given * IRemotelyProvisionedComponent. * @param serviceName the name of the remotely provisioned component to be provisioned * @return true if the remotely provisioned component requires more keys, false if the pool * of available keys is healthy. */ public boolean isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName) { return calculateKeysRequired(metrics, serviceName) > 0; } /** * Generate, sign and store remotely provisioned keys. */ public void provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface, GeekResponse geekResponse) throws CborException, RkpdException, InterruptedException { synchronized (provisionKeysLock) { try { int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName()); Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired); if (keysRequired == 0) { metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED); return; } List keysGenerated = generateKeys(metrics, keysRequired, systemInterface); checkForInterrupts(); List certChains = fetchCertificates(metrics, keysGenerated, systemInterface, geekResponse); checkForInterrupts(); List keys = associateCertsWithKeys(certChains, keysGenerated); mKeyDao.insertKeys(keys); Log.i(TAG, "Total provisioned keys: " + keys.size()); metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED); } catch (InterruptedException e) { metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED); throw e; } catch (RkpdException e) { if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) { Log.e(TAG, "Too many failures, resetting defaults."); Settings.resetDefaultConfig(mContext); } // Rethrow to provide failure signal to caller throw e; } } } private List generateKeys(ProvisioningAttempt metrics, int numKeysRequired, SystemInterface systemInterface) throws CborException, RkpdException, InterruptedException { List keyArray = new ArrayList<>(numKeysRequired); checkForInterrupts(); for (long i = 0; i < numKeysRequired; i++) { keyArray.add(systemInterface.generateKey(metrics)); } return keyArray; } private List fetchCertificates(ProvisioningAttempt metrics, List keysGenerated, SystemInterface systemInterface, GeekResponse geekResponse) throws RkpdException, CborException, InterruptedException { int provisionedSoFar = 0; List certChains = new ArrayList<>(keysGenerated.size()); int maxBatchSize; try { maxBatchSize = systemInterface.getBatchSize(); } catch (RemoteException e) { throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR, "Error getting batch size from the system", e); } while (provisionedSoFar != keysGenerated.size()) { int batchSize = Math.min(keysGenerated.size() - provisionedSoFar, maxBatchSize); certChains.addAll(batchProvision(metrics, systemInterface, geekResponse, keysGenerated.subList(provisionedSoFar, batchSize + provisionedSoFar))); provisionedSoFar += batchSize; } return certChains; } private List batchProvision(ProvisioningAttempt metrics, SystemInterface systemInterface, GeekResponse response, List keysGenerated) throws RkpdException, CborException, InterruptedException { int batch_size = keysGenerated.size(); if (batch_size < 1) { throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR, "Request at least 1 key to be signed. Num requested: " + batch_size); } byte[] certRequest = systemInterface.generateCsr(metrics, response, keysGenerated); if (certRequest == null) { throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR, "Failed to serialize payload"); } return new ServerInterface(mContext, mIsAsync).requestSignedCertificates(certRequest, response.getChallenge(), metrics); } private List associateCertsWithKeys(List certChains, List keysGenerated) throws RkpdException { List provisionedKeys = new ArrayList<>(); for (byte[] chain : certChains) { X509Certificate[] certChain = X509Utils.formatX509Certs(chain); X509Certificate leafCertificate = certChain[0]; long expirationDate = X509Utils.getExpirationTimeForCertificateChain(certChain) .toInstant().toEpochMilli(); byte[] rawPublicKey = X509Utils.getAndFormatRawPublicKey(leafCertificate); if (rawPublicKey == null) { Log.e(TAG, "Skipping malformed public key."); continue; } for (RkpKey key : keysGenerated) { if (Arrays.equals(key.getPublicKey(), rawPublicKey)) { provisionedKeys.add(key.generateProvisionedKey(chain, InstantConverter.fromTimestamp(expirationDate))); keysGenerated.remove(key); break; } } } return provisionedKeys; } /** * Calculate the number of keys to be provisioned. */ private int calculateKeysRequired(ProvisioningAttempt metrics, String serviceName) { int numExtraAttestationKeys = Settings.getExtraSignedKeysAvailable(mContext); Instant expirationTime = Settings.getExpirationTime(mContext); StatsProcessor.PoolStats poolStats = StatsProcessor.processPool(mKeyDao, serviceName, numExtraAttestationKeys, expirationTime); metrics.setIsKeyPoolEmpty(poolStats.keysUnassigned == 0); return poolStats.keysToGenerate; } private void checkForInterrupts() throws InterruptedException { if (Thread.interrupted()) { throw new InterruptedException(); } } /** * Clears bad attestation keys on the basis of information provided in the FetchGeek response. */ public void clearBadAttestationKeys(GeekResponse resp) { if (resp.lastBadCertTimeStart == null || resp.lastBadCertTimeEnd == null) { // if there is no time sent, no need to do anything. return; } if (resp.lastBadCertTimeStart.equals(Settings.getLastBadCertTimeStart(mContext)) && resp.lastBadCertTimeEnd.equals(Settings.getLastBadCertTimeEnd(mContext))) { // if the time is same as already stored version, no need to do anything. return; } // clear the attestation keys on the basis of time. checkAndDeleteBadKeys(resp.lastBadCertTimeStart, resp.lastBadCertTimeEnd); // store the time. Settings.setLastBadCertTimeRange(mContext, resp.lastBadCertTimeStart, resp.lastBadCertTimeEnd); } private void checkAndDeleteBadKeys(Instant startTime, Instant endTime) { try { List allKeys = mKeyDao.getAllKeys(); for (int i = 0; i < allKeys.size(); i++) { ProvisionedKey key = allKeys.get(i); X509Certificate[] certChain = X509Utils.formatX509Certs(key.certificateChain); X509Certificate leafCertificate = certChain[0]; Instant creationTime = leafCertificate.getNotBefore().toInstant() .truncatedTo(ChronoUnit.MILLIS); if (!creationTime.isBefore(startTime) && !creationTime.isAfter(endTime)) { mKeyDao.deleteKey(key.keyBlob); } } } catch (RkpdException ex) { Log.e(TAG, "Could not convert certificate chain to X509 certificates.", ex); } } }