1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.federatedcompute; 18 19 import static android.federatedcompute.common.ClientConstants.EXTRA_TASK_ID; 20 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; 21 22 import static com.google.common.truth.Truth.assertThat; 23 24 import static org.junit.Assert.assertTrue; 25 26 import android.federatedcompute.ExampleStoreQueryCallbackImpl.IteratorAdapter; 27 import android.federatedcompute.aidl.IExampleStoreCallback; 28 import android.federatedcompute.aidl.IExampleStoreIterator; 29 import android.federatedcompute.aidl.IExampleStoreService; 30 import android.os.Bundle; 31 32 import androidx.test.ext.junit.runners.AndroidJUnit4; 33 34 import com.google.common.collect.ImmutableList; 35 import com.google.protobuf.ByteString; 36 37 import org.junit.Before; 38 import org.junit.Test; 39 import org.junit.runner.RunWith; 40 import org.tensorflow.example.BytesList; 41 import org.tensorflow.example.Example; 42 import org.tensorflow.example.Feature; 43 import org.tensorflow.example.Features; 44 45 import java.util.Iterator; 46 import java.util.List; 47 import java.util.concurrent.CountDownLatch; 48 49 import javax.annotation.Nonnull; 50 51 @RunWith(AndroidJUnit4.class) 52 public class ExampleStoreServiceTest { 53 private static final String EXPECTED_TASK_NAME = "federated_task"; 54 private static final Example EXAMPLE_PROTO_1 = 55 Example.newBuilder() 56 .setFeatures( 57 Features.newBuilder() 58 .putFeature( 59 "feature1", 60 Feature.newBuilder() 61 .setBytesList( 62 BytesList.newBuilder() 63 .addValue( 64 ByteString.copyFromUtf8( 65 "f1_value1"))) 66 .build())) 67 .build(); 68 private IExampleStoreIterator mCallbackResult; 69 private int mCallbackErrorCode; 70 private boolean mStartQueryCalled; 71 private final CountDownLatch mLatch = new CountDownLatch(1); 72 private final TestJavaExampleStoreService mTestExampleStoreService = 73 new TestJavaExampleStoreService(); 74 private IExampleStoreService mBinder; 75 76 @Before doBeforeEachTest()77 public void doBeforeEachTest() { 78 mTestExampleStoreService.onCreate(); 79 mBinder = IExampleStoreService.Stub.asInterface(mTestExampleStoreService.onBind(null)); 80 } 81 82 @Test testStartQuerySuccess()83 public void testStartQuerySuccess() throws Exception { 84 Bundle bundle = new Bundle(); 85 bundle.putString(EXTRA_TASK_ID, EXPECTED_TASK_NAME); 86 mBinder.startQuery(bundle, new TestJavaExampleStoreServiceCallback()); 87 mLatch.await(); 88 assertTrue(mStartQueryCalled); 89 assertThat(mCallbackResult).isInstanceOf(IteratorAdapter.class); 90 } 91 92 @Test testStartQueryFailure()93 public void testStartQueryFailure() throws Exception { 94 Bundle bundle = new Bundle(); 95 bundle.putString(EXTRA_TASK_ID, "wrong_taskName"); 96 mBinder.startQuery(bundle, new TestJavaExampleStoreServiceCallback()); 97 mLatch.await(); 98 assertTrue(mStartQueryCalled); 99 assertThat(mCallbackErrorCode).isEqualTo(STATUS_INTERNAL_ERROR); 100 assertThat(mCallbackErrorCode).isEqualTo(STATUS_INTERNAL_ERROR); 101 } 102 103 class TestJavaExampleStoreService extends ExampleStoreService { 104 @Override startQuery(@onnull Bundle params, @Nonnull QueryCallback callback)105 public void startQuery(@Nonnull Bundle params, @Nonnull QueryCallback callback) { 106 mStartQueryCalled = true; 107 String taskName = params.getString(EXTRA_TASK_ID); 108 if (!taskName.equals(EXPECTED_TASK_NAME)) { 109 callback.onStartQueryFailure(STATUS_INTERNAL_ERROR); 110 return; 111 } 112 callback.onStartQuerySuccess( 113 new ListJavaExampleStoreIterator(ImmutableList.of(EXAMPLE_PROTO_1))); 114 } 115 116 @Override checkCallerPermission()117 protected boolean checkCallerPermission() { 118 return true; 119 } 120 } 121 122 /** 123 * A simple {@link ExampleStoreIterator} that returns the contents of the {@link List} it's 124 * constructed with. 125 */ 126 private static class ListJavaExampleStoreIterator implements ExampleStoreIterator { 127 private final Iterator<Example> mExampleIterator; 128 ListJavaExampleStoreIterator(List<Example> examples)129 ListJavaExampleStoreIterator(List<Example> examples) { 130 mExampleIterator = examples.iterator(); 131 } 132 133 @Override next(IteratorCallback callback)134 public synchronized void next(IteratorCallback callback) { 135 callback.onIteratorNextSuccess(null); 136 } 137 138 @Override close()139 public void close() {} 140 } 141 142 class TestJavaExampleStoreServiceCallback extends IExampleStoreCallback.Stub { 143 @Override onStartQuerySuccess(IExampleStoreIterator iterator)144 public void onStartQuerySuccess(IExampleStoreIterator iterator) { 145 mCallbackResult = iterator; 146 mLatch.countDown(); 147 } 148 149 @Override onStartQueryFailure(int errorCode)150 public void onStartQueryFailure(int errorCode) { 151 mCallbackErrorCode = errorCode; 152 mLatch.countDown(); 153 } 154 } 155 } 156