1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.federatedcompute.services;
18 
19 import static android.federatedcompute.common.ClientConstants.STATUS_INTERNAL_ERROR;
20 import static android.federatedcompute.common.ClientConstants.STATUS_KILL_SWITCH_ENABLED;
21 import static android.federatedcompute.common.ClientConstants.STATUS_SUCCESS;
22 
23 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doAnswer;
24 import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL;
25 import static com.android.federatedcompute.services.stats.FederatedComputeStatsLog.FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE;
26 
27 import static com.google.common.truth.Truth.assertThat;
28 
29 import static org.junit.Assert.assertThrows;
30 import static org.mockito.ArgumentMatchers.any;
31 import static org.mockito.ArgumentMatchers.anyString;
32 import static org.mockito.Mockito.spy;
33 import static org.mockito.Mockito.verify;
34 import static org.mockito.Mockito.when;
35 
36 import android.content.ComponentName;
37 import android.content.Context;
38 import android.federatedcompute.aidl.IFederatedComputeCallback;
39 import android.federatedcompute.common.TrainingOptions;
40 
41 import androidx.test.core.app.ApplicationProvider;
42 
43 import com.android.federatedcompute.services.common.PhFlagsTestUtil;
44 import com.android.federatedcompute.services.scheduling.FederatedComputeJobManager;
45 import com.android.federatedcompute.services.statsd.ApiCallStats;
46 import com.android.federatedcompute.services.statsd.FederatedComputeStatsdLogger;
47 import com.android.odp.module.common.Clock;
48 
49 import org.junit.Before;
50 import org.junit.Test;
51 import org.junit.runner.RunWith;
52 import org.junit.runners.JUnit4;
53 import org.mockito.ArgumentCaptor;
54 import org.mockito.Mock;
55 import org.mockito.MockitoAnnotations;
56 import org.mockito.invocation.InvocationOnMock;
57 import org.mockito.stubbing.Answer;
58 
59 import java.util.concurrent.CountDownLatch;
60 import java.util.concurrent.TimeUnit;
61 
62 @RunWith(JUnit4.class)
63 public final class FederatedComputeManagingServiceDelegateTest {
64     private static final int BINDER_CONNECTION_TIMEOUT_MS = 10_000;
65 
66     private static final String CALLING_PACKAGE_NAME = "callingPkg";
67     private static final String CALLING_CLASS_NAME = "callingClass";
68 
69     public static final ComponentName OWNER_COMPONENT_NAME =
70             ComponentName.createRelative(CALLING_PACKAGE_NAME, CALLING_CLASS_NAME);
71 
72     private FederatedComputeManagingServiceDelegate mFcpService;
73     private Context mContext;
74     private final FederatedComputeStatsdLogger mFcStatsdLogger =
75             spy(FederatedComputeStatsdLogger.getInstance());
76 
77     @Mock FederatedComputeJobManager mMockJobManager;
78     @Mock private Clock mClock;
79 
80     @Before
setUp()81     public void setUp() throws Exception {
82         MockitoAnnotations.initMocks(this);
83         PhFlagsTestUtil.setUpDeviceConfigPermissions();
84         PhFlagsTestUtil.disableGlobalKillSwitch();
85 
86         mContext = ApplicationProvider.getApplicationContext();
87         mFcpService =
88                 new FederatedComputeManagingServiceDelegate(
89                         mContext, new TestInjector(), mFcStatsdLogger, mClock);
90         when(mClock.elapsedRealtime()).thenReturn(100L, 200L);
91     }
92 
93     @Test
testScheduleMissingPackageName_throwsException()94     public void testScheduleMissingPackageName_throwsException() {
95         TrainingOptions trainingOptions =
96                 new TrainingOptions.Builder()
97                         .setPopulationName("fake-population")
98                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
99                         .build();
100 
101         assertThrows(
102                 NullPointerException.class,
103                 () -> mFcpService.schedule(null, trainingOptions, new FederatedComputeCallback()));
104     }
105 
106     @Test
testScheduleMissingCallback_throwsException()107     public void testScheduleMissingCallback_throwsException() {
108         TrainingOptions trainingOptions =
109                 new TrainingOptions.Builder()
110                         .setPopulationName("fake-population")
111                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
112                         .build();
113         assertThrows(
114                 NullPointerException.class,
115                 () -> mFcpService.schedule(mContext.getPackageName(), trainingOptions, null));
116     }
117 
118     @Test
testSchedule_returnsSuccess()119     public void testSchedule_returnsSuccess() throws Exception {
120         when(mMockJobManager.onTrainerStartCalled(anyString(), any())).thenReturn(STATUS_SUCCESS);
121 
122         TrainingOptions trainingOptions =
123                 new TrainingOptions.Builder()
124                         .setPopulationName("fake-population")
125                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
126                         .build();
127         invokeScheduleAndVerifyLogging(trainingOptions, STATUS_SUCCESS);
128     }
129 
130     @Test
testScheduleFailed()131     public void testScheduleFailed() throws Exception {
132         when(mMockJobManager.onTrainerStartCalled(anyString(), any()))
133                 .thenReturn(STATUS_INTERNAL_ERROR);
134 
135         TrainingOptions trainingOptions =
136                 new TrainingOptions.Builder()
137                         .setPopulationName("fake-population")
138                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
139                         .build();
140         invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR);
141     }
142 
143     @Test
testScheduleThrowsRTE()144     public void testScheduleThrowsRTE() throws Exception {
145         when(mMockJobManager.onTrainerStartCalled(anyString(), any()))
146                 .thenThrow(RuntimeException.class);
147 
148         TrainingOptions trainingOptions =
149                 new TrainingOptions.Builder()
150                         .setPopulationName("fake-population")
151                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
152                         .build();
153         invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR);
154     }
155 
156     @Test
testScheduleThrowsNPE()157     public void testScheduleThrowsNPE() throws Exception {
158         when(mMockJobManager.onTrainerStartCalled(anyString(), any()))
159                 .thenThrow(NullPointerException.class);
160 
161         TrainingOptions trainingOptions =
162                 new TrainingOptions.Builder()
163                         .setPopulationName("fake-population")
164                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
165                         .build();
166         invokeScheduleAndVerifyLogging(trainingOptions, STATUS_INTERNAL_ERROR);
167     }
168 
169     @Test
testScheduleClockThrowsRTE()170     public void testScheduleClockThrowsRTE() throws Exception {
171         when(mClock.elapsedRealtime()).thenThrow(RuntimeException.class);
172         TrainingOptions trainingOptions =
173                 new TrainingOptions.Builder()
174                         .setPopulationName("fake-population")
175                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
176                         .build();
177         FederatedComputeCallback callback = spy(new FederatedComputeCallback());
178 
179         mFcpService.schedule(mContext.getPackageName(), trainingOptions, callback);
180 
181         ArgumentCaptor<Integer> argument = ArgumentCaptor.forClass(Integer.class);
182         verify(callback).onFailure(argument.capture());
183         assertThat(argument.getValue()).isEqualTo(STATUS_INTERNAL_ERROR);
184     }
185 
186     @Test
testScheduleClockThrowsIAE()187     public void testScheduleClockThrowsIAE() throws Exception {
188         when(mClock.elapsedRealtime()).thenThrow(IllegalArgumentException.class);
189 
190         TrainingOptions trainingOptions =
191                 new TrainingOptions.Builder()
192                         .setPopulationName("fake-population")
193                         .setOwnerComponentName(OWNER_COMPONENT_NAME)
194                         .build();
195         assertThrows(
196                 IllegalArgumentException.class,
197                 () ->
198                         mFcpService.schedule(
199                                 mContext.getPackageName(),
200                                 trainingOptions,
201                                 new FederatedComputeCallback()));
202     }
203 
204     @Test
testScheduleEnabledGlobalKillSwitch_returnsError()205     public void testScheduleEnabledGlobalKillSwitch_returnsError() throws Exception {
206         PhFlagsTestUtil.enableGlobalKillSwitch();
207         try {
208             TrainingOptions trainingOptions =
209                     new TrainingOptions.Builder()
210                             .setPopulationName("fake-population")
211                             .setOwnerComponentName(OWNER_COMPONENT_NAME)
212                             .build();
213             invokeScheduleAndVerifyLogging(trainingOptions, STATUS_KILL_SWITCH_ENABLED, 0);
214         } finally {
215             PhFlagsTestUtil.disableGlobalKillSwitch();
216         }
217     }
218 
219     @Test
testCancelMissingPackageName_throwsException()220     public void testCancelMissingPackageName_throwsException() {
221         assertThrows(
222                 NullPointerException.class,
223                 () -> mFcpService.cancel(null, "fake-population", new FederatedComputeCallback()));
224     }
225 
226     @Test
testCancelMissingCallback_throwsException()227     public void testCancelMissingCallback_throwsException() {
228         assertThrows(
229                 NullPointerException.class,
230                 () -> mFcpService.cancel(OWNER_COMPONENT_NAME, "fake-population", null));
231     }
232 
233     @Test
testCancel_returnsSuccess()234     public void testCancel_returnsSuccess() throws Exception {
235         when(mMockJobManager.onTrainerStopCalled(any(), anyString())).thenReturn(STATUS_SUCCESS);
236 
237         invokeCancelAndVerifyLogging("fake-population", STATUS_SUCCESS);
238     }
239 
240     @Test
testCancelFails()241     public void testCancelFails() throws Exception {
242         when(mMockJobManager.onTrainerStopCalled(any(), any())).thenReturn(STATUS_INTERNAL_ERROR);
243 
244         invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR);
245     }
246 
247     @Test
testCancelEnabledGlobalKillSwitch_returnsError()248     public void testCancelEnabledGlobalKillSwitch_returnsError() throws Exception {
249         PhFlagsTestUtil.enableGlobalKillSwitch();
250         try {
251             invokeCancelAndVerifyLogging("fake-population", STATUS_KILL_SWITCH_ENABLED, 0);
252         } finally {
253             PhFlagsTestUtil.disableGlobalKillSwitch();
254         }
255     }
256 
257     @Test
testCancelThrowsRTE()258     public void testCancelThrowsRTE() throws Exception {
259         when(mMockJobManager.onTrainerStopCalled(any(), any())).thenThrow(RuntimeException.class);
260 
261         invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR);
262     }
263 
264     @Test
testCancelThrowsNPE()265     public void testCancelThrowsNPE() throws Exception {
266         when(mMockJobManager.onTrainerStopCalled(any(), any()))
267                 .thenThrow(NullPointerException.class);
268 
269         invokeCancelAndVerifyLogging("fake-population", STATUS_INTERNAL_ERROR);
270     }
271 
272     @Test
testCancelClockThrowsRTE()273     public void testCancelClockThrowsRTE() throws Exception {
274         when(mClock.elapsedRealtime()).thenThrow(RuntimeException.class);
275         FederatedComputeCallback callback = spy(new FederatedComputeCallback());
276 
277         mFcpService.cancel(OWNER_COMPONENT_NAME, "fake-population", callback);
278 
279         ArgumentCaptor<Integer> argument = ArgumentCaptor.forClass(Integer.class);
280         verify(callback).onFailure(argument.capture());
281         assertThat(argument.getValue()).isEqualTo(STATUS_INTERNAL_ERROR);
282     }
283 
284     @Test
testCancelClockThrowsIAE()285     public void testCancelClockThrowsIAE() throws Exception {
286         when(mClock.elapsedRealtime()).thenThrow(IllegalArgumentException.class);
287 
288         assertThrows(
289                 IllegalArgumentException.class,
290                 () ->
291                         mFcpService.cancel(
292                                 OWNER_COMPONENT_NAME,
293                                 "fake-population",
294                                 new FederatedComputeCallback()));
295     }
296 
invokeScheduleAndVerifyLogging( TrainingOptions trainingOptions, int expectedResultCode)297     private void invokeScheduleAndVerifyLogging(
298             TrainingOptions trainingOptions, int expectedResultCode) throws InterruptedException {
299         invokeScheduleAndVerifyLogging(trainingOptions, expectedResultCode, 100L);
300     }
301 
invokeScheduleAndVerifyLogging( TrainingOptions trainingOptions, int expectedResultCode, long latency)302     private void invokeScheduleAndVerifyLogging(
303             TrainingOptions trainingOptions, int expectedResultCode, long latency)
304             throws InterruptedException {
305         ArgumentCaptor<ApiCallStats> argument = ArgumentCaptor.forClass(ApiCallStats.class);
306         final CountDownLatch logOperationCalledLatch = new CountDownLatch(1);
307         doAnswer(
308                         new Answer<Object>() {
309                             @Override
310                             public Object answer(InvocationOnMock invocation) throws Throwable {
311                                 // The method logAPiCallStats is called.
312                                 invocation.callRealMethod();
313                                 logOperationCalledLatch.countDown();
314                                 return null;
315                             }
316                         })
317                 .when(mFcStatsdLogger)
318                 .logApiCallStats(argument.capture());
319 
320         var callback = new FederatedComputeCallback();
321         mFcpService.schedule(mContext.getPackageName(), trainingOptions, callback);
322 
323         callback.mJobFinishCountDown.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
324         logOperationCalledLatch.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
325 
326         assertThat(argument.getValue().getResponseCode()).isEqualTo(expectedResultCode);
327         assertThat(argument.getValue().getLatencyMillis()).isEqualTo(latency);
328         assertThat(argument.getValue().getApiName())
329                 .isEqualTo(FEDERATED_COMPUTE_API_CALLED__API_NAME__SCHEDULE);
330     }
331 
invokeCancelAndVerifyLogging(String populationName, int expectedResultCode)332     private void invokeCancelAndVerifyLogging(String populationName, int expectedResultCode)
333             throws InterruptedException {
334         invokeCancelAndVerifyLogging(populationName, expectedResultCode, 100);
335     }
336 
invokeCancelAndVerifyLogging( String populationName, int expectedResultCode, long latency)337     private void invokeCancelAndVerifyLogging(
338             String populationName, int expectedResultCode, long latency)
339             throws InterruptedException {
340 
341         final CountDownLatch logOperationCalledLatch = new CountDownLatch(1);
342         ArgumentCaptor<ApiCallStats> argument = ArgumentCaptor.forClass(ApiCallStats.class);
343         doAnswer(
344                         new Answer<Object>() {
345                             @Override
346                             public Object answer(InvocationOnMock invocation) throws Throwable {
347                                 // The method logAPiCallStats is called.
348                                 invocation.callRealMethod();
349                                 logOperationCalledLatch.countDown();
350                                 return null;
351                             }
352                         })
353                 .when(mFcStatsdLogger)
354                 .logApiCallStats(argument.capture());
355         var callback = new FederatedComputeCallback();
356         mFcpService.cancel(OWNER_COMPONENT_NAME, populationName, callback);
357 
358         callback.mJobFinishCountDown.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
359         logOperationCalledLatch.await(BINDER_CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
360 
361         assertThat(argument.getValue().getResponseCode()).isEqualTo(expectedResultCode);
362         assertThat(argument.getValue().getLatencyMillis()).isEqualTo(latency);
363         assertThat(argument.getValue().getApiName())
364                 .isEqualTo(FEDERATED_COMPUTE_API_CALLED__API_NAME__CANCEL);
365     }
366 
367     static class FederatedComputeCallback extends IFederatedComputeCallback.Stub {
368         public boolean mError = false;
369         public int mErrorCode = 0;
370         private final CountDownLatch mJobFinishCountDown = new CountDownLatch(1);
371 
372         @Override
onSuccess()373         public void onSuccess() {
374             mJobFinishCountDown.countDown();
375         }
376 
377         @Override
onFailure(int errorCode)378         public void onFailure(int errorCode) {
379             mError = true;
380             mErrorCode = errorCode;
381             mJobFinishCountDown.countDown();
382         }
383     }
384 
385     class TestInjector extends FederatedComputeManagingServiceDelegate.Injector {
getJobManager(Context mContext)386         FederatedComputeJobManager getJobManager(Context mContext) {
387             return mMockJobManager;
388         }
389     }
390 }
391