1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.rkpdapp.unittest;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static com.google.common.truth.Truth.assertWithMessage;
21 
22 import android.content.Context;
23 import android.content.pm.ApplicationInfo;
24 import android.content.pm.PackageManager;
25 import android.util.Base64;
26 
27 import androidx.test.core.app.ApplicationProvider;
28 
29 import com.android.rkpdapp.GeekResponse;
30 import com.android.rkpdapp.RkpdException;
31 import com.android.rkpdapp.interfaces.ServerInterface;
32 import com.android.rkpdapp.metrics.ProvisioningAttempt;
33 import com.android.rkpdapp.testutil.FakeRkpServer;
34 import com.android.rkpdapp.utils.CborUtils;
35 import com.android.rkpdapp.utils.Settings;
36 
37 import org.junit.After;
38 import org.junit.Before;
39 import org.junit.BeforeClass;
40 import org.junit.Test;
41 import org.mockito.Mockito;
42 
43 import java.io.ByteArrayInputStream;
44 import java.io.IOException;
45 import java.io.InputStream;
46 import java.net.HttpURLConnection;
47 import java.nio.charset.StandardCharsets;
48 import java.time.Duration;
49 import java.util.List;
50 
51 public class ServerInterfaceTest {
52     private static final Duration TIME_TO_REFRESH_HOURS = Duration.ofHours(2);
53     private static Context sContext;
54     private ServerInterface mServerInterface;
55 
56     @BeforeClass
init()57     public static void init() {
58         sContext = Mockito.spy(ApplicationProvider.getApplicationContext());
59     }
60 
61     @Before
setUp()62     public void setUp() {
63         Settings.clearPreferences(sContext);
64         mServerInterface = new ServerInterface(sContext, false);
65         Utils.mockConnectivityState(sContext, Utils.ConnectivityState.CONNECTED);
66     }
67 
68     @After
tearDown()69     public void tearDown() {
70         Settings.clearPreferences(sContext);
71         Mockito.reset(sContext);
72     }
73 
74     @Test
testRetryOnServerFailure()75     public void testRetryOnServerFailure() throws Exception {
76         try (FakeRkpServer server = new FakeRkpServer(FakeRkpServer.Response.INTERNAL_ERROR,
77                 FakeRkpServer.Response.INTERNAL_ERROR)) {
78             Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
79                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
80             Settings.setMaxRequestTime(sContext, 100);
81             GeekResponse ignored = mServerInterface.fetchGeek(
82                     ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
83             assertWithMessage("Expected RkpdException.").fail();
84         } catch (RkpdException e) {
85             assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.HTTP_SERVER_ERROR);
86             assertThat(e).hasMessageThat().contains("HTTP error status encountered");
87         }
88     }
89 
90     @Test
testFetchGeekRkpDisabled()91     public void testFetchGeekRkpDisabled() throws Exception {
92         try (FakeRkpServer server = new FakeRkpServer(
93                 FakeRkpServer.Response.FETCH_EEK_RKP_DISABLED,
94                 FakeRkpServer.Response.INTERNAL_ERROR)) {
95             Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
96                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
97             GeekResponse response = mServerInterface.fetchGeek(
98                     ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
99 
100             assertThat(response.numExtraAttestationKeys).isEqualTo(0);
101             assertThat(response.getChallenge()).isNotNull();
102             assertThat(response.getGeekChain(2)).isNotNull();
103         }
104     }
105 
106     @Test
testFetchGeekRkpEnabled()107     public void testFetchGeekRkpEnabled() throws Exception {
108         try (FakeRkpServer server = new FakeRkpServer(
109                 FakeRkpServer.Response.FETCH_EEK_OK,
110                 FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
111             Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
112                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
113             GeekResponse response = mServerInterface.fetchGeek(
114                     ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
115 
116             assertThat(response.numExtraAttestationKeys).isEqualTo(20);
117             assertThat(response.getChallenge()).isNotNull();
118             byte[] challenge = Base64.decode("AAABgEg1zGsBILStY/1VNI7st0AG9x2S/tba+H4=",
119                     Base64.DEFAULT);
120             assertThat(response.getChallenge()).isEqualTo(challenge);
121             byte[] ed25519GeekChain = Base64.decode(
122                     "g4RDoQEnoFgqpAEBAycgBiFYIJm57t1e5FL2hcZMYtw+YatXS"
123                             + "H11NymtdoAy0rPLY1jZWEAeIghLpLekyNdOAw7+uK8UTKc7b6XN3Np5xitk"
124                             + "/pk5r3bngPpmAIUNB5gqrJFcpyUUSQY0dcqKJ3rZ41pJ6wIDhEOhASegWCqk"
125                             + "AQEDJyAGIVgg6i+FDp5qDFz3vdn6KDK/2lXpIKJRA8kDkxjOoBUp7NFYQIJr"
126                             + "x12mNle3x3ESrRzCarMsIyrdFDDLghS2icXTHjG7uFAhSklNupEMbzNNg7xY"
127                             + "Ky6E28VZD5hh4sHqifLQrgSEQ6EBJ6BYTqUBAQJYIG+S0QRtcdinjojY0VaB"
128                             + "X5bReIPmMBuH7b8g0Uo7/mouAzgYIAQhWCC2XRxLmoM6nbUVWTehJvsP3+ec"
129                             + "rAHVpOzIOikAiFglOVhAgLKf0DKenUr+sCXywtIiaEbGILCq6BasZKFFg5vM"
130                             + "SVQlf6sWBVPwvTWT88a7WU5e+d4hBxSjtqSji4+Clpa6Aw==",
131                     Base64.DEFAULT);
132             byte[] p256GeekChain = Base64.decode(
133                     "g4RDoQEmoFhNpQECAyYgASFYIPcUituX9MxT79JkEcTjdR9mH6Rx"
134                             + "DGzP+glGgHSHVPKtIlggXn9b9uzk9hnM/xM3/Q+hyJPbGAZ2xF3m12p3hsMtr49YQC"
135                             + "+XjkL7vgctlUeFR5NAsB/Um0ekxESp8qEHhxDHn8sR9L+f6Dvg5zRMFfx7w34zBfTR"
136                             + "NDztAgRgehXgedOK/ySEQ6EBJqBYTaUBAgMmIAEhWCBRgKzPj5aM7A9Q4akbt5CGNI"
137                             + "vjw6xlAk209jEOCEYyOSJYIFTrlJ3+trTkczolTi8fnZ29+mbBEYvploxD5DD22nar"
138                             + "WECYOPs0OmXbc5ixJ6IVdPK+BueNIk7d8L/CAXTEtylrJBy12NJm+kTv9TAsBHTt6M"
139                             + "Zg2s6fVlcndCHT3pOP47jNhEOhASagWHGmAQICWCCDn/j9EBwSn5JBx1uN5E70GROa"
140                             + "xxttpw6V8mRTXacdwQM4GCABIVggFqRSEmOzhlZQ2N/yoKh9vNlup2hg6oxc8ZPllx"
141                             + "kNrN4iWCCJvsxsP16wOTSvl7o40RYdocwdZNOMSE74coEbOz4x7lhA+trPLaulMAxz"
142                             + "xeWrSZJZYET6xPIz5QSybBlk6RzjZDs0hgBlLfXdr6oBya+DyU74WpToZZNR4xgeOY"
143                             + "CnaUszzQ==",
144                     Base64.DEFAULT);
145             assertThat(response.getGeekChain(CborUtils.EC_CURVE_25519)).isEqualTo(ed25519GeekChain);
146             assertThat(response.getGeekChain(CborUtils.EC_CURVE_P256)).isEqualTo(p256GeekChain);
147         }
148     }
149 
150     @Test
testFetchKeyAndUpdate()151     public void testFetchKeyAndUpdate() throws Exception {
152         try (FakeRkpServer server = new FakeRkpServer(
153                 FakeRkpServer.Response.FETCH_EEK_OK,
154                 FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
155             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
156                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
157             mServerInterface.fetchGeekAndUpdate(
158                     ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
159 
160             assertThat(Settings.getExtraSignedKeysAvailable(sContext)).isEqualTo(20);
161             assertThat(Settings.getExpiringBy(sContext)).isEqualTo(Duration.ofHours(72));
162         }
163     }
164 
165     @Test
testRequestSignedCertUnregistered()166     public void testRequestSignedCertUnregistered() throws Exception {
167         try (FakeRkpServer server = new FakeRkpServer(
168                 FakeRkpServer.Response.FETCH_EEK_OK,
169                 FakeRkpServer.Response.SIGN_CERTS_DEVICE_UNREGISTERED)) {
170             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
171                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
172             ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
173                     sContext);
174             mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
175             assertWithMessage("Should fail due to unregistered device.").fail();
176         } catch (RkpdException e) {
177             assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.DEVICE_NOT_REGISTERED);
178         }
179     }
180 
181     @Test
testRequestSignedCertClientError()182     public void testRequestSignedCertClientError() throws Exception {
183         try (FakeRkpServer server = new FakeRkpServer(
184                 FakeRkpServer.Response.FETCH_EEK_OK,
185                 FakeRkpServer.Response.SIGN_CERTS_USER_UNAUTHORIZED)) {
186             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
187                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
188             Settings.setMaxRequestTime(sContext, 100);
189             ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
190                     sContext);
191             mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
192             assertWithMessage("Should fail due to client error.").fail();
193         } catch (RkpdException e) {
194             assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.HTTP_CLIENT_ERROR);
195         }
196     }
197 
198     @Test
testRequestSignedCertCborError()199     public void testRequestSignedCertCborError() throws Exception {
200         try (FakeRkpServer server = new FakeRkpServer(
201                 FakeRkpServer.Response.FETCH_EEK_OK,
202                 FakeRkpServer.Response.SIGN_CERTS_OK_INVALID_CBOR)) {
203             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
204                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
205             ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
206                     sContext);
207             mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
208             assertWithMessage("Should fail due to invalid cbor.").fail();
209         } catch (RkpdException e) {
210             assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.INTERNAL_ERROR);
211             assertThat(e).hasMessageThat().isEqualTo("Response failed to parse.");
212         }
213     }
214 
215     @Test
testRequestSignedCertValid()216     public void testRequestSignedCertValid() throws Exception {
217         try (FakeRkpServer server = new FakeRkpServer(
218                 FakeRkpServer.Response.FETCH_EEK_OK,
219                 FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
220             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
221                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
222             ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
223                     sContext);
224             List<byte[]> certChains = mServerInterface.requestSignedCertificates(new byte[0],
225                     new byte[0], metrics);
226             assertThat(certChains).isEmpty();
227             assertThat(certChains).isNotNull();
228         }
229     }
230 
231     @Test
testDataBudgetEmptyFetchGeekNetworkConnected()232     public void testDataBudgetEmptyFetchGeekNetworkConnected() throws Exception {
233         try (FakeRkpServer server = new FakeRkpServer(
234                 FakeRkpServer.Response.FETCH_EEK_OK,
235                 FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
236             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
237                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
238 
239             // Check the data budget in order to initialize a rolling window.
240             assertThat(Settings.hasErrDataBudget(sContext, null /* curTime */)).isTrue();
241             Settings.consumeErrDataBudget(sContext, Settings.FAILURE_DATA_USAGE_MAX);
242             ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
243                     sContext);
244 
245             mServerInterface.fetchGeek(metrics);
246             assertWithMessage("Network transaction should not have proceeded.").fail();
247         } catch (RkpdException e) {
248             assertThat(e).hasMessageThat().contains("Out of data budget due to repeated errors");
249             assertThat(e.getErrorCode()).isEqualTo(
250                     RkpdException.ErrorCode.NETWORK_COMMUNICATION_ERROR);
251         }
252     }
253 
254     @Test
testNetworkDisconnected()255     public void testNetworkDisconnected() throws Exception {
256         try (FakeRkpServer server = new FakeRkpServer(
257                 FakeRkpServer.Response.FETCH_EEK_OK,
258                 FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
259             Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
260                     TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
261 
262             ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
263                     sContext);
264 
265             // We are okay in mocking connectivity failure since network check is the first thing
266             // to happen.
267             Utils.mockConnectivityState(sContext, Utils.ConnectivityState.DISCONNECTED);
268             mServerInterface.fetchGeek(metrics);
269             assertWithMessage("Network transaction should not have proceeded.").fail();
270         } catch (RkpdException e) {
271             assertThat(e).hasMessageThat().contains("No network detected");
272             assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.NO_NETWORK_CONNECTIVITY);
273         }
274     }
275 
276     @Test
testReadErrorInvalidContentType()277     public void testReadErrorInvalidContentType() {
278         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
279         Mockito.when(connection.getContentType()).thenReturn("application/NOPE");
280         assertThat(ServerInterface.readErrorFromConnection(connection))
281                 .isEqualTo("Unexpected content type from the server: application/NOPE");
282     }
283 
284     @Test
testReadTextErrorFromErrorStreamNoErrorData()285     public void testReadTextErrorFromErrorStreamNoErrorData() throws Exception {
286         final String expectedError = "No error data returned by server.";
287 
288         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
289         Mockito.when(connection.getContentType()).thenReturn("text");
290         Mockito.when(connection.getInputStream()).thenThrow(new IOException());
291         Mockito.when(connection.getErrorStream()).thenReturn(null);
292 
293         assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(expectedError);
294     }
295 
296     @Test
testReadTextErrorFromErrorStream()297     public void testReadTextErrorFromErrorStream() throws Exception {
298         final String error = "Explanation for error goes here.";
299 
300         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
301         Mockito.when(connection.getContentType()).thenReturn("text");
302         Mockito.when(connection.getInputStream()).thenThrow(new IOException());
303         Mockito.when(connection.getErrorStream())
304                 .thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));
305 
306         assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(error);
307     }
308 
309     @Test
testReadTextError()310     public void testReadTextError() throws IOException {
311         final String error = "This is an error.  Oh No.";
312         final String[] textContentTypes = new String[]{
313                 "text",
314                 "text/ANYTHING",
315                 "text/what-is-this; charset=unknown",
316                 "text/lowercase; charset=utf-8",
317                 "text/uppercase; charset=UTF-8",
318                 "text/yolo; charset=ASCII"
319         };
320 
321         for (String contentType : textContentTypes) {
322             HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
323             Mockito.when(connection.getContentType()).thenReturn(contentType);
324             Mockito.when(connection.getInputStream())
325                     .thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));
326 
327             assertWithMessage("Failed on content type '" + contentType + "'")
328                     .that(error)
329                     .isEqualTo(ServerInterface.readErrorFromConnection(connection));
330         }
331     }
332 
333     @Test
testReadJsonError()334     public void testReadJsonError() throws IOException {
335         final String error = "Not really JSON.";
336 
337         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
338         Mockito.when(connection.getContentType()).thenReturn("application/json");
339         Mockito.when(connection.getInputStream())
340                 .thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));
341 
342         assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(error);
343     }
344 
345     @Test
testReadErrorStreamThrowsException()346     public void testReadErrorStreamThrowsException() throws IOException {
347         InputStream stream = Mockito.mock(InputStream.class);
348         Mockito.when(stream.read(Mockito.any())).thenThrow(new IOException());
349 
350         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
351         Mockito.when(connection.getContentType()).thenReturn("text");
352         Mockito.when(connection.getInputStream()).thenReturn(stream);
353 
354         final String error = ServerInterface.readErrorFromConnection(connection);
355         assertWithMessage("Error string: '" + error + "'")
356                 .that(error).startsWith("Error reading error string from server: ");
357     }
358 
359     @Test
testReadErrorEmptyStream()360     public void testReadErrorEmptyStream() throws IOException {
361         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
362         Mockito.when(connection.getContentType()).thenReturn("text");
363         Mockito.when(connection.getInputStream())
364                 .thenReturn(new ByteArrayInputStream(new byte[0]));
365 
366         assertThat(ServerInterface.readErrorFromConnection(connection))
367                 .isEqualTo("No error data returned by server.");
368     }
369 
370     @Test
testReadErrorStreamTooLarge()371     public void testReadErrorStreamTooLarge() throws IOException {
372         final StringBuilder sb = new StringBuilder();
373         for (int i = 0; i < 2048; ++i) {
374             sb.append(i % 100);
375         }
376         final String bigString = sb.toString();
377 
378         HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
379         Mockito.when(connection.getContentType()).thenReturn("text");
380         Mockito.when(connection.getInputStream())
381                 .thenReturn(new ByteArrayInputStream(bigString.getBytes(StandardCharsets.UTF_8)));
382 
383         sb.setLength(1024);
384         assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(sb.toString());
385     }
386 
387     @Test
testServerConnectionTimeout()388     public void testServerConnectionTimeout() {
389         ServerInterface serverInterface = Mockito.spy(mServerInterface);
390         Mockito.when(serverInterface.getRegionalProperty()).thenReturn("cn");
391         assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
392                 ServerInterface.SYNC_CONNECT_TIMEOUT_RETRICTED_MS);
393 
394         Mockito.when(serverInterface.getRegionalProperty()).thenReturn("cn,us");
395         assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
396                 ServerInterface.SYNC_CONNECT_TIMEOUT_RETRICTED_MS);
397 
398         Mockito.when(serverInterface.getRegionalProperty()).thenReturn(null);
399         assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
400                 ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
401 
402         Mockito.when(serverInterface.getRegionalProperty()).thenReturn("");
403         assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
404                 ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
405 
406         Mockito.when(serverInterface.getRegionalProperty()).thenReturn("us");
407         assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
408                 ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
409     }
410 
411     @Test
testConnectionConsent()412     public void testConnectionConsent() throws Exception {
413         String cnGmsFeature = "cn.google.services";
414         PackageManager mockedPackageManager = Mockito.mock(PackageManager.class);
415         Context mockedContext = Mockito.mock(Context.class);
416         ApplicationInfo fakeApplicationInfo = new ApplicationInfo();
417 
418         Mockito.when(mockedContext.getPackageManager()).thenReturn(mockedPackageManager);
419         Mockito.when(mockedPackageManager.hasSystemFeature(cnGmsFeature)).thenReturn(true);
420         Mockito.when(mockedPackageManager.getApplicationInfo(Mockito.any(), Mockito.eq(0)))
421                 .thenReturn(fakeApplicationInfo);
422 
423         fakeApplicationInfo.enabled = false;
424         assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isFalse();
425 
426         fakeApplicationInfo.enabled = true;
427         assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
428 
429         Mockito.when(mockedPackageManager.getApplicationInfo(Mockito.any(), Mockito.eq(0)))
430                 .thenThrow(new PackageManager.NameNotFoundException());
431         assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isFalse();
432 
433         Mockito.when(mockedPackageManager.hasSystemFeature(cnGmsFeature)).thenReturn(false);
434         assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
435 
436         fakeApplicationInfo.enabled = false;
437         assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
438     }
439 }
440