1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.federatedcompute;
18 
19 import static android.federatedcompute.common.ClientConstants.EXTRA_TASK_ID;
20 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR;
21 
22 import static com.google.common.truth.Truth.assertThat;
23 
24 import static org.junit.Assert.assertTrue;
25 
26 import android.federatedcompute.ExampleStoreQueryCallbackImpl.IteratorAdapter;
27 import android.federatedcompute.aidl.IExampleStoreCallback;
28 import android.federatedcompute.aidl.IExampleStoreIterator;
29 import android.federatedcompute.aidl.IExampleStoreService;
30 import android.os.Bundle;
31 
32 import androidx.test.ext.junit.runners.AndroidJUnit4;
33 
34 import com.google.common.collect.ImmutableList;
35 import com.google.protobuf.ByteString;
36 
37 import org.junit.Before;
38 import org.junit.Test;
39 import org.junit.runner.RunWith;
40 import org.tensorflow.example.BytesList;
41 import org.tensorflow.example.Example;
42 import org.tensorflow.example.Feature;
43 import org.tensorflow.example.Features;
44 
45 import java.util.Iterator;
46 import java.util.List;
47 import java.util.concurrent.CountDownLatch;
48 
49 import javax.annotation.Nonnull;
50 
51 @RunWith(AndroidJUnit4.class)
52 public class ExampleStoreServiceTest {
53     private static final String EXPECTED_TASK_NAME = "federated_task";
54     private static final Example EXAMPLE_PROTO_1 =
55             Example.newBuilder()
56                     .setFeatures(
57                             Features.newBuilder()
58                                     .putFeature(
59                                             "feature1",
60                                             Feature.newBuilder()
61                                                     .setBytesList(
62                                                             BytesList.newBuilder()
63                                                                     .addValue(
64                                                                             ByteString.copyFromUtf8(
65                                                                                     "f1_value1")))
66                                                     .build()))
67                     .build();
68     private IExampleStoreIterator mCallbackResult;
69     private int mCallbackErrorCode;
70     private boolean mStartQueryCalled;
71     private final CountDownLatch mLatch = new CountDownLatch(1);
72     private final TestJavaExampleStoreService mTestExampleStoreService =
73             new TestJavaExampleStoreService();
74     private IExampleStoreService mBinder;
75 
76     @Before
doBeforeEachTest()77     public void doBeforeEachTest() {
78         mTestExampleStoreService.onCreate();
79         mBinder = IExampleStoreService.Stub.asInterface(mTestExampleStoreService.onBind(null));
80     }
81 
82     @Test
testStartQuerySuccess()83     public void testStartQuerySuccess() throws Exception {
84         Bundle bundle = new Bundle();
85         bundle.putString(EXTRA_TASK_ID, EXPECTED_TASK_NAME);
86         mBinder.startQuery(bundle, new TestJavaExampleStoreServiceCallback());
87         mLatch.await();
88         assertTrue(mStartQueryCalled);
89         assertThat(mCallbackResult).isInstanceOf(IteratorAdapter.class);
90     }
91 
92     @Test
testStartQueryFailure()93     public void testStartQueryFailure() throws Exception {
94         Bundle bundle = new Bundle();
95         bundle.putString(EXTRA_TASK_ID, "wrong_taskName");
96         mBinder.startQuery(bundle, new TestJavaExampleStoreServiceCallback());
97         mLatch.await();
98         assertTrue(mStartQueryCalled);
99         assertThat(mCallbackErrorCode).isEqualTo(STATUS_INTERNAL_ERROR);
100         assertThat(mCallbackErrorCode).isEqualTo(STATUS_INTERNAL_ERROR);
101     }
102 
103     class TestJavaExampleStoreService extends ExampleStoreService {
104         @Override
startQuery(@onnull Bundle params, @Nonnull QueryCallback callback)105         public void startQuery(@Nonnull Bundle params, @Nonnull QueryCallback callback) {
106             mStartQueryCalled = true;
107             String taskName = params.getString(EXTRA_TASK_ID);
108             if (!taskName.equals(EXPECTED_TASK_NAME)) {
109                 callback.onStartQueryFailure(STATUS_INTERNAL_ERROR);
110                 return;
111             }
112             callback.onStartQuerySuccess(
113                     new ListJavaExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1)));
114         }
115 
116         @Override
checkCallerPermission()117         protected boolean checkCallerPermission() {
118             return true;
119         }
120     }
121 
122     /**
123      * A simple {@link ExampleStoreIterator} that returns the contents of the {@link List} it's
124      * constructed with.
125      */
126     private static class ListJavaExampleStoreIterator implements ExampleStoreIterator {
127         private final Iterator<Example> mExampleIterator;
128 
ListJavaExampleStoreIterator(List<Example> examples)129         ListJavaExampleStoreIterator(List<Example> examples) {
130             mExampleIterator = examples.iterator();
131         }
132 
133         @Override
next(IteratorCallback callback)134         public synchronized void next(IteratorCallback callback) {
135             callback.onIteratorNextSuccess(null);
136         }
137 
138         @Override
close()139         public void close() {}
140     }
141 
142     class TestJavaExampleStoreServiceCallback extends IExampleStoreCallback.Stub {
143         @Override
onStartQuerySuccess(IExampleStoreIterator iterator)144         public void onStartQuerySuccess(IExampleStoreIterator iterator) {
145             mCallbackResult = iterator;
146             mLatch.countDown();
147         }
148 
149         @Override
onStartQueryFailure(int errorCode)150         public void onStartQueryFailure(int errorCode) {
151             mCallbackErrorCode = errorCode;
152             mLatch.countDown();
153         }
154     }
155 }
156