1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.security.rkp;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 
21 import static org.junit.Assert.assertThrows;
22 import static org.mockito.AdditionalAnswers.answerVoid;
23 import static org.mockito.Mockito.any;
24 import static org.mockito.Mockito.anyInt;
25 import static org.mockito.Mockito.argThat;
26 import static org.mockito.Mockito.atLeastOnce;
27 import static org.mockito.Mockito.contains;
28 import static org.mockito.Mockito.doAnswer;
29 import static org.mockito.Mockito.doReturn;
30 import static org.mockito.Mockito.doThrow;
31 import static org.mockito.Mockito.eq;
32 import static org.mockito.Mockito.mock;
33 import static org.mockito.Mockito.times;
34 import static org.mockito.Mockito.verify;
35 import static org.mockito.Mockito.verifyNoMoreInteractions;
36 
37 import android.os.Binder;
38 import android.os.CancellationSignal;
39 import android.os.OperationCanceledException;
40 import android.os.OutcomeReceiver;
41 import android.os.RemoteException;
42 import android.security.rkp.IGetKeyCallback;
43 import android.security.rkp.IStoreUpgradedKeyCallback;
44 import android.security.rkp.service.RegistrationProxy;
45 import android.security.rkp.service.RemotelyProvisionedKey;
46 import android.security.rkp.service.RkpProxyException;
47 
48 import androidx.test.runner.AndroidJUnit4;
49 
50 import org.junit.Before;
51 import org.junit.Test;
52 import org.junit.runner.RunWith;
53 import org.mockito.stubbing.Answer;
54 import org.mockito.stubbing.VoidAnswer4;
55 
56 import java.lang.reflect.Field;
57 import java.util.Arrays;
58 import java.util.Map;
59 import java.util.concurrent.Executor;
60 
61 /**
62  * Build/Install/Run:
63  * atest FrameworksServicesTests:RemoteProvisioningRegistrationTest
64  */
65 @RunWith(AndroidJUnit4.class)
66 public class RemoteProvisioningRegistrationTest {
67     private RegistrationProxy mRegistrationProxy;
68     private RemoteProvisioningRegistration mRegistration;
69 
70     @Before
setUp()71     public void setUp() {
72         mRegistrationProxy = mock(RegistrationProxy.class);
73         mRegistration = new RemoteProvisioningRegistration(mRegistrationProxy, Runnable::run);
74     }
75 
76     // answerVoid wrapper with explicit types, avoiding long signatures when mocking getKeyAsync.
answerGetKeyAsync( VoidAnswer4<Integer, CancellationSignal, Executor, OutcomeReceiver<RemotelyProvisionedKey, Exception>> answer)77     static Answer<Void> answerGetKeyAsync(
78             VoidAnswer4<Integer, CancellationSignal, Executor,
79                     OutcomeReceiver<RemotelyProvisionedKey, Exception>> answer) {
80         return answerVoid(answer);
81     }
82 
83     // answerVoid wrapper for mocking storeUpgradeKeyAsync.
answerStoreUpgradedKeyAsync( VoidAnswer4<byte[], byte[], Executor, OutcomeReceiver<Void, Exception>> answer)84     static Answer<Void> answerStoreUpgradedKeyAsync(
85             VoidAnswer4<byte[], byte[], Executor, OutcomeReceiver<Void, Exception>> answer) {
86         return answerVoid(answer);
87     }
88 
89     // matcher helper, making it easier to match the different key types
matches( RemotelyProvisionedKey expectedKey)90     private android.security.rkp.RemotelyProvisionedKey matches(
91             RemotelyProvisionedKey expectedKey) {
92         return argThat((android.security.rkp.RemotelyProvisionedKey key) ->
93                 Arrays.equals(key.keyBlob, expectedKey.getKeyBlob())
94                         && Arrays.equals(key.encodedCertChain, expectedKey.getEncodedCertChain())
95         );
96     }
97 
98     @Test
getKeySuccess()99     public void getKeySuccess() throws Exception {
100         RemotelyProvisionedKey expectedKey = mock(RemotelyProvisionedKey.class);
101         doAnswer(
102                 answerGetKeyAsync((keyId, cancelSignal, executor, receiver) ->
103                         executor.execute(() -> receiver.onResult(expectedKey))))
104                 .when(mRegistrationProxy).getKeyAsync(eq(42), any(), any(), any());
105 
106         IGetKeyCallback callback = mock(IGetKeyCallback.class);
107         doReturn(new Binder()).when(callback).asBinder();
108         mRegistration.getKey(42, callback);
109         verify(callback).onSuccess(matches(expectedKey));
110         verify(callback, atLeastOnce()).asBinder();
111         verifyNoMoreInteractions(callback);
112     }
113 
114     @Test
getKeyHandlesArbitraryException()115     public void getKeyHandlesArbitraryException() throws Exception {
116         Exception expectedException = new Exception("oops!");
117         doAnswer(
118                 answerGetKeyAsync((keyId, cancelSignal, executor, receiver) ->
119                         executor.execute(() -> receiver.onError(expectedException))))
120                 .when(mRegistrationProxy).getKeyAsync(eq(0), any(), any(), any());
121         IGetKeyCallback callback = mock(IGetKeyCallback.class);
122         doReturn(new Binder()).when(callback).asBinder();
123         mRegistration.getKey(0, callback);
124         verify(callback).onError(eq(IGetKeyCallback.ErrorCode.ERROR_UNKNOWN), eq("oops!"));
125         verify(callback, atLeastOnce()).asBinder();
126         verifyNoMoreInteractions(callback);
127     }
128 
129     @Test
getKeyMapsRkpErrorsCorrectly()130     public void getKeyMapsRkpErrorsCorrectly() throws Exception {
131         Map<Byte, Integer> expectedConversions = Map.of(
132                 IGetKeyCallback.ErrorCode.ERROR_UNKNOWN,
133                 RkpProxyException.ERROR_UNKNOWN,
134                 IGetKeyCallback.ErrorCode.ERROR_REQUIRES_SECURITY_PATCH,
135                 RkpProxyException.ERROR_REQUIRES_SECURITY_PATCH,
136                 IGetKeyCallback.ErrorCode.ERROR_PENDING_INTERNET_CONNECTIVITY,
137                 RkpProxyException.ERROR_PENDING_INTERNET_CONNECTIVITY,
138                 IGetKeyCallback.ErrorCode.ERROR_PERMANENT,
139                 RkpProxyException.ERROR_PERMANENT);
140 
141         for (Field errorField: IGetKeyCallback.ErrorCode.class.getFields()) {
142             byte error = (Byte) errorField.get(null);
143             Exception expectedException = new RkpProxyException(expectedConversions.get(error),
144                     errorField.getName());
145             doAnswer(
146                     answerGetKeyAsync((keyId, cancelSignal, executor, receiver) ->
147                             executor.execute(() -> receiver.onError(expectedException))))
148                     .when(mRegistrationProxy).getKeyAsync(eq(0), any(), any(), any());
149             IGetKeyCallback callback = mock(IGetKeyCallback.class);
150             doReturn(new Binder()).when(callback).asBinder();
151             mRegistration.getKey(0, callback);
152             verify(callback).onError(eq(error), contains(errorField.getName()));
153             verify(callback, atLeastOnce()).asBinder();
154             verifyNoMoreInteractions(callback);
155         }
156     }
157 
158     @Test
getKeyCancelDuringProxyOperation()159     public void getKeyCancelDuringProxyOperation() throws Exception {
160         final Binder theBinder = new Binder();
161         IGetKeyCallback callback = mock(IGetKeyCallback.class);
162         doReturn(theBinder).when(callback).asBinder();
163         doAnswer(
164                 answerGetKeyAsync((keyId, cancelSignal, executor, receiver) -> {
165                     // Use a different callback object to ensure that the callback equivalence
166                     // relies on the actual IBinder object.
167                     IGetKeyCallback differentCallback = mock(IGetKeyCallback.class);
168                     doReturn(theBinder).when(differentCallback).asBinder();
169                     mRegistration.cancelGetKey(differentCallback);
170                     verify(differentCallback, atLeastOnce()).asBinder();
171                     verifyNoMoreInteractions(differentCallback);
172                     assertThat(cancelSignal.isCanceled()).isTrue();
173                     executor.execute(() -> receiver.onError(new OperationCanceledException()));
174                 }))
175                 .when(mRegistrationProxy).getKeyAsync(eq(Integer.MAX_VALUE), any(), any(), any());
176 
177         mRegistration.getKey(Integer.MAX_VALUE, callback);
178         verify(callback).onCancel();
179         verify(callback, atLeastOnce()).asBinder();
180         verifyNoMoreInteractions(callback);
181     }
182 
183     @Test
cancelGetKeyWithInvalidCallback()184     public void cancelGetKeyWithInvalidCallback() throws Exception {
185         IGetKeyCallback callback = mock(IGetKeyCallback.class);
186         doReturn(new Binder()).when(callback).asBinder();
187         assertThrows(IllegalArgumentException.class, () -> mRegistration.cancelGetKey(callback));
188     }
189 
190     @Test
getKeyRejectsDuplicateCallback()191     public void getKeyRejectsDuplicateCallback() throws Exception {
192         IGetKeyCallback callback = mock(IGetKeyCallback.class);
193         doReturn(new Binder()).when(callback).asBinder();
194         doAnswer(
195                 answerGetKeyAsync((keyId, cancelSignal, executor, receiver) -> {
196                     assertThrows(IllegalArgumentException.class, () ->
197                             mRegistration.getKey(0, callback));
198                     executor.execute(() -> receiver.onResult(mock(RemotelyProvisionedKey.class)));
199                 }))
200                 .when(mRegistrationProxy).getKeyAsync(anyInt(), any(), any(), any());
201 
202         mRegistration.getKey(0, callback);
203         verify(callback, times(1)).onSuccess(any());
204         verify(callback, atLeastOnce()).asBinder();
205         verifyNoMoreInteractions(callback);
206     }
207 
208     @Test
getKeyCancelAfterCompleteFails()209     public void getKeyCancelAfterCompleteFails() throws Exception {
210         IGetKeyCallback callback = mock(IGetKeyCallback.class);
211         doReturn(new Binder()).when(callback).asBinder();
212         doAnswer(
213                 answerGetKeyAsync((keyId, cancelSignal, executor, receiver) ->
214                         executor.execute(() ->
215                                 receiver.onResult(mock(RemotelyProvisionedKey.class))
216                         )))
217                 .when(mRegistrationProxy).getKeyAsync(eq(Integer.MIN_VALUE), any(), any(), any());
218 
219         mRegistration.getKey(Integer.MIN_VALUE, callback);
220         verify(callback).onSuccess(any());
221         assertThrows(IllegalArgumentException.class, () -> mRegistration.cancelGetKey(callback));
222         verify(callback, atLeastOnce()).asBinder();
223         verifyNoMoreInteractions(callback);
224     }
225 
226     @Test
getKeyCatchesExceptionFromProxy()227     public void getKeyCatchesExceptionFromProxy() throws Exception {
228         Exception expectedException = new RuntimeException("oops! bad input!");
229         doThrow(expectedException)
230                 .when(mRegistrationProxy)
231                 .getKeyAsync(anyInt(), any(), any(), any());
232 
233         IGetKeyCallback callback = mock(IGetKeyCallback.class);
234         doReturn(new Binder()).when(callback).asBinder();
235         mRegistration.getKey(0, callback);
236         verify(callback).onError(eq(IGetKeyCallback.ErrorCode.ERROR_UNKNOWN),
237                 eq(expectedException.getMessage()));
238         assertThrows(IllegalArgumentException.class, () -> mRegistration.cancelGetKey(callback));
239         verify(callback, atLeastOnce()).asBinder();
240         verifyNoMoreInteractions(callback);
241     }
242 
243     @Test
storeUpgradedKeySuccess()244     public void storeUpgradedKeySuccess() throws Exception {
245         doAnswer(
246                 answerStoreUpgradedKeyAsync((oldBlob, newBlob, executor, receiver) ->
247                         executor.execute(() -> receiver.onResult(null))))
248                 .when(mRegistrationProxy)
249                 .storeUpgradedKeyAsync(any(), any(), any(), any());
250 
251         IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
252         doReturn(new Binder()).when(callback).asBinder();
253         mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
254         verify(callback).onSuccess();
255         verify(callback, atLeastOnce()).asBinder();
256         verifyNoMoreInteractions(callback);
257     }
258 
259     @Test
storeUpgradedKeyFails()260     public void storeUpgradedKeyFails() throws Exception {
261         final String errorString = "this is a failure";
262         doAnswer(
263                 answerStoreUpgradedKeyAsync((oldBlob, newBlob, executor, receiver) ->
264                         executor.execute(() -> receiver.onError(new RemoteException(errorString)))))
265                 .when(mRegistrationProxy)
266                 .storeUpgradedKeyAsync(any(), any(), any(), any());
267 
268         IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
269         doReturn(new Binder()).when(callback).asBinder();
270         mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
271         verify(callback).onError(errorString);
272         verify(callback, atLeastOnce()).asBinder();
273         verifyNoMoreInteractions(callback);
274     }
275 
276     @Test
storeUpgradedKeyHandlesException()277     public void storeUpgradedKeyHandlesException() throws Exception {
278         final String errorString = "all aboard the failboat, toot toot";
279         doThrow(new IllegalArgumentException(errorString))
280                 .when(mRegistrationProxy)
281                 .storeUpgradedKeyAsync(any(), any(), any(), any());
282 
283         IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
284         doReturn(new Binder()).when(callback).asBinder();
285         mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
286         verify(callback).onError(errorString);
287         verify(callback, atLeastOnce()).asBinder();
288         verifyNoMoreInteractions(callback);
289     }
290 
291     @Test
storeUpgradedKeyDuplicateCallback()292     public void storeUpgradedKeyDuplicateCallback() throws Exception {
293         IStoreUpgradedKeyCallback callback = mock(IStoreUpgradedKeyCallback.class);
294         doReturn(new Binder()).when(callback).asBinder();
295 
296         doAnswer(
297                 answerStoreUpgradedKeyAsync((oldBlob, newBlob, executor, receiver) -> {
298                     assertThrows(IllegalArgumentException.class,
299                             () -> mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0],
300                                     callback));
301                     executor.execute(() -> receiver.onResult(null));
302                 }))
303                 .when(mRegistrationProxy)
304                 .storeUpgradedKeyAsync(any(), any(), any(), any());
305 
306         mRegistration.storeUpgradedKeyAsync(new byte[0], new byte[0], callback);
307         verify(callback).onSuccess();
308         verify(callback, atLeastOnce()).asBinder();
309         verifyNoMoreInteractions(callback);
310     }
311 
312 }
313