1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.ondevicepersonalization.services.federatedcompute;
18 
19 import android.adservices.ondevicepersonalization.Constants;
20 import android.adservices.ondevicepersonalization.TrainingExampleRecord;
21 import android.adservices.ondevicepersonalization.TrainingExamplesInputParcel;
22 import android.adservices.ondevicepersonalization.TrainingExamplesOutputParcel;
23 import android.adservices.ondevicepersonalization.UserData;
24 import android.annotation.NonNull;
25 import android.content.ComponentName;
26 import android.content.Context;
27 import android.federatedcompute.ExampleStoreService;
28 import android.federatedcompute.FederatedComputeManager;
29 import android.federatedcompute.common.ClientConstants;
30 import android.os.Bundle;
31 import android.os.OutcomeReceiver;
32 
33 import com.android.odp.module.common.Clock;
34 import com.android.odp.module.common.MonotonicClock;
35 import com.android.ondevicepersonalization.internal.util.LoggerFactory;
36 import com.android.ondevicepersonalization.internal.util.OdpParceledListSlice;
37 import com.android.ondevicepersonalization.services.Flags;
38 import com.android.ondevicepersonalization.services.FlagsFactory;
39 import com.android.ondevicepersonalization.services.OnDevicePersonalizationExecutors;
40 import com.android.ondevicepersonalization.services.data.DataAccessPermission;
41 import com.android.ondevicepersonalization.services.data.DataAccessServiceImpl;
42 import com.android.ondevicepersonalization.services.data.events.EventState;
43 import com.android.ondevicepersonalization.services.data.events.EventsDao;
44 import com.android.ondevicepersonalization.services.data.user.UserPrivacyStatus;
45 import com.android.ondevicepersonalization.services.manifest.AppManifestConfigHelper;
46 import com.android.ondevicepersonalization.services.policyengine.UserDataAccessor;
47 import com.android.ondevicepersonalization.services.process.IsolatedServiceInfo;
48 import com.android.ondevicepersonalization.services.process.PluginProcessRunner;
49 import com.android.ondevicepersonalization.services.process.ProcessRunner;
50 import com.android.ondevicepersonalization.services.process.SharedIsolatedProcessRunner;
51 import com.android.ondevicepersonalization.services.util.StatsUtils;
52 
53 import com.google.common.util.concurrent.FluentFuture;
54 import com.google.common.util.concurrent.FutureCallback;
55 import com.google.common.util.concurrent.Futures;
56 import com.google.common.util.concurrent.ListenableFuture;
57 import com.google.common.util.concurrent.ListeningScheduledExecutorService;
58 
59 import java.util.Objects;
60 import java.util.concurrent.TimeUnit;
61 
62 /** Implementation of ExampleStoreService for OnDevicePersonalization */
63 public final class OdpExampleStoreService extends ExampleStoreService {
64 
65     private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
66     private static final String TAG = OdpExampleStoreService.class.getSimpleName();
67     private static final String TASK_NAME = "ExampleStore";
68 
69     static class Injector {
getClock()70         Clock getClock() {
71             return MonotonicClock.getInstance();
72         }
73 
getFlags()74         Flags getFlags() {
75             return FlagsFactory.getFlags();
76         }
77 
getScheduledExecutor()78         ListeningScheduledExecutorService getScheduledExecutor() {
79             return OnDevicePersonalizationExecutors.getScheduledExecutor();
80         }
81 
getProcessRunner()82         ProcessRunner getProcessRunner() {
83             return FlagsFactory.getFlags().isSharedIsolatedProcessFeatureEnabled()
84                     ? SharedIsolatedProcessRunner.getInstance()
85                     : PluginProcessRunner.getInstance();
86         }
87     }
88 
89     private final Injector mInjector = new Injector();
90 
91     /** Generates a unique task identifier from the given strings */
getTaskIdentifier(String populationName, String taskId)92     public static String getTaskIdentifier(String populationName, String taskId) {
93         return populationName + "_" + taskId;
94     }
95 
96     /** Generates a unique task identifier from the given strings */
getTaskIdentifier( String populationName, String taskId, String collectionUri)97     public static String getTaskIdentifier(
98             String populationName, String taskId, String collectionUri) {
99         return populationName + "_" + taskId + "_" + collectionUri;
100     }
101 
isCollectionUriPresent(String collectionUri)102     private static boolean isCollectionUriPresent(String collectionUri) {
103         return collectionUri != null && !collectionUri.isEmpty();
104     }
105 
106     @Override
startQuery(@onNull Bundle params, @NonNull QueryCallback callback)107     public void startQuery(@NonNull Bundle params, @NonNull QueryCallback callback) {
108         try {
109             ContextData contextData =
110                     ContextData.fromByteArray(
111                             Objects.requireNonNull(
112                                     params.getByteArray(ClientConstants.EXTRA_CONTEXT_DATA)));
113             String packageName = contextData.getPackageName();
114             String ownerClassName = contextData.getClassName();
115             String populationName =
116                     Objects.requireNonNull(params.getString(ClientConstants.EXTRA_POPULATION_NAME));
117             String taskId = Objects.requireNonNull(params.getString(ClientConstants.EXTRA_TASK_ID));
118             String collectionUri = params.getString(ClientConstants.EXTRA_COLLECTION_URI);
119             int eligibilityMinExample =
120                     params.getInt(ClientConstants.EXTRA_ELIGIBILITY_MIN_EXAMPLE);
121 
122             EventsDao eventDao = EventsDao.getInstance(getContext());
123 
124             boolean privacyStatusEligible = true;
125 
126             if (!UserPrivacyStatus.getInstance().isMeasurementEnabled()) {
127                 privacyStatusEligible = false;
128                 sLogger.w(TAG + ": Measurement control is not given.");
129             }
130 
131             // Cancel job if on longer valid. This is written to the table during scheduling
132             // via {@link FederatedComputeServiceImpl} and deleted either during cancel or
133             // during maintenance for uninstalled packages.
134             ComponentName owner = ComponentName.createRelative(packageName, ownerClassName);
135             EventState eventStatePopulation = eventDao.getEventState(populationName, owner);
136             if (!privacyStatusEligible || eventStatePopulation == null) {
137                 sLogger.w("Job was either cancelled or package was uninstalled");
138                 // Cancel job.
139                 FederatedComputeManager FCManager =
140                         getContext().getSystemService(FederatedComputeManager.class);
141                 if (FCManager == null) {
142                     sLogger.e(TAG + ": Failed to get FederatedCompute Service");
143                     callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
144                     return;
145                 }
146                 FCManager.cancel(
147                         owner,
148                         populationName,
149                         OnDevicePersonalizationExecutors.getBackgroundExecutor(),
150                         new OutcomeReceiver<Object, Exception>() {
151                             @Override
152                             public void onResult(Object result) {
153                                 sLogger.d(TAG + ": Successfully canceled job");
154                                 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
155                             }
156 
157                             @Override
158                             public void onError(Exception error) {
159                                 sLogger.e(TAG + ": Error while cancelling job", error);
160                                 OutcomeReceiver.super.onError(error);
161                                 callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
162                             }
163                         });
164                 return;
165             }
166 
167             // Get resumptionToken
168             EventState eventState =
169                     eventDao.getEventState(
170                             isCollectionUriPresent(collectionUri)
171                                     ? getTaskIdentifier(populationName, taskId, collectionUri)
172                                     : getTaskIdentifier(populationName, taskId),
173                             owner);
174             byte[] resumptionToken = null;
175             if (eventState != null) {
176                 resumptionToken = eventState.getToken();
177             }
178 
179             TrainingExamplesInputParcel.Builder input =
180                     new TrainingExamplesInputParcel.Builder()
181                             .setResumptionToken(resumptionToken)
182                             .setPopulationName(populationName)
183                             .setTaskName(taskId);
184             if (isCollectionUriPresent(collectionUri)) {
185                 input.setCollectionName(collectionUri);
186             }
187 
188             String className =
189                     AppManifestConfigHelper.getServiceNameFromOdpSettings(
190                             getContext(), packageName);
191             ListenableFuture<IsolatedServiceInfo> loadFuture =
192                     mInjector
193                             .getProcessRunner()
194                             .loadIsolatedService(
195                                     TASK_NAME,
196                                     ComponentName.createRelative(packageName, className));
197             ListenableFuture<TrainingExamplesOutputParcel> resultFuture =
198                     FluentFuture.from(loadFuture)
199                             .transformAsync(
200                                     result ->
201                                             executeOnTrainingExamples(
202                                                     result, input.build(), packageName),
203                                     OnDevicePersonalizationExecutors.getBackgroundExecutor())
204                             .transform(
205                                     result -> {
206                                         return result.getParcelable(
207                                                 Constants.EXTRA_RESULT,
208                                                 TrainingExamplesOutputParcel.class);
209                                     },
210                                     OnDevicePersonalizationExecutors.getBackgroundExecutor())
211                             .withTimeout(
212                                     mInjector.getFlags().getIsolatedServiceDeadlineSeconds(),
213                                     TimeUnit.SECONDS,
214                                     mInjector.getScheduledExecutor());
215 
216             Futures.addCallback(
217                     resultFuture,
218                     new FutureCallback<TrainingExamplesOutputParcel>() {
219                         @Override
220                         public void onSuccess(
221                                 TrainingExamplesOutputParcel trainingExamplesOutputParcel) {
222                             OdpParceledListSlice<TrainingExampleRecord> trainingExampleRecordList =
223                                     trainingExamplesOutputParcel.getTrainingExampleRecords();
224 
225                             if (trainingExampleRecordList == null
226                                     || trainingExampleRecordList.getList().size()
227                                             < eligibilityMinExample) {
228                                 callback.onStartQueryFailure(
229                                         ClientConstants.STATUS_NOT_ENOUGH_DATA);
230                             } else {
231                                 callback.onStartQuerySuccess(
232                                         OdpExampleStoreIteratorFactory.getInstance()
233                                                 .createIterator(
234                                                         trainingExampleRecordList.getList()));
235                             }
236                         }
237 
238                         @Override
239                         public void onFailure(Throwable t) {
240                             sLogger.w(t, "%s : Request failed.", TAG);
241                             callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
242                         }
243                     },
244                     OnDevicePersonalizationExecutors.getBackgroundExecutor());
245 
246             var unused =
247                     Futures.whenAllComplete(loadFuture, resultFuture)
248                             .callAsync(
249                                     () ->
250                                             mInjector
251                                                     .getProcessRunner()
252                                                     .unloadIsolatedService(loadFuture.get()),
253                                     OnDevicePersonalizationExecutors.getBackgroundExecutor());
254         } catch (Exception e) {
255             sLogger.w(e, "%s : Start query failed.", TAG);
256             callback.onStartQueryFailure(ClientConstants.STATUS_INTERNAL_ERROR);
257         }
258     }
259 
executeOnTrainingExamples( IsolatedServiceInfo isolatedServiceInfo, TrainingExamplesInputParcel exampleInput, String packageName)260     private ListenableFuture<Bundle> executeOnTrainingExamples(
261             IsolatedServiceInfo isolatedServiceInfo,
262             TrainingExamplesInputParcel exampleInput,
263             String packageName) {
264         sLogger.d(TAG + ": executeOnTrainingExamples() started.");
265         Bundle serviceParams = new Bundle();
266         serviceParams.putParcelable(Constants.EXTRA_INPUT, exampleInput);
267         String serviceClass =
268                 AppManifestConfigHelper.getServiceNameFromOdpSettings(getContext(), packageName);
269         DataAccessServiceImpl binder =
270                 new DataAccessServiceImpl(
271                         ComponentName.createRelative(packageName, serviceClass),
272                         getContext(),
273                         // ODP provides accurate user signal in training flow, so we disable write
274                         // access of databases to prevent leak.
275                         /* localDataPermission */ DataAccessPermission.READ_ONLY,
276                         /* eventDataPermission */ DataAccessPermission.READ_ONLY);
277         serviceParams.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, binder);
278         UserDataAccessor userDataAccessor = new UserDataAccessor();
279         UserData userData = userDataAccessor.getUserDataWithAppInstall();
280         serviceParams.putParcelable(Constants.EXTRA_USER_DATA, userData);
281         ListenableFuture<Bundle> result =
282                 mInjector
283                         .getProcessRunner()
284                         .runIsolatedService(
285                                 isolatedServiceInfo, Constants.OP_TRAINING_EXAMPLE, serviceParams);
286         return FluentFuture.from(result)
287                 .transform(
288                         val -> {
289                             StatsUtils.writeServiceRequestMetrics(
290                                     Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
291                                     val, mInjector.getClock(),
292                                     Constants.STATUS_SUCCESS,
293                                     isolatedServiceInfo.getStartTimeMillis());
294                             return val;
295                         },
296                         OnDevicePersonalizationExecutors.getBackgroundExecutor())
297                 .catchingAsync(
298                         Exception.class,
299                         e -> {
300                             StatsUtils.writeServiceRequestMetrics(
301                                     Constants.API_NAME_SERVICE_ON_TRAINING_EXAMPLE,
302                                     /* result= */ null, mInjector.getClock(),
303                                     Constants.STATUS_INTERNAL_ERROR,
304                                     isolatedServiceInfo.getStartTimeMillis());
305                             return Futures.immediateFailedFuture(e);
306                         },
307                         OnDevicePersonalizationExecutors.getBackgroundExecutor());
308     }
309 
310     // used for tests to provide mock/real implementation of context.
getContext()311     private Context getContext() {
312         return this.getApplicationContext();
313     }
314 }
315