1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.federatedcompute.services; 18 19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR; 20 import static android.federatedcompute.common.ClientConstants.STATUS_KILL_SWITCH_ENABLED; 21 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS; 22 23 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doAnswer; 24 import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL; 25 import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE; 26 27 import static com.google.common.truth.Truth.assertThat; 28 29 import static org.junit.Assert.assertThrows; 30 import static org.mockito.ArgumentMatchers.any; 31 import static org.mockito.ArgumentMatchers.anyString; 32 import static org.mockito.Mockito.spy; 33 import static org.mockito.Mockito.verify; 34 import static org.mockito.Mockito.when; 35 36 import android.content.ComponentName; 37 import android.content.Context; 38 import android.federatedcompute.aidl.IFederatedComputeCallback; 39 import android.federatedcompute.common.TrainingOptions; 40 41 import androidx.test.core.app.ApplicationProvider; 42 43 import com.android.federatedcompute.services.common.PhFlagsTestUtil; 44 import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager; 45 import com.android.federatedcompute.services.statsd.ApiCallStats; 46 import com.android.federatedcompute.services.statsd.FederatedComputeStatsdLogger; 47 import com.android.odp.module.common.Clock; 48 49 import org.junit.Before; 50 import org.junit.Test; 51 import org.junit.runner.RunWith; 52 import org.junit.runners.JUnit4; 53 import org.mockito.ArgumentCaptor; 54 import org.mockito.Mock; 55 import org.mockito.MockitoAnnotations; 56 import org.mockito.invocation.InvocationOnMock; 57 import org.mockito.stubbing.Answer; 58 59 import java.util.concurrent.CountDownLatch; 60 import java.util.concurrent.TimeUnit; 61 62 @RunWith(JUnit4.class) 63 public final class FederatedComputeManagingServiceDelegateTest { 64 private static final int BINDER_CONNECTION_TIMEOUT_MS = 10_000; 65 66 private static final String CALLING_PACKAGE_NAME = "callingPkg"; 67 private static final String CALLING_CLASS_NAME = "callingClass"; 68 69 public static final ComponentName OWNER_COMPONENT_NAME = 70 ComponentName.createRelative(CALLING_PACKAGE_NAME, CALLING_CLASS_NAME); 71 72 private FederatedComputeManagingServiceDelegate mFcpService; 73 private Context mContext; 74 private final FederatedComputeStatsdLogger mFcStatsdLogger = 75 spy(FederatedComputeStatsdLogger.getInstance()); 76 77 @Mock FederatedComputeJobManager mMockJobManager; 78 @Mock private Clock mClock; 79 80 @Before setUp()81 public void setUp() throws Exception { 82 MockitoAnnotations.initMocks(this); 83 PhFlagsTestUtil.setUpDeviceConfigPermissions(); 84 PhFlagsTestUtil.disableGlobalKillSwitch(); 85 86 mContext = ApplicationProvider.getApplicationContext(); 87 mFcpService = 88 new FederatedComputeManagingServiceDelegate( 89 mContext, new TestInjector(), mFcStatsdLogger, mClock); 90 when(mClock.elapsedRealtime()).thenReturn(100L, 200L); 91 } 92 93 @Test testScheduleMissingPackageName_throwsException()94 public void testScheduleMissingPackageName_throwsException() { 95 TrainingOptions trainingOptions = 96 new TrainingOptions.Builder() 97 .setPopulationName("fake-population") 98 .setOwnerComponentName(OWNER_COMPONENT_NAME) 99 .build(); 100 101 assertThrows( 102 NullPointerException.class, 103 () -> mFcpService.schedule(null, trainingOptions, new FederatedComputeCallback())); 104 } 105 106 @Test testScheduleMissingCallback_throwsException()107 public void testScheduleMissingCallback_throwsException() { 108 TrainingOptions trainingOptions = 109 new TrainingOptions.Builder() 110 .setPopulationName("fake-population") 111 .setOwnerComponentName(OWNER_COMPONENT_NAME) 112 .build(); 113 assertThrows( 114 NullPointerException.class, 115 () -> mFcpService.schedule(mContext.getPackageName(), trainingOptions, null)); 116 } 117 118 @Test testSchedule_returnsSuccess()119 public void testSchedule_returnsSuccess() throws Exception { 120 when(mMockJobManager.onTrainerStartCalled(anyString(), any())).thenReturn(STATUS_SUCCESS); 121 122 TrainingOptions trainingOptions = 123 new TrainingOptions.Builder() 124 .setPopulationName("fake-population") 125 .setOwnerComponentName(OWNER_COMPONENT_NAME) 126 .build(); 127 invokeScheduleAndVerifyLogging(trainingOptions, STATUS_SUCCESS); 128 } 129 130 @Test testScheduleFailed()131 public void testScheduleFailed() throws Exception { 132 when(mMockJobManager.onTrainerStartCalled(anyString(), any())) 133 .thenReturn(STATUS_INTERNAL_ERROR); 134 135 TrainingOptions trainingOptions = 136 new TrainingOptions.Builder() 137 .setPopulationName("fake-population") 138 .setOwnerComponentName(OWNER_COMPONENT_NAME) 139 .build(); 140 invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR); 141 } 142 143 @Test testScheduleThrowsRTE()144 public void testScheduleThrowsRTE() throws Exception { 145 when(mMockJobManager.onTrainerStartCalled(anyString(), any())) 146 .thenThrow(RuntimeException.class); 147 148 TrainingOptions trainingOptions = 149 new TrainingOptions.Builder() 150 .setPopulationName("fake-population") 151 .setOwnerComponentName(OWNER_COMPONENT_NAME) 152 .build(); 153 invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR); 154 } 155 156 @Test testScheduleThrowsNPE()157 public void testScheduleThrowsNPE() throws Exception { 158 when(mMockJobManager.onTrainerStartCalled(anyString(), any())) 159 .thenThrow(NullPointerException.class); 160 161 TrainingOptions trainingOptions = 162 new TrainingOptions.Builder() 163 .setPopulationName("fake-population") 164 .setOwnerComponentName(OWNER_COMPONENT_NAME) 165 .build(); 166 invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR); 167 } 168 169 @Test testScheduleClockThrowsRTE()170 public void testScheduleClockThrowsRTE() throws Exception { 171 when(mClock.elapsedRealtime()).thenThrow(RuntimeException.class); 172 TrainingOptions trainingOptions = 173 new TrainingOptions.Builder() 174 .setPopulationName("fake-population") 175 .setOwnerComponentName(OWNER_COMPONENT_NAME) 176 .build(); 177 FederatedComputeCallback callback = spy(new FederatedComputeCallback()); 178 179 mFcpService.schedule(mContext.getPackageName(), trainingOptions, callback); 180 181 ArgumentCaptor<Integer> argument = ArgumentCaptor.forClass(Integer.class); 182 verify(callback).onFailure(argument.capture()); 183 assertThat(argument.getValue()).isEqualTo(STATUS_INTERNAL_ERROR); 184 } 185 186 @Test testScheduleClockThrowsIAE()187 public void testScheduleClockThrowsIAE() throws Exception { 188 when(mClock.elapsedRealtime()).thenThrow(IllegalArgumentException.class); 189 190 TrainingOptions trainingOptions = 191 new TrainingOptions.Builder() 192 .setPopulationName("fake-population") 193 .setOwnerComponentName(OWNER_COMPONENT_NAME) 194 .build(); 195 assertThrows( 196 IllegalArgumentException.class, 197 () -> 198 mFcpService.schedule( 199 mContext.getPackageName(), 200 trainingOptions, 201 new FederatedComputeCallback())); 202 } 203 204 @Test testScheduleEnabledGlobalKillSwitch_returnsError()205 public void testScheduleEnabledGlobalKillSwitch_returnsError() throws Exception { 206 PhFlagsTestUtil.enableGlobalKillSwitch(); 207 try { 208 TrainingOptions trainingOptions = 209 new TrainingOptions.Builder() 210 .setPopulationName("fake-population") 211 .setOwnerComponentName(OWNER_COMPONENT_NAME) 212 .build(); 213 invokeScheduleAndVerifyLogging(trainingOptions, STATUS_KILL_SWITCH_ENABLED, 0); 214 } finally { 215 PhFlagsTestUtil.disableGlobalKillSwitch(); 216 } 217 } 218 219 @Test testCancelMissingPackageName_throwsException()220 public void testCancelMissingPackageName_throwsException() { 221 assertThrows( 222 NullPointerException.class, 223 () -> mFcpService.cancel(null, "fake-population", new FederatedComputeCallback())); 224 } 225 226 @Test testCancelMissingCallback_throwsException()227 public void testCancelMissingCallback_throwsException() { 228 assertThrows( 229 NullPointerException.class, 230 () -> mFcpService.cancel(OWNER_COMPONENT_NAME, "fake-population", null)); 231 } 232 233 @Test testCancel_returnsSuccess()234 public void testCancel_returnsSuccess() throws Exception { 235 when(mMockJobManager.onTrainerStopCalled(any(), anyString())).thenReturn(STATUS_SUCCESS); 236 237 invokeCancelAndVerifyLogging("fake-population", STATUS_SUCCESS); 238 } 239 240 @Test testCancelFails()241 public void testCancelFails() throws Exception { 242 when(mMockJobManager.onTrainerStopCalled(any(), any())).thenReturn(STATUS_INTERNAL_ERROR); 243 244 invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR); 245 } 246 247 @Test testCancelEnabledGlobalKillSwitch_returnsError()248 public void testCancelEnabledGlobalKillSwitch_returnsError() throws Exception { 249 PhFlagsTestUtil.enableGlobalKillSwitch(); 250 try { 251 invokeCancelAndVerifyLogging("fake-population", STATUS_KILL_SWITCH_ENABLED, 0); 252 } finally { 253 PhFlagsTestUtil.disableGlobalKillSwitch(); 254 } 255 } 256 257 @Test testCancelThrowsRTE()258 public void testCancelThrowsRTE() throws Exception { 259 when(mMockJobManager.onTrainerStopCalled(any(), any())).thenThrow(RuntimeException.class); 260 261 invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR); 262 } 263 264 @Test testCancelThrowsNPE()265 public void testCancelThrowsNPE() throws Exception { 266 when(mMockJobManager.onTrainerStopCalled(any(), any())) 267 .thenThrow(NullPointerException.class); 268 269 invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR); 270 } 271 272 @Test testCancelClockThrowsRTE()273 public void testCancelClockThrowsRTE() throws Exception { 274 when(mClock.elapsedRealtime()).thenThrow(RuntimeException.class); 275 FederatedComputeCallback callback = spy(new FederatedComputeCallback()); 276 277 mFcpService.cancel(OWNER_COMPONENT_NAME, "fake-population", callback); 278 279 ArgumentCaptor<Integer> argument = ArgumentCaptor.forClass(Integer.class); 280 verify(callback).onFailure(argument.capture()); 281 assertThat(argument.getValue()).isEqualTo(STATUS_INTERNAL_ERROR); 282 } 283 284 @Test testCancelClockThrowsIAE()285 public void testCancelClockThrowsIAE() throws Exception { 286 when(mClock.elapsedRealtime()).thenThrow(IllegalArgumentException.class); 287 288 assertThrows( 289 IllegalArgumentException.class, 290 () -> 291 mFcpService.cancel( 292 OWNER_COMPONENT_NAME, 293 "fake-population", 294 new FederatedComputeCallback())); 295 } 296 invokeScheduleAndVerifyLogging( TrainingOptions trainingOptions, int expectedResultCode)297 private void invokeScheduleAndVerifyLogging( 298 TrainingOptions trainingOptions, int expectedResultCode) throws InterruptedException { 299 invokeScheduleAndVerifyLogging(trainingOptions, expectedResultCode, 100L); 300 } 301 invokeScheduleAndVerifyLogging( TrainingOptions trainingOptions, int expectedResultCode, long latency)302 private void invokeScheduleAndVerifyLogging( 303 TrainingOptions trainingOptions, int expectedResultCode, long latency) 304 throws InterruptedException { 305 ArgumentCaptor<ApiCallStats> argument = ArgumentCaptor.forClass(ApiCallStats.class); 306 final CountDownLatch logOperationCalledLatch = new CountDownLatch(1); 307 doAnswer( 308 new Answer<Object>() { 309 @Override 310 public Object answer(InvocationOnMock invocation) throws Throwable { 311 // The method logAPiCallStats is called. 312 invocation.callRealMethod(); 313 logOperationCalledLatch.countDown(); 314 return null; 315 } 316 }) 317 .when(mFcStatsdLogger) 318 .logApiCallStats(argument.capture()); 319 320 var callback = new FederatedComputeCallback(); 321 mFcpService.schedule(mContext.getPackageName(), trainingOptions, callback); 322 323 callback.mJobFinishCountDown.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); 324 logOperationCalledLatch.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); 325 326 assertThat(argument.getValue().getResponseCode()).isEqualTo(expectedResultCode); 327 assertThat(argument.getValue().getLatencyMillis()).isEqualTo(latency); 328 assertThat(argument.getValue().getApiName()) 329 .isEqualTo(FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE); 330 } 331 invokeCancelAndVerifyLogging(String populationName, int expectedResultCode)332 private void invokeCancelAndVerifyLogging(String populationName, int expectedResultCode) 333 throws InterruptedException { 334 invokeCancelAndVerifyLogging(populationName, expectedResultCode, 100); 335 } 336 invokeCancelAndVerifyLogging( String populationName, int expectedResultCode, long latency)337 private void invokeCancelAndVerifyLogging( 338 String populationName, int expectedResultCode, long latency) 339 throws InterruptedException { 340 341 final CountDownLatch logOperationCalledLatch = new CountDownLatch(1); 342 ArgumentCaptor<ApiCallStats> argument = ArgumentCaptor.forClass(ApiCallStats.class); 343 doAnswer( 344 new Answer<Object>() { 345 @Override 346 public Object answer(InvocationOnMock invocation) throws Throwable { 347 // The method logAPiCallStats is called. 348 invocation.callRealMethod(); 349 logOperationCalledLatch.countDown(); 350 return null; 351 } 352 }) 353 .when(mFcStatsdLogger) 354 .logApiCallStats(argument.capture()); 355 var callback = new FederatedComputeCallback(); 356 mFcpService.cancel(OWNER_COMPONENT_NAME, populationName, callback); 357 358 callback.mJobFinishCountDown.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); 359 logOperationCalledLatch.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); 360 361 assertThat(argument.getValue().getResponseCode()).isEqualTo(expectedResultCode); 362 assertThat(argument.getValue().getLatencyMillis()).isEqualTo(latency); 363 assertThat(argument.getValue().getApiName()) 364 .isEqualTo(FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL); 365 } 366 367 static class FederatedComputeCallback extends IFederatedComputeCallback.Stub { 368 public boolean mError = false; 369 public int mErrorCode = 0; 370 private final CountDownLatch mJobFinishCountDown = new CountDownLatch(1); 371 372 @Override onSuccess()373 public void onSuccess() { 374 mJobFinishCountDown.countDown(); 375 } 376 377 @Override onFailure(int errorCode)378 public void onFailure(int errorCode) { 379 mError = true; 380 mErrorCode = errorCode; 381 mJobFinishCountDown.countDown(); 382 } 383 } 384 385 class TestInjector extends FederatedComputeManagingServiceDelegate.Injector { getJobManager(Context mContext)386 FederatedComputeJobManager getJobManager(Context mContext) { 387 return mMockJobManager; 388 } 389 } 390 } 391