1 /* 2 * Copyright 2024 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.adservices.ondevicepersonalization; 18 19 import static com.google.common.truth.Truth.assertThat; 20 21 import static junit.framework.Assert.assertEquals; 22 23 import static org.junit.Assert.assertThrows; 24 25 import android.adservices.ondevicepersonalization.aidl.IDataAccessService; 26 import android.adservices.ondevicepersonalization.aidl.IDataAccessServiceCallback; 27 import android.os.Bundle; 28 29 import org.junit.Before; 30 import org.junit.Test; 31 32 import java.util.HashMap; 33 34 public class InferenceInputTest { 35 private static final String MODEL_KEY = "model_key"; 36 private RemoteDataImpl mRemoteData; 37 38 @Before setup()39 public void setup() { 40 mRemoteData = 41 new RemoteDataImpl( 42 IDataAccessService.Stub.asInterface(new TestDataAccessService())); 43 } 44 45 @Test buildParams_reusable()46 public void buildParams_reusable() { 47 InferenceInput.Params.Builder builder = 48 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY); 49 builder.build(); 50 51 InferenceInput.Params params = builder.setModelKey("other_kay").build(); 52 53 assertThat(params.getModelKey()).isEqualTo("other_kay"); 54 } 55 56 @Test buildInferenceInput_reusable()57 public void buildInferenceInput_reusable() { 58 HashMap<Integer, Object> outputData = new HashMap<>(); 59 outputData.put(0, new float[1]); 60 Object[] input = new Object[1]; 61 input[0] = new float[] {1.2f}; 62 InferenceInput.Params params = 63 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 64 65 InferenceInput.Builder builder = 66 new InferenceInput.Builder( 67 params, 68 input, 69 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 70 .setBatchSize(1); 71 builder.build(); 72 InferenceInput inferenceInput = builder.setBatchSize(10).build(); 73 assertThat(inferenceInput.getBatchSize()).isEqualTo(10); 74 } 75 76 @Test buildInput_success()77 public void buildInput_success() { 78 HashMap<Integer, Object> outputData = new HashMap<>(); 79 outputData.put(0, new float[1]); 80 Object[] input = new Object[1]; 81 input[0] = new float[] {1.2f}; 82 InferenceInput.Params params = 83 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 84 85 InferenceInput inferenceInput = 86 new InferenceInput.Builder( 87 params, 88 input, 89 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 90 .setBatchSize(1) 91 .build(); 92 93 float[] inputData = (float[]) inferenceInput.getInputData()[0]; 94 assertEquals(inputData[0], 1.2f); 95 assertThat(inferenceInput.getBatchSize()).isEqualTo(1); 96 assertThat(inferenceInput.getExpectedOutputStructure().getDataOutputs()).hasSize(1); 97 assertThat(inferenceInput.getParams()).isEqualTo(params); 98 } 99 100 @Test buildInput_batchNotSet_success()101 public void buildInput_batchNotSet_success() { 102 HashMap<Integer, Object> outputData = new HashMap<>(); 103 outputData.put(0, new float[1]); 104 Object[] input = new Object[1]; 105 input[0] = new float[] {1.2f}; 106 InferenceInput.Params params = 107 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 108 109 InferenceInput inferenceInput = 110 new InferenceInput.Builder( 111 params, 112 input, 113 new InferenceOutput.Builder().setDataOutputs(outputData).build()) 114 .build(); 115 116 assertThat(inferenceInput.getBatchSize()).isEqualTo(1); 117 } 118 119 @Test buildParams_success()120 public void buildParams_success() { 121 InferenceInput.Params params = 122 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY).build(); 123 124 assertThat(params.getRecommendedNumThreads()).isEqualTo(1); 125 assertThat(params.getDelegateType()).isEqualTo(InferenceInput.Params.DELEGATE_CPU); 126 assertThat(params.getModelType()) 127 .isEqualTo(InferenceInput.Params.MODEL_TYPE_TENSORFLOW_LITE); 128 assertThat(params.getKeyValueStore()).isEqualTo(mRemoteData); 129 assertThat(params.getModelKey()).isEqualTo(MODEL_KEY); 130 } 131 132 @Test buildParams_negativeThread_throws()133 public void buildParams_negativeThread_throws() { 134 assertThrows( 135 IllegalStateException.class, 136 () -> 137 new InferenceInput.Params.Builder(mRemoteData, MODEL_KEY) 138 .setRecommendedNumThreads(-2) 139 .build()); 140 } 141 142 @Test buildParams_nullModelKey_throws()143 public void buildParams_nullModelKey_throws() { 144 assertThrows( 145 NullPointerException.class, 146 () -> new InferenceInput.Params.Builder(mRemoteData, null).build()); 147 } 148 149 static class TestDataAccessService extends IDataAccessService.Stub { 150 @Override onRequest(int operation, Bundle params, IDataAccessServiceCallback callback)151 public void onRequest(int operation, Bundle params, IDataAccessServiceCallback callback) {} 152 @Override logApiCallStats(int apiName, long latencyMillis, int responseCode)153 public void logApiCallStats(int apiName, long latencyMillis, int responseCode) {} 154 } 155 } 156