1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 
21 import static org.junit.Assert.assertArrayEquals;
22 import static org.junit.Assert.assertEquals;
23 import static org.junit.Assert.assertThrows;
24 import static org.junit.Assert.assertTrue;
25 
26 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
27 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback;
28 import android.adservices.ondevicepersonalization.aidl.IFederatedComputeCallback;
29 import android.adservices.ondevicepersonalization.aidl.IFederatedComputeService;
30 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
31 import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
32 import android.adservices.ondevicepersonalization.aidl.IIsolatedService;
33 import android.adservices.ondevicepersonalization.aidl.IIsolatedServiceCallback;
34 import android.content.ContentValues;
35 import android.federatedcompute.common.TrainingOptions;
36 import android.net.Uri;
37 import android.os.Bundle;
38 import android.os.OutcomeReceiver;
39 import android.os.ParcelFileDescriptor;
40 import android.os.PersistableBundle;
41 
42 import androidx.test.ext.junit.runners.AndroidJUnit4;
43 import androidx.test.filters.SmallTest;
44 
45 import com.android.ondevicepersonalization.internal.util.ByteArrayParceledSlice;
46 import com.android.ondevicepersonalization.internal.util.PersistableBundleUtils;
47 
48 import org.junit.Before;
49 import org.junit.Test;
50 import org.junit.runner.RunWith;
51 
52 import java.util.ArrayList;
53 import java.util.List;
54 import java.util.concurrent.CountDownLatch;
55 
56 /** Unit Tests of IsolatedService class. */
57 @SmallTest
58 @RunWith(AndroidJUnit4.class)
59 public class IsolatedServiceTest {
60     private static final String EVENT_TYPE_KEY = "event_type";
61     private final TestService mTestService = new TestService();
62     private final CountDownLatch mLatch = new CountDownLatch(1);
63     private IIsolatedService mBinder;
64     private boolean mOnExecuteCalled;
65     private boolean mOnDownloadCalled;
66     private boolean mOnRenderCalled;
67     private boolean mOnEventCalled;
68     private boolean mOnTrainingExampleCalled;
69     private boolean mOnWebTriggerCalled;
70     private Bundle mCallbackResult;
71     private int mCallbackErrorCode;
72     private int mIsolatedServiceErrorCode;
73 
74     @Before
setUp()75     public void setUp() {
76         mTestService.onCreate();
77         mBinder = IIsolatedService.Stub.asInterface(mTestService.onBind(null));
78     }
79 
80     @Test
testServiceThrowsIfOpcodeInvalid()81     public void testServiceThrowsIfOpcodeInvalid() throws Exception {
82         assertThrows(
83                 IllegalArgumentException.class,
84                 () -> {
85                     mBinder.onRequest(9999, new Bundle(), new TestServiceCallback());
86                 });
87     }
88 
89     @Test
testOnExecute()90     public void testOnExecute() throws Exception {
91         PersistableBundle appParams = new PersistableBundle();
92         appParams.putString("x", "y");
93         ExecuteInputParcel input =
94                 new ExecuteInputParcel.Builder()
95                         .setAppPackageName("com.testapp")
96                         .setSerializedAppParams(new ByteArrayParceledSlice(
97                                 PersistableBundleUtils.toByteArray(appParams)))
98                         .build();
99         Bundle params = new Bundle();
100         params.putParcelable(Constants.EXTRA_INPUT, input);
101         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
102         params.putBinder(
103                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
104                 new TestFederatedComputeService());
105         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
106         mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
107         mLatch.await();
108         assertTrue(mOnExecuteCalled);
109         ExecuteOutputParcel result =
110                 mCallbackResult.getParcelable(Constants.EXTRA_RESULT, ExecuteOutputParcel.class);
111         assertEquals(5, result.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue());
112         assertEquals("123", result.getRenderingConfig().getKeys().get(0));
113     }
114 
115     @Test
testOnExecutePropagatesError()116     public void testOnExecutePropagatesError() throws Exception {
117         PersistableBundle appParams = new PersistableBundle();
118         appParams.putInt("error", 1); // Trigger an error in the service.
119         ExecuteInputParcel input =
120                 new ExecuteInputParcel.Builder()
121                         .setAppPackageName("com.testapp")
122                         .setSerializedAppParams(new ByteArrayParceledSlice(
123                                 PersistableBundleUtils.toByteArray(appParams)))
124                         .build();
125         Bundle params = new Bundle();
126         params.putParcelable(Constants.EXTRA_INPUT, input);
127         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
128         params.putBinder(
129                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
130                 new TestFederatedComputeService());
131         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
132         mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
133         mLatch.await();
134         assertTrue(mOnExecuteCalled);
135         assertEquals(Constants.STATUS_SERVICE_FAILED, mCallbackErrorCode);
136         assertEquals(1, mIsolatedServiceErrorCode);
137     }
138 
139     @Test
testOnExecuteWithoutAppParams()140     public void testOnExecuteWithoutAppParams() throws Exception {
141         ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build();
142         Bundle params = new Bundle();
143         params.putParcelable(Constants.EXTRA_INPUT, input);
144         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
145         params.putBinder(
146                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
147                 new TestFederatedComputeService());
148         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
149         mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
150         mLatch.await();
151         assertTrue(mOnExecuteCalled);
152     }
153 
154     @Test
testOnExecuteThrowsIfParamsMissing()155     public void testOnExecuteThrowsIfParamsMissing() throws Exception {
156         assertThrows(
157                 NullPointerException.class,
158                 () -> {
159                     mBinder.onRequest(Constants.OP_EXECUTE, null, new TestServiceCallback());
160                 });
161     }
162 
163     @Test
testOnExecuteThrowsIfInputMissing()164     public void testOnExecuteThrowsIfInputMissing() throws Exception {
165         Bundle params = new Bundle();
166         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
167         params.putBinder(
168                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
169                 new TestFederatedComputeService());
170         mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
171         mLatch.await();
172         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
173     }
174 
175     @Test
testOnExecuteThrowsIfDataAccessServiceMissing()176     public void testOnExecuteThrowsIfDataAccessServiceMissing() throws Exception {
177         ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build();
178         Bundle params = new Bundle();
179         params.putBinder(
180                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
181                 new TestFederatedComputeService());
182         params.putParcelable(Constants.EXTRA_INPUT, input);
183         mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
184         mLatch.await();
185         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
186     }
187 
188     @Test
testOnExecuteThrowsIfFederatedComputeServiceMissing()189     public void testOnExecuteThrowsIfFederatedComputeServiceMissing() throws Exception {
190         ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build();
191         Bundle params = new Bundle();
192         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
193         params.putParcelable(Constants.EXTRA_INPUT, input);
194         mBinder.onRequest(Constants.OP_EXECUTE, params, new TestServiceCallback());
195         mLatch.await();
196         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
197     }
198 
199     @Test
testOnExecuteThrowsIfCallbackMissing()200     public void testOnExecuteThrowsIfCallbackMissing() throws Exception {
201         ExecuteInputParcel input = new ExecuteInputParcel.Builder().setAppPackageName("com.testapp").build();
202         Bundle params = new Bundle();
203         params.putParcelable(Constants.EXTRA_INPUT, input);
204         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
205         assertThrows(
206                 NullPointerException.class,
207                 () -> {
208                     mBinder.onRequest(Constants.OP_EXECUTE, params, null);
209                 });
210     }
211 
212     @Test
testOnDownload()213     public void testOnDownload() throws Exception {
214         DownloadInputParcel input =
215                 new DownloadInputParcel.Builder()
216                         .setDataAccessServiceBinder(new TestDataAccessService())
217                         .build();
218         Bundle params = new Bundle();
219         params.putParcelable(Constants.EXTRA_INPUT, input);
220         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
221         params.putBinder(
222                 Constants.EXTRA_FEDERATED_COMPUTE_SERVICE_BINDER,
223                 new TestFederatedComputeService());
224         mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback());
225         mLatch.await();
226         assertTrue(mOnDownloadCalled);
227         DownloadCompletedOutputParcel result =
228                 mCallbackResult.getParcelable(
229                         Constants.EXTRA_RESULT, DownloadCompletedOutputParcel.class);
230         assertEquals("12", result.getRetainedKeys().get(0));
231     }
232 
233     @Test
testOnDownloadThrowsIfParamsMissing()234     public void testOnDownloadThrowsIfParamsMissing() throws Exception {
235         assertThrows(
236                 NullPointerException.class,
237                 () -> {
238                     mBinder.onRequest(Constants.OP_DOWNLOAD, null, new TestServiceCallback());
239                 });
240     }
241 
242     @Test
testOnDownloadThrowsIfInputMissing()243     public void testOnDownloadThrowsIfInputMissing() throws Exception {
244         Bundle params = new Bundle();
245         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
246         mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback());
247         mLatch.await();
248         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
249     }
250 
251     @Test
testOnDownloadThrowsIfDataAccessServiceMissing()252     public void testOnDownloadThrowsIfDataAccessServiceMissing() throws Exception {
253         DownloadInputParcel input =
254                 new DownloadInputParcel.Builder()
255                         .setDataAccessServiceBinder(new TestDataAccessService())
256                         .build();
257         Bundle params = new Bundle();
258         params.putParcelable(Constants.EXTRA_INPUT, input);
259         mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback());
260         mLatch.await();
261         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
262     }
263 
264     @Test
testOnDownloadThrowsIfFederatedComputeServiceMissing()265     public void testOnDownloadThrowsIfFederatedComputeServiceMissing() throws Exception {
266         DownloadInputParcel input =
267                 new DownloadInputParcel.Builder()
268                         .setDataAccessServiceBinder(new TestDataAccessService())
269                         .build();
270         Bundle params = new Bundle();
271         params.putParcelable(Constants.EXTRA_INPUT, input);
272         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
273         mBinder.onRequest(Constants.OP_DOWNLOAD, params, new TestServiceCallback());
274         mLatch.await();
275         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
276     }
277 
278     @Test
testOnDownloadThrowsIfCallbackMissing()279     public void testOnDownloadThrowsIfCallbackMissing() throws Exception {
280         ParcelFileDescriptor[] pfds = ParcelFileDescriptor.createPipe();
281         DownloadInputParcel input =
282                 new DownloadInputParcel.Builder()
283                         .setDataAccessServiceBinder(new TestDataAccessService())
284                         .build();
285         Bundle params = new Bundle();
286         params.putParcelable(Constants.EXTRA_INPUT, input);
287         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
288         assertThrows(
289                 NullPointerException.class,
290                 () -> {
291                     mBinder.onRequest(Constants.OP_DOWNLOAD, params, null);
292                 });
293     }
294 
295     @Test
testOnRender()296     public void testOnRender() throws Exception {
297         RenderInputParcel input =
298                 new RenderInputParcel.Builder()
299                         .setRenderingConfig(
300                                 new RenderingConfig.Builder().addKey("a").addKey("b").build())
301                         .build();
302         Bundle params = new Bundle();
303         params.putParcelable(Constants.EXTRA_INPUT, input);
304         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
305         mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback());
306         mLatch.await();
307         assertTrue(mOnRenderCalled);
308         RenderOutputParcel result =
309                 mCallbackResult.getParcelable(Constants.EXTRA_RESULT, RenderOutputParcel.class);
310         assertEquals("htmlstring", result.getContent());
311     }
312 
313     @Test
testOnRenderPropagatesError()314     public void testOnRenderPropagatesError() throws Exception {
315         RenderInputParcel input =
316                 new RenderInputParcel.Builder()
317                         .setRenderingConfig(
318                                 new RenderingConfig.Builder()
319                                         .addKey("z") // Trigger error in service.
320                                         .build())
321                         .build();
322         Bundle params = new Bundle();
323         params.putParcelable(Constants.EXTRA_INPUT, input);
324         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
325         mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback());
326         mLatch.await();
327         assertTrue(mOnRenderCalled);
328         assertEquals(Constants.STATUS_SERVICE_FAILED, mCallbackErrorCode);
329     }
330 
331     @Test
testOnRenderThrowsIfParamsMissing()332     public void testOnRenderThrowsIfParamsMissing() throws Exception {
333         assertThrows(
334                 NullPointerException.class,
335                 () -> {
336                     mBinder.onRequest(Constants.OP_RENDER, null, new TestServiceCallback());
337                 });
338     }
339 
340     @Test
testOnRenderThrowsIfInputMissing()341     public void testOnRenderThrowsIfInputMissing() throws Exception {
342         Bundle params = new Bundle();
343         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
344         mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback());
345         mLatch.await();
346         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
347     }
348 
349     @Test
testOnRenderThrowsIfDataAccessServiceMissing()350     public void testOnRenderThrowsIfDataAccessServiceMissing() throws Exception {
351         RenderInputParcel input =
352                 new RenderInputParcel.Builder()
353                         .setRenderingConfig(
354                                 new RenderingConfig.Builder().addKey("a").addKey("b").build())
355                         .build();
356         Bundle params = new Bundle();
357         params.putParcelable(Constants.EXTRA_INPUT, input);
358         mBinder.onRequest(Constants.OP_RENDER, params, new TestServiceCallback());
359         mLatch.await();
360         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
361     }
362 
363     @Test
testOnRenderThrowsIfCallbackMissing()364     public void testOnRenderThrowsIfCallbackMissing() throws Exception {
365         RenderInputParcel input =
366                 new RenderInputParcel.Builder()
367                         .setRenderingConfig(
368                                 new RenderingConfig.Builder().addKey("a").addKey("b").build())
369                         .build();
370         Bundle params = new Bundle();
371         params.putParcelable(Constants.EXTRA_INPUT, input);
372         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
373         assertThrows(
374                 NullPointerException.class,
375                 () -> {
376                     mBinder.onRequest(Constants.OP_RENDER, params, null);
377                 });
378     }
379 
380     @Test
testOnEvent()381     public void testOnEvent() throws Exception {
382         Bundle params = new Bundle();
383         params.putParcelable(
384                 Constants.EXTRA_INPUT,
385                 new EventInputParcel.Builder().setParameters(PersistableBundle.EMPTY).build());
386         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
387         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
388         mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback());
389         mLatch.await();
390         assertTrue(mOnEventCalled);
391         EventOutputParcel result =
392                 mCallbackResult.getParcelable(Constants.EXTRA_RESULT, EventOutputParcel.class);
393         assertEquals(1, result.getEventLogRecord().getType());
394         assertEquals(2, result.getEventLogRecord().getRowIndex());
395     }
396 
397     @Test
testOnEventPropagatesError()398     public void testOnEventPropagatesError() throws Exception {
399         PersistableBundle eventParams = new PersistableBundle();
400         // Input value 9999 will trigger an error in the mock service.
401         eventParams.putInt(EVENT_TYPE_KEY, 9999);
402         Bundle params = new Bundle();
403         params.putParcelable(
404                 Constants.EXTRA_INPUT,
405                 new EventInputParcel.Builder().setParameters(eventParams).build());
406         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
407         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
408         mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback());
409         mLatch.await();
410         assertTrue(mOnEventCalled);
411         assertEquals(Constants.STATUS_SERVICE_FAILED, mCallbackErrorCode);
412     }
413 
414     @Test
testOnEventThrowsIfParamsMissing()415     public void testOnEventThrowsIfParamsMissing() throws Exception {
416         assertThrows(
417                 NullPointerException.class,
418                 () -> {
419                     mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, null, new TestServiceCallback());
420                 });
421     }
422 
423     @Test
testOnEventThrowsIfInputMissing()424     public void testOnEventThrowsIfInputMissing() throws Exception {
425         Bundle params = new Bundle();
426         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
427         mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback());
428         mLatch.await();
429         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
430     }
431 
432     @Test
testOnEventThrowsIfDataAccessServiceMissing()433     public void testOnEventThrowsIfDataAccessServiceMissing() throws Exception {
434         Bundle params = new Bundle();
435         params.putParcelable(
436                 Constants.EXTRA_INPUT,
437                 new EventInputParcel.Builder().setParameters(PersistableBundle.EMPTY).build());
438         mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, new TestServiceCallback());
439         mLatch.await();
440         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
441     }
442 
443     @Test
testOnEventThrowsIfCallbackMissing()444     public void testOnEventThrowsIfCallbackMissing() throws Exception {
445         Bundle params = new Bundle();
446         params.putParcelable(
447                 Constants.EXTRA_INPUT,
448                 new EventInputParcel.Builder().setParameters(PersistableBundle.EMPTY).build());
449         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
450         assertThrows(
451                 NullPointerException.class,
452                 () -> {
453                     mBinder.onRequest(Constants.OP_WEB_VIEW_EVENT, params, null);
454                 });
455     }
456 
457     @Test
testOnTrainingExamples()458     public void testOnTrainingExamples() throws Exception {
459         JoinedLogRecord joinedLogRecord = new JoinedLogRecord.Builder().build();
460         TrainingExamplesInputParcel input =
461                 new TrainingExamplesInputParcel.Builder()
462                         .setPopulationName("")
463                         .setTaskName("")
464                         .setResumptionToken(new byte[] {0})
465                         .build();
466         Bundle params = new Bundle();
467         params.putParcelable(Constants.EXTRA_INPUT, input);
468         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
469         mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback());
470         mLatch.await();
471         assertTrue(mOnTrainingExampleCalled);
472         TrainingExamplesOutputParcel result =
473                 mCallbackResult.getParcelable(
474                         Constants.EXTRA_RESULT, TrainingExamplesOutputParcel.class);
475         List<TrainingExampleRecord> examples = result.getTrainingExampleRecords().getList();
476         assertThat(examples).hasSize(1);
477         assertArrayEquals(new byte[] {12}, examples.get(0).getTrainingExample());
478         assertArrayEquals(new byte[] {13}, examples.get(0).getResumptionToken());
479     }
480 
481     @Test
testOnTrainingExampleThrowsIfParamsMissing()482     public void testOnTrainingExampleThrowsIfParamsMissing() throws Exception {
483         assertThrows(
484                 NullPointerException.class,
485                 () -> {
486                     mBinder.onRequest(
487                             Constants.OP_TRAINING_EXAMPLE, null, new TestServiceCallback());
488                 });
489     }
490 
491     @Test
testOnTrainingExampleThrowsIfDataAccessServiceMissing()492     public void testOnTrainingExampleThrowsIfDataAccessServiceMissing() throws Exception {
493         JoinedLogRecord joinedLogRecord = new JoinedLogRecord.Builder().build();
494         TrainingExamplesInputParcel input =
495                 new TrainingExamplesInputParcel.Builder()
496                         .setPopulationName("")
497                         .setTaskName("")
498                         .setResumptionToken(new byte[] {0})
499                         .build();
500         Bundle params = new Bundle();
501         params.putParcelable(Constants.EXTRA_INPUT, input);
502         mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback());
503         mLatch.await();
504         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
505     }
506 
507     @Test
testOnTrainingExampleThrowsIfCallbackMissing()508     public void testOnTrainingExampleThrowsIfCallbackMissing() throws Exception {
509         JoinedLogRecord joinedLogRecord = new JoinedLogRecord.Builder().build();
510         TrainingExamplesInputParcel input =
511                 new TrainingExamplesInputParcel.Builder()
512                         .setPopulationName("")
513                         .setTaskName("")
514                         .setResumptionToken(new byte[] {0})
515                         .build();
516         Bundle params = new Bundle();
517         params.putParcelable(Constants.EXTRA_INPUT, input);
518         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
519         mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, new TestServiceCallback());
520         assertThrows(
521                 NullPointerException.class,
522                 () -> {
523                     mBinder.onRequest(Constants.OP_TRAINING_EXAMPLE, params, null);
524                 });
525     }
526 
527     @Test
testOnWebTrigger()528     public void testOnWebTrigger() throws Exception {
529         WebTriggerInputParcel input =
530                 new WebTriggerInputParcel.Builder(
531                         Uri.parse("http://desturl"),
532                         "com.browser",
533                         new byte[] {1, 2, 3})
534                     .build();
535         Bundle params = new Bundle();
536         params.putParcelable(Constants.EXTRA_INPUT, input);
537         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
538         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
539         mBinder.onRequest(Constants.OP_WEB_TRIGGER, params, new TestServiceCallback());
540         mLatch.await();
541         assertTrue(mOnWebTriggerCalled);
542         WebTriggerOutputParcel result =
543                 mCallbackResult.getParcelable(Constants.EXTRA_RESULT, WebTriggerOutputParcel.class);
544         assertEquals(5, result.getRequestLogRecord().getRows().get(0).getAsInteger("a").intValue());
545     }
546 
547     @Test
testOnWebTriggerPropagatesError()548     public void testOnWebTriggerPropagatesError() throws Exception {
549         WebTriggerInputParcel input =
550                 new WebTriggerInputParcel.Builder(
551                         Uri.parse("http://error"),
552                         "com.browser",
553                         new byte[] {1, 2, 3})
554                     .build();
555         Bundle params = new Bundle();
556         params.putParcelable(Constants.EXTRA_INPUT, input);
557         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
558         params.putBinder(Constants.EXTRA_MODEL_SERVICE_BINDER, new TestIsolatedModelService());
559         mBinder.onRequest(Constants.OP_WEB_TRIGGER, params, new TestServiceCallback());
560         mLatch.await();
561         assertTrue(mOnWebTriggerCalled);
562         assertEquals(Constants.STATUS_SERVICE_FAILED, mCallbackErrorCode);
563     }
564 
565     @Test
testOnWebTriggerThrowsIfParamsMissing()566     public void testOnWebTriggerThrowsIfParamsMissing() throws Exception {
567         assertThrows(
568                 NullPointerException.class,
569                 () -> {
570                     mBinder.onRequest(Constants.OP_WEB_TRIGGER, null, new TestServiceCallback());
571                 });
572     }
573 
574     @Test
testOnWebTriggerThrowsIfInputMissing()575     public void testOnWebTriggerThrowsIfInputMissing() throws Exception {
576         Bundle params = new Bundle();
577         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
578         mBinder.onRequest(Constants.OP_WEB_TRIGGER, params, new TestServiceCallback());
579         mLatch.await();
580         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
581     }
582 
583     @Test
testOnWebTriggerThrowsIfDataAccessServiceMissing()584     public void testOnWebTriggerThrowsIfDataAccessServiceMissing() throws Exception {
585         WebTriggerInputParcel input =
586                 new WebTriggerInputParcel.Builder(
587                         Uri.parse("http://desturl"),
588                         "com.browser",
589                         new byte[] {1, 2, 3})
590                     .build();
591         Bundle params = new Bundle();
592         params.putParcelable(Constants.EXTRA_INPUT, input);
593         mBinder.onRequest(Constants.OP_WEB_TRIGGER, params, new TestServiceCallback());
594         mLatch.await();
595         assertEquals(Constants.STATUS_INTERNAL_ERROR, mCallbackErrorCode);
596     }
597 
598     @Test
testOnWebTriggerThrowsIfCallbackMissing()599     public void testOnWebTriggerThrowsIfCallbackMissing() throws Exception {
600         WebTriggerInputParcel input =
601                 new WebTriggerInputParcel.Builder(
602                         Uri.parse("http://desturl"),
603                         "com.browser",
604                         new byte[] {1, 2, 3})
605                     .build();
606         Bundle params = new Bundle();
607         params.putParcelable(Constants.EXTRA_INPUT, input);
608         params.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, new TestDataAccessService());
609         assertThrows(
610                 NullPointerException.class,
611                 () -> {
612                     mBinder.onRequest(Constants.OP_WEB_TRIGGER, params, null);
613                 });
614     }
615 
616     static class TestDataAccessService extends IDataAccessService.Stub {
617         @Override
onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)618         public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {}
619         @Override
logApiCallStats(int apiName, long latencyMillis, int responseCode)620         public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {}
621     }
622 
623     static class TestFederatedComputeService extends IFederatedComputeService.Stub {
624         @Override
schedule(TrainingOptions trainingOptions, IFederatedComputeCallback callback)625         public void schedule(TrainingOptions trainingOptions, IFederatedComputeCallback callback) {}
626 
cancel(String populationName, IFederatedComputeCallback callback)627         public void cancel(String populationName, IFederatedComputeCallback callback) {}
628     }
629 
630     class TestWorker implements IsolatedWorker {
631         @Override
onExecute( ExecuteInput input, OutcomeReceiver<ExecuteOutput, IsolatedServiceException> receiver)632         public void onExecute(
633                 ExecuteInput input,
634                 OutcomeReceiver<ExecuteOutput, IsolatedServiceException> receiver) {
635             mOnExecuteCalled = true;
636             if (input.getAppParams() != null && input.getAppParams().getInt("error") > 0) {
637                 receiver.onError(new IsolatedServiceException(1));
638             } else {
639                 ContentValues row = new ContentValues();
640                 row.put("a", 5);
641                 receiver.onResult(
642                         new ExecuteOutput.Builder()
643                                 .setRequestLogRecord(
644                                         new RequestLogRecord.Builder().addRow(row).build())
645                                 .setRenderingConfig(
646                                         new RenderingConfig.Builder().addKey("123").build())
647                                 .build());
648             }
649         }
650 
651         @Override
onDownloadCompleted( DownloadCompletedInput input, OutcomeReceiver<DownloadCompletedOutput, IsolatedServiceException> receiver)652         public void onDownloadCompleted(
653                 DownloadCompletedInput input,
654                 OutcomeReceiver<DownloadCompletedOutput, IsolatedServiceException> receiver) {
655             mOnDownloadCalled = true;
656             receiver.onResult(new DownloadCompletedOutput.Builder().addRetainedKey("12").build());
657         }
658 
659         @Override
onRender( RenderInput input, OutcomeReceiver<RenderOutput, IsolatedServiceException> receiver)660         public void onRender(
661                 RenderInput input,
662                 OutcomeReceiver<RenderOutput, IsolatedServiceException> receiver) {
663             mOnRenderCalled = true;
664             if (input.getRenderingConfig().getKeys().size() >= 1
665                     && input.getRenderingConfig().getKeys().get(0).equals("z")) {
666                 receiver.onError(new IsolatedServiceException(1));
667             } else {
668                 receiver.onResult(new RenderOutput.Builder().setContent("htmlstring").build());
669             }
670         }
671 
672         @Override
onEvent( EventInput input, OutcomeReceiver<EventOutput, IsolatedServiceException> receiver)673         public void onEvent(
674                 EventInput input,
675                 OutcomeReceiver<EventOutput, IsolatedServiceException> receiver) {
676             mOnEventCalled = true;
677             int eventType = input.getParameters().getInt(EVENT_TYPE_KEY);
678             if (eventType == 9999) {
679                 receiver.onError(new IsolatedServiceException(1));
680             } else {
681                 receiver.onResult(
682                         new EventOutput.Builder()
683                                 .setEventLogRecord(
684                                         new EventLogRecord.Builder()
685                                                 .setType(1)
686                                                 .setRowIndex(2)
687                                                 .setData(new ContentValues())
688                                                 .build())
689                                 .build());
690             }
691         }
692 
693         @Override
onTrainingExamples( TrainingExamplesInput input, OutcomeReceiver<TrainingExamplesOutput, IsolatedServiceException> receiver)694         public void onTrainingExamples(
695                 TrainingExamplesInput input,
696                 OutcomeReceiver<TrainingExamplesOutput, IsolatedServiceException> receiver) {
697             mOnTrainingExampleCalled = true;
698             List<TrainingExampleRecord> exampleRecordList = new ArrayList<>();
699             TrainingExampleRecord record =
700                     new TrainingExampleRecord.Builder()
701                             .setTrainingExample(new byte[] {12})
702                             .setResumptionToken(new byte[] {13})
703                             .build();
704             exampleRecordList.add(record);
705             receiver.onResult(
706                     new TrainingExamplesOutput.Builder()
707                             .setTrainingExampleRecords(exampleRecordList)
708                             .build());
709         }
710 
711         @Override
onWebTrigger( WebTriggerInput input, OutcomeReceiver<WebTriggerOutput, IsolatedServiceException> receiver)712         public void onWebTrigger(
713                 WebTriggerInput input,
714                 OutcomeReceiver<WebTriggerOutput, IsolatedServiceException> receiver) {
715             mOnWebTriggerCalled = true;
716             if (input.getDestinationUrl().toString().equals("http://error")) {
717                 receiver.onError(new IsolatedServiceException(1));
718             } else {
719                 ContentValues row = new ContentValues();
720                 row.put("a", 5);
721                 receiver.onResult(
722                         new WebTriggerOutput.Builder()
723                             .setRequestLogRecord(
724                                     new RequestLogRecord.Builder().addRow(row).build())
725                             .build());
726             }
727         }
728     }
729 
730     class TestService extends IsolatedService {
731         @Override
onRequest(RequestToken token)732         public IsolatedWorker onRequest(RequestToken token) {
733             return new TestWorker();
734         }
735     }
736 
737     class TestServiceCallback extends IIsolatedServiceCallback.Stub {
738         @Override
onSuccess(Bundle result)739         public void onSuccess(Bundle result) {
740             mCallbackResult = result;
741             mLatch.countDown();
742         }
743 
744         @Override
onError(int errorCode, int isolatedServiceErrorCode)745         public void onError(int errorCode, int isolatedServiceErrorCode) {
746             mCallbackErrorCode = errorCode;
747             mIsolatedServiceErrorCode = isolatedServiceErrorCode;
748             mLatch.countDown();
749         }
750     }
751 
752     class TestIsolatedModelService extends IIsolatedModelService.Stub {
753         @Override
runInference(Bundle params, IIsolatedModelServiceCallback callback)754         public void runInference(Bundle params, IIsolatedModelServiceCallback callback) {}
755     }
756 }
757