1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 
20 import static org.junit.Assert.assertEquals;
21 import static org.junit.Assert.assertTrue;
22 
23 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
24 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback;
25 import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback;
26 import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService;
27 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
28 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
29 import android.adservices.ondevicepersonalization.aidl.IIsolatedService;
30 import android.adservices.ondevicepersonalization.aidl.IIsolatedServiceCallback;
31 import android.content.Context;
32 import android.federatedcompute.common.TrainingOptions;
33 import android.net.Uri;
34 import android.os.Bundle;
35 import android.os.PersistableBundle;
36 import android.os.RemoteException;
37 
38 import androidx.test.core.app.ApplicationProvider;
39 import androidx.test.filters.SmallTest;
40 
41 import com.android.federatedcompute.internal.util.AbstractServiceBinder;
42 import com.android.ondevicepersonalization.internal.util.ByteArrayParceledSlice;
43 import com.android.ondevicepersonalization.internal.util.PersistableBundleUtils;
44 
45 import org.junit.After;
46 import org.junit.Before;
47 import org.junit.Test;
48 import org.junit.runner.RunWith;
49 import org.junit.runners.Parameterized;
50 
51 import java.util.Arrays;
52 import java.util.Collection;
53 import java.util.concurrent.CountDownLatch;
54 import java.util.concurrent.TimeUnit;
55 
56 @SmallTest
57 @RunWith(Parameterized.class)
58 public class IsolatedServiceExceptionSafetyTest {
59 
60     private final Context mContext = ApplicationProvider.getApplicationContext();
61 
62     private IIsolatedService mIsolatedService;
63     private AbstractServiceBinder<IIsolatedService> mServiceBinder;
64     private int mCallbackErrorCode;
65     private int mIsolatedServiceErrorCode;
66     private CountDownLatch mLatch;
67 
68     @Parameterized.Parameter(0)
69     public String operation;
70 
71     @Parameterized.Parameters
data()72     public static Collection<Object[]> data() {
73         return Arrays.asList(
74                 new Object[][] {
75                     {RuntimeException.class.getName()},
76                     {NullPointerException.class.getName()},
77                     {IllegalArgumentException.class.getName()}
78                 });
79     }
80 
81     @Before
setUp()82     public void setUp() throws Exception {
83         mServiceBinder = AbstractServiceBinder.getIsolatedServiceBinderByServiceName(
84                 mContext,
85                 "android.adservices.ondevicepersonalization.IsolatedServiceExceptionSafetyTestImpl",
86                 mContext.getPackageName(),
87                 "testIsolatedProcess",
88                 0,
89                 IIsolatedService.Stub::asInterface);
90 
91         mIsolatedService = mServiceBinder.getService(Runnable::run);
92         mLatch = new CountDownLatch(1);
93     }
94 
95     @After
tearDown()96     public void tearDown() {
97         mServiceBinder.unbindFromService();
98         mIsolatedService = null;
99         mCallbackErrorCode = 0;
100     }
101 
102     @Test
testOnRequestExceptions()103     public void testOnRequestExceptions() throws Exception {
104         PersistableBundle appParams = new PersistableBundle();
105         appParams.putString("ex", operation);
106         ExecuteInputParcel input =
107                 new ExecuteInputParcel.Builder()
108                         .setAppPackageName("com.testapp")
109                         .setSerializedAppParams(new ByteArrayParceledSlice(
110                                 PersistableBundleUtils.toByteArray(appParams)))
111                         .build();
112         Bundle params = new Bundle();
113         params.putParcelable(Constants.EXTRA_INPUT, input);
114         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
115         params.putBinder(
116                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
117                 new TestFederatedComputeService());
118         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
119         mIsolatedService.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
120         assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS));
121         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
122     }
123 
124     @Test
testOnDownloadExceptions()125     public void testOnDownloadExceptions() throws Exception {
126         DownloadInputParcel input =
127                 new DownloadInputParcel.Builder()
128                         .setDataAccessServiceBinder(new TestDataAccessService(operation))
129                         .build();
130         Bundle params = new Bundle();
131         params.putParcelable(Constants.EXTRA_INPUT, input);
132         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
133         params.putBinder(
134                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
135                 new TestFederatedComputeService());
136         mIsolatedService.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback());
137         assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS));
138         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
139     }
140 
141     @Test
testOnRender()142     public void testOnRender() throws Exception {
143         RenderInputParcel input =
144                 new RenderInputParcel.Builder()
145                         .setRenderingConfig(
146                                 new RenderingConfig.Builder().addKey(operation).build())
147                         .build();
148         Bundle params = new Bundle();
149         params.putParcelable(Constants.EXTRA_INPUT, input);
150         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
151         mIsolatedService.onRequest(Constants.OP_RENDER, params, new TestServiceCallback());
152         assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS));
153         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
154     }
155 
156     @Test
testOnEvent()157     public void testOnEvent() throws Exception {
158         PersistableBundle appParams = new PersistableBundle();
159         appParams.putString("ex", operation);
160         Bundle params = new Bundle();
161         params.putParcelable(
162                 Constants.EXTRA_INPUT,
163                 new EventInputParcel.Builder().setParameters(appParams).build());
164         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
165         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
166         mIsolatedService.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback());
167         assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS));
168         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
169     }
170 
171     @Test
testOnTrainingExamples()172     public void testOnTrainingExamples() throws Exception {
173         TrainingExamplesInputParcel input =
174                 new TrainingExamplesInputParcel.Builder()
175                         .setPopulationName("")
176                         .setTaskName(operation)
177                         .setResumptionToken(new byte[] {0})
178                         .build();
179         Bundle params = new Bundle();
180         params.putParcelable(Constants.EXTRA_INPUT, input);
181         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
182         mIsolatedService.onRequest(
183                 Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback());
184         assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS));
185         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
186     }
187 
188     @Test
testOnWebTrigger()189     public void testOnWebTrigger() throws Exception {
190         WebTriggerInputParcel input =
191                 new WebTriggerInputParcel.Builder(
192                                 Uri.parse("http://desturl"), operation, new byte[] {1, 2, 3})
193                         .build();
194         Bundle params = new Bundle();
195         params.putParcelable(Constants.EXTRA_INPUT, input);
196         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
197         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
198         mIsolatedService.onRequest(Constants.OP_WEB_TRIGGER, params, new TestServiceCallback());
199         assertTrue(mLatch.await(5000, TimeUnit.MILLISECONDS));
200         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
201     }
202 
203     class TestServiceCallback extends IIsolatedServiceCallback.Stub {
204         @Override
onSuccess(Bundle result)205         public void onSuccess(Bundle result) {
206             mLatch.countDown();
207         }
208 
209         @Override
onError(int errorCode, int isolatedServiceErrorCode)210         public void onError(int errorCode, int isolatedServiceErrorCode) {
211             mCallbackErrorCode = errorCode;
212             mIsolatedServiceErrorCode = isolatedServiceErrorCode;
213             mLatch.countDown();
214         }
215     }
216 
217     static class TestDataAccessService extends IDataAccessService.Stub {
218 
219         String mOp;
220 
TestDataAccessService(String operation)221         TestDataAccessService(String operation) {
222             this.mOp = operation;
223         }
224 
TestDataAccessService()225         TestDataAccessService() {
226             mOp = null;
227         }
228 
229         @Override
onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)230         public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {
231             // pass parameters for onDownloadCompleted testing
232             if (mOp != null) {
233                 Bundle bndl = new Bundle();
234                 bndl.putParcelable(
235                         Constants.EXTRA_RESULT, new ByteArrayParceledSlice(mOp.getBytes()));
236                 try {
237                     callback.onSuccess(bndl);
238                 } catch (RemoteException e) {
239                     throw new RuntimeException(e);
240                 }
241             }
242         }
243 
244         @Override
logApiCallStats(int apiName, long latencyMillis, int responseCode)245         public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {}
246     }
247 
248     static class TestFederatedComputeService extends IFederatedComputeService.Stub {
249         @Override
schedule(TrainingOptions trainingOptions, IFederatedComputeCallback callback)250         public void schedule(TrainingOptions trainingOptions, IFederatedComputeCallback callback) {}
251 
cancel(String populationName, IFederatedComputeCallback callback)252         public void cancel(String populationName, IFederatedComputeCallback callback) {}
253     }
254 
255     static class TestIsolatedModelService extends IIsolatedModelService.Stub {
256         @Override
runInference(Bundle params, IIsolatedModelServiceCallback callback)257         public void runInference(Bundle params, IIsolatedModelServiceCallback callback) {}
258     }
259 }
260