1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.adservices.ondevicepersonalization;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 
21 import static junit.framework.Assert.assertEquals;
22 
23 import static org.junit.Assert.assertThrows;
24 
25 import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
26 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback;
27 import android.os.Bundle;
28 
29 import org.junit.Before;
30 import org.junit.Test;
31 
32 import java.util.HashMap;
33 
34 public class InferenceInputTest {
35     private static final String MODEL_KEY = "model_key";
36     private RemoteDataImpl mRemoteData;
37 
38     @Before
setup()39     public void setup() {
40         mRemoteData =
41                 new RemoteDataImpl(
42                         IDataAccessService.Stub.asInterface(new TestDataAccessService()));
43     }
44 
45     @Test
buildParams_reusable()46     public void buildParams_reusable() {
47         InferenceInput.Params.Builder builder =
48                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY);
49         builder.build();
50 
51         InferenceInput.Params params = builder.setModelKey("other_kay").build();
52 
53         assertThat(params.getModelKey()).isEqualTo("other_kay");
54     }
55 
56     @Test
buildInferenceInput_reusable()57     public void buildInferenceInput_reusable() {
58         HashMap<Integer, Object> outputData = new HashMap<>();
59         outputData.put(0, new float[1]);
60         Object[] input = new Object[1];
61         input[0] = new float[] {1.2f};
62         InferenceInput.Params params =
63                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
64 
65         InferenceInput.Builder builder =
66                 new InferenceInput.Builder(
67                                 params,
68                                 input,
69                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
70                         .setBatchSize(1);
71         builder.build();
72         InferenceInput inferenceInput = builder.setBatchSize(10).build();
73         assertThat(inferenceInput.getBatchSize()).isEqualTo(10);
74     }
75 
76     @Test
buildInput_success()77     public void buildInput_success() {
78         HashMap<Integer, Object> outputData = new HashMap<>();
79         outputData.put(0, new float[1]);
80         Object[] input = new Object[1];
81         input[0] = new float[] {1.2f};
82         InferenceInput.Params params =
83                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
84 
85         InferenceInput inferenceInput =
86                 new InferenceInput.Builder(
87                                 params,
88                                 input,
89                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
90                         .setBatchSize(1)
91                         .build();
92 
93         float[] inputData = (float[]) inferenceInput.getInputData()[0];
94         assertEquals(inputData[0], 1.2f);
95         assertThat(inferenceInput.getBatchSize()).isEqualTo(1);
96         assertThat(inferenceInput.getExpectedOutputStructure().getDataOutputs()).hasSize(1);
97         assertThat(inferenceInput.getParams()).isEqualTo(params);
98     }
99 
100     @Test
buildInput_batchNotSet_success()101     public void buildInput_batchNotSet_success() {
102         HashMap<Integer, Object> outputData = new HashMap<>();
103         outputData.put(0, new float[1]);
104         Object[] input = new Object[1];
105         input[0] = new float[] {1.2f};
106         InferenceInput.Params params =
107                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
108 
109         InferenceInput inferenceInput =
110                 new InferenceInput.Builder(
111                                 params,
112                                 input,
113                                 new InferenceOutput.Builder().setDataOutputs(outputData).build())
114                         .build();
115 
116         assertThat(inferenceInput.getBatchSize()).isEqualTo(1);
117     }
118 
119     @Test
buildParams_success()120     public void buildParams_success() {
121         InferenceInput.Params params =
122                 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build();
123 
124         assertThat(params.getRecommendedNumThreads()).isEqualTo(1);
125         assertThat(params.getDelegateType()).isEqualTo(InferenceInput.Params.DELEGATE_CPU);
126         assertThat(params.getModelType())
127                 .isEqualTo(InferenceInput.Params.MODEL_TYPE_TENSORFLOW_LITE);
128         assertThat(params.getKeyValueStore()).isEqualTo(mRemoteData);
129         assertThat(params.getModelKey()).isEqualTo(MODEL_KEY);
130     }
131 
132     @Test
buildParams_negativeThread_throws()133     public void buildParams_negativeThread_throws() {
134         assertThrows(
135                 IllegalStateException.class,
136                 () ->
137                         new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY)
138                                 .setRecommendedNumThreads(-2)
139                                 .build());
140     }
141 
142     @Test
buildParams_nullModelKey_throws()143     public void buildParams_nullModelKey_throws() {
144         assertThrows(
145                 NullPointerException.class,
146                 () -> new InferenceInput.Params.Builder(mRemoteData, null).build());
147     }
148 
149     static class TestDataAccessService extends IDataAccessService.Stub {
150         @Override
onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)151         public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {}
152         @Override
logApiCallStats(int apiName, long latencyMillis, int responseCode)153         public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {}
154     }
155 }
156