1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services.encryption;
18 
19 import android.content.Context;
20 
21 import com.android.federatedcompute.internal.util.LogUtil;
22 import com.android.federatedcompute.services.common.FederatedComputeExecutors;
23 import com.android.federatedcompute.services.common.Flags;
24 import com.android.federatedcompute.services.common.FlagsFactory;
25 import com.android.federatedcompute.services.data.FederatedComputeEncryptionKey;
26 import com.android.federatedcompute.services.data.FederatedComputeEncryptionKeyDao;
27 import com.android.federatedcompute.services.http.FederatedComputeHttpRequest;
28 import com.android.federatedcompute.services.http.FederatedComputeHttpResponse;
29 import com.android.federatedcompute.services.http.HttpClient;
30 import com.android.federatedcompute.services.http.HttpClientUtil;
31 import com.android.odp.module.common.Clock;
32 import com.android.odp.module.common.MonotonicClock;
33 
34 import com.google.common.annotations.VisibleForTesting;
35 import com.google.common.collect.ImmutableList;
36 import com.google.common.util.concurrent.FluentFuture;
37 import com.google.common.util.concurrent.Futures;
38 
39 import org.json.JSONArray;
40 import org.json.JSONException;
41 import org.json.JSONObject;
42 
43 import java.util.HashMap;
44 import java.util.List;
45 import java.util.Locale;
46 import java.util.Map;
47 import java.util.Objects;
48 import java.util.concurrent.ExecutorService;
49 import java.util.concurrent.TimeUnit;
50 import java.util.concurrent.TimeoutException;
51 
52 /** Class to manage key fetch. */
53 public class FederatedComputeEncryptionKeyManager {
54     private static final String TAG = "FederatedComputeEncryptionKeyManager";
55 
56     private interface EncryptionKeyResponseContract {
57         String RESPONSE_HEADER_CACHE_CONTROL_LABEL = "cache-control";
58         String RESPONSE_HEADER_AGE_LABEL = "age";
59 
60         String RESPONSE_HEADER_CACHE_CONTROL_MAX_AGE_LABEL = "max-age=";
61 
62         String RESPONSE_KEYS_LABEL = "keys";
63 
64         String RESPONSE_KEY_ID_LABEL = "id";
65 
66         String RESPONSE_PUBLIC_KEY = "key";
67     }
68 
69     @VisibleForTesting private final FederatedComputeEncryptionKeyDao mEncryptionKeyDao;
70 
71     private static volatile FederatedComputeEncryptionKeyManager sBackgroundKeyManager;
72 
73     private final Clock mClock;
74 
75     private final Flags mFlags;
76 
77     private final HttpClient mHttpClient;
78 
79     private final ExecutorService mBackgroundExecutor;
80 
FederatedComputeEncryptionKeyManager( Clock clock, FederatedComputeEncryptionKeyDao encryptionKeyDao, Flags flags, HttpClient httpClient, ExecutorService backgroundExecutor)81     public FederatedComputeEncryptionKeyManager(
82             Clock clock,
83             FederatedComputeEncryptionKeyDao encryptionKeyDao,
84             Flags flags,
85             HttpClient httpClient,
86             ExecutorService backgroundExecutor) {
87         mClock = clock;
88         mEncryptionKeyDao = encryptionKeyDao;
89         mFlags = flags;
90         mHttpClient = httpClient;
91         mBackgroundExecutor = backgroundExecutor;
92     }
93 
94     /**
95      * @return a singleton instance for key manager
96      */
getInstance(Context context)97     public static FederatedComputeEncryptionKeyManager getInstance(Context context) {
98         if (sBackgroundKeyManager == null) {
99             synchronized (FederatedComputeEncryptionKeyManager.class) {
100                 if (sBackgroundKeyManager == null) {
101                     FederatedComputeEncryptionKeyDao encryptionKeyDao =
102                             FederatedComputeEncryptionKeyDao.getInstance(context);
103                     HttpClient client = new HttpClient();
104                     Clock clock = MonotonicClock.getInstance();
105                     Flags flags = FlagsFactory.getFlags();
106                     sBackgroundKeyManager =
107                             new FederatedComputeEncryptionKeyManager(
108                                     clock,
109                                     encryptionKeyDao,
110                                     flags,
111                                     client,
112                                     FederatedComputeExecutors.getBackgroundExecutor());
113                 }
114             }
115         }
116         return sBackgroundKeyManager;
117     }
118 
119     /** For testing only, returns an instance of key manager for test. */
120     @VisibleForTesting
getInstanceForTest( Clock clock, FederatedComputeEncryptionKeyDao encryptionKeyDao, Flags flags, HttpClient client, ExecutorService executor)121     public static FederatedComputeEncryptionKeyManager getInstanceForTest(
122             Clock clock,
123             FederatedComputeEncryptionKeyDao encryptionKeyDao,
124             Flags flags,
125             HttpClient client,
126             ExecutorService executor) {
127         if (sBackgroundKeyManager == null) {
128             synchronized (FederatedComputeEncryptionKeyManager.class) {
129                 if (sBackgroundKeyManager == null) {
130                     sBackgroundKeyManager =
131                             new FederatedComputeEncryptionKeyManager(
132                                     clock, encryptionKeyDao, flags, client, executor);
133                 }
134             }
135         }
136         return sBackgroundKeyManager;
137     }
138 
139     /**
140      * Fetch the active key from the server, persists the fetched key to encryption_key table, and
141      * deletes expired keys
142      */
fetchAndPersistActiveKeys( @ederatedComputeEncryptionKey.KeyType int keyType, boolean isScheduledJob)143     public FluentFuture<List<FederatedComputeEncryptionKey>> fetchAndPersistActiveKeys(
144             @FederatedComputeEncryptionKey.KeyType int keyType, boolean isScheduledJob) {
145         String fetchUri = mFlags.getEncryptionKeyFetchUrl();
146         if (fetchUri == null) {
147             return FluentFuture.from(Futures.immediateFailedFuture(
148                     new IllegalArgumentException("Url to fetch active encryption keys is null")));
149         }
150 
151         FederatedComputeHttpRequest request;
152         try {
153             request =
154                     FederatedComputeHttpRequest.create(
155                             fetchUri,
156                             HttpClientUtil.HttpMethod.GET,
157                             new HashMap<String, String>(),
158                             HttpClientUtil.EMPTY_BODY);
159         } catch (Exception e) {
160             return FluentFuture.from(Futures.immediateFailedFuture(e));
161         }
162 
163         return FluentFuture.from(mHttpClient.performRequestAsyncWithRetry(request))
164                 .transform(
165                         response ->
166                                 parseFetchEncryptionKeyPayload(
167                                         response, keyType, mClock.currentTimeMillis()),
168                         mBackgroundExecutor)
169                 .transform(
170                         result -> {
171                             result.forEach(mEncryptionKeyDao::insertEncryptionKey);
172                             if (isScheduledJob) {
173                                 // When the job is a background scheduled job, delete the
174                                 // expired keys, otherwise, only fetch from the key server.
175                                 mEncryptionKeyDao.deleteExpiredKeys();
176                             }
177                             return result;
178                         },
179                         mBackgroundExecutor); // TODO: Add timeout controlled by Ph flags
180     }
181 
parseFetchEncryptionKeyPayload( FederatedComputeHttpResponse keyFetchResponse, @FederatedComputeEncryptionKey.KeyType int keyType, Long fetchTime)182     private ImmutableList<FederatedComputeEncryptionKey> parseFetchEncryptionKeyPayload(
183             FederatedComputeHttpResponse keyFetchResponse,
184             @FederatedComputeEncryptionKey.KeyType int keyType,
185             Long fetchTime) {
186         String payload = new String(Objects.requireNonNull(keyFetchResponse.getPayload()));
187         Map<String, List<String>> headers = keyFetchResponse.getHeaders();
188         long ttlInSeconds = getTTL(headers);
189         if (ttlInSeconds <= 0) {
190             ttlInSeconds = mFlags.getFederatedComputeEncryptionKeyMaxAgeSeconds();
191         }
192 
193         try {
194             JSONObject responseObj = new JSONObject(payload);
195             JSONArray keysArr =
196                     responseObj.getJSONArray(EncryptionKeyResponseContract.RESPONSE_KEYS_LABEL);
197             ImmutableList.Builder<FederatedComputeEncryptionKey> encryptionKeys =
198                     ImmutableList.builder();
199 
200             for (int i = 0; i < keysArr.length(); i++) {
201                 JSONObject keyObj = keysArr.getJSONObject(i);
202                 FederatedComputeEncryptionKey key =
203                         new FederatedComputeEncryptionKey.Builder()
204                                 .setKeyIdentifier(
205                                         keyObj.getString(
206                                                 EncryptionKeyResponseContract
207                                                         .RESPONSE_KEY_ID_LABEL))
208                                 .setPublicKey(
209                                         keyObj.getString(
210                                                 EncryptionKeyResponseContract.RESPONSE_PUBLIC_KEY))
211                                 .setKeyType(keyType)
212                                 .setCreationTime(fetchTime)
213                                 .setExpiryTime(
214                                         fetchTime + ttlInSeconds * 1000) // convert to milliseconds
215                                 .build();
216                 encryptionKeys.add(key);
217             }
218             return encryptionKeys.build();
219         } catch (JSONException e) {
220             LogUtil.e(TAG, "Invalid Json response: " + e.getMessage());
221             return ImmutableList.of();
222         }
223     }
224 
225     /**
226      * Parse the "age" and "cache-control" of response headers. Calculate the ttl of the current key
227      * maxage (in cache-control) - age.
228      *
229      * @return the ttl in seconds of the keys.
230      */
231     @VisibleForTesting
getTTL(Map<String, List<String>> headers)232     static long getTTL(Map<String, List<String>> headers) {
233         String cacheControl = null;
234         int cachedAge = 0;
235         int remainingHeaders = 2;
236         for (String key : headers.keySet()) {
237             if (key != null) {
238                 if (key.equalsIgnoreCase(
239                         EncryptionKeyResponseContract.RESPONSE_HEADER_CACHE_CONTROL_LABEL)) {
240                     List<String> field = headers.get(key);
241                     if (field != null && field.size() > 0) {
242                         cacheControl = field.get(0).toLowerCase(Locale.ENGLISH);
243                         remainingHeaders -= 1;
244                     }
245 
246                 } else if (key.equalsIgnoreCase(
247                         EncryptionKeyResponseContract.RESPONSE_HEADER_AGE_LABEL)) {
248                     List<String> field = headers.get(key);
249                     if (field != null && field.size() > 0) {
250                         try {
251                             cachedAge = Integer.parseInt(field.get(0));
252                         } catch (NumberFormatException e) {
253                             LogUtil.e(TAG, "Error parsing age header");
254                         }
255                         remainingHeaders -= 1;
256                     }
257                 }
258             }
259             if (remainingHeaders == 0) {
260                 break;
261             }
262         }
263         if (cacheControl == null) {
264             LogUtil.d(TAG, "Cache-Control header or value is missing");
265             return 0;
266         }
267 
268         String[] tokens = cacheControl.split(",", /* limit= */ 0);
269         long maxAge = 0;
270         for (String s : tokens) {
271             String token = s.trim();
272             if (token.startsWith(
273                     EncryptionKeyResponseContract.RESPONSE_HEADER_CACHE_CONTROL_MAX_AGE_LABEL)) {
274                 try {
275                     maxAge =
276                             Long.parseLong(
277                                     token.substring(
278                                             /* beginIndex= */ EncryptionKeyResponseContract
279                                                     .RESPONSE_HEADER_CACHE_CONTROL_MAX_AGE_LABEL
280                                                     .length())); // in the format of
281                     // "max-age=<number>"
282                 } catch (NumberFormatException e) {
283                     LogUtil.d(TAG, "Failed to parse max-age value");
284                     return 0;
285                 }
286             }
287         }
288         if (maxAge == 0) {
289             LogUtil.d(TAG, "max-age directive is missing");
290             return 0;
291         }
292         return maxAge - cachedAge;
293     }
294 
295     /** Get active keys, if there is no active key, then force a fetch from the key service.
296      * In the case of key fetching from the key service, the http call
297      * is executed on a BlockingExecutor.
298      * @return The list of active keys.
299      */
getOrFetchActiveKeys(int keyType, int keyCount)300     public List<FederatedComputeEncryptionKey> getOrFetchActiveKeys(int keyType, int keyCount) {
301         List<FederatedComputeEncryptionKey> activeKeys = mEncryptionKeyDao
302                 .getLatestExpiryNKeys(keyCount);
303         if (activeKeys.size() > 0) {
304             return activeKeys;
305         }
306         try {
307             var fetchedKeysUnused = fetchAndPersistActiveKeys(keyType,
308                     /* isScheduledJob= */ false).get(/* timeout= */ 5, TimeUnit.SECONDS);
309             activeKeys = mEncryptionKeyDao.getLatestExpiryNKeys(keyCount);
310             if (activeKeys.size() > 0) {
311                 return activeKeys;
312             }
313         } catch (TimeoutException e) {
314             LogUtil.e(TAG, "Time out when forcing encryption key fetch: "
315                     + e.getMessage());
316         } catch (Exception e) {
317             LogUtil.e(TAG, "Exception encountered when forcing encryption key fetch: "
318                     + e.getMessage());
319         }
320         return activeKeys;
321     }
322 }
323