1#!/usr/bin/env python 2# 3# Copyright 2018 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Tests for acloud.public.actions.common_operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20 21import unittest 22 23from unittest import mock 24 25from acloud import errors 26from acloud.internal import constants 27from acloud.internal.lib import android_build_client 28from acloud.internal.lib import android_compute_client 29from acloud.internal.lib import auth 30from acloud.internal.lib import driver_test_lib 31from acloud.internal.lib import utils 32from acloud.internal.lib import ssh 33from acloud.public import report 34from acloud.public.actions import common_operations 35 36 37class CommonOperationsTest(driver_test_lib.BaseDriverTest): 38 """Test Common Operations.""" 39 maxDiff = None 40 IP = ssh.IP(external="127.0.0.1", internal="10.0.0.1") 41 INSTANCE = "fake-instance" 42 CMD = "test-cmd" 43 AVD_TYPE = "fake-type" 44 BRANCH = "fake-branch" 45 BUILD_TARGET = "fake-target" 46 BUILD_ID = "fake-build-id" 47 LOGS = [{"path": "/log", "type": "TEXT"}] 48 49 # pylint: disable=protected-access 50 def setUp(self): 51 """Set up the test.""" 52 super().setUp() 53 self.build_client = mock.MagicMock() 54 self.device_factory = mock.MagicMock() 55 self.Patch( 56 android_build_client, 57 "AndroidBuildClient", 58 return_value=self.build_client) 59 self.compute_client = mock.MagicMock() 60 self.compute_client.gce_hostname = None 61 self.Patch( 62 android_compute_client, 63 "AndroidComputeClient", 64 return_value=self.compute_client) 65 self.Patch(auth, "CreateCredentials", return_value=mock.MagicMock()) 66 self.Patch(self.compute_client, "GetInstanceIP", return_value=self.IP) 67 self.Patch( 68 self.device_factory, "CreateInstance", return_value=self.INSTANCE) 69 self.Patch( 70 self.device_factory, 71 "GetComputeClient", 72 return_value=self.compute_client) 73 self.Patch(self.device_factory, "GetVncPorts", return_value=[6444]) 74 self.Patch(self.device_factory, "GetAdbPorts", return_value=[6520]) 75 self.Patch(self.device_factory, "GetBuildInfoDict", 76 return_value={"branch": self.BRANCH, 77 "build_id": self.BUILD_ID, 78 "build_target": self.BUILD_TARGET, 79 "gcs_bucket_build_id": self.BUILD_ID}) 80 self.Patch(self.device_factory, "GetLogs", 81 return_value={self.INSTANCE: self.LOGS}) 82 self.Patch( 83 self.device_factory, 84 "GetFetchCvdWrapperLogIfExist", return_value={}) 85 86 @staticmethod 87 def _CreateCfg(): 88 """A helper method that creates a mock configuration object.""" 89 cfg = mock.MagicMock() 90 cfg.service_account_name = "fake@service.com" 91 cfg.service_account_private_key_path = "/fake/path/to/key" 92 cfg.zone = "fake_zone" 93 cfg.disk_image_name = "fake_image.tar.gz" 94 cfg.disk_image_mime_type = "fake/type" 95 cfg.ssh_private_key_path = "cfg/private/key" 96 cfg.ssh_public_key_path = "" 97 cfg.extra_args_ssh_tunnel="extra args" 98 return cfg 99 100 def testDevicePoolCreateDevices(self): 101 """Test Device Pool Create Devices.""" 102 pool = common_operations.DevicePool(self.device_factory) 103 pool.CreateDevices(5) 104 self.assertEqual(self.device_factory.CreateInstance.call_count, 5) 105 self.assertEqual(len(pool.devices), 5) 106 107 def testCreateDevices(self): 108 """Test Create Devices.""" 109 cfg = self._CreateCfg() 110 _report = common_operations.CreateDevices(self.CMD, cfg, 111 self.device_factory, 1, 112 self.AVD_TYPE) 113 self.assertEqual(_report.command, self.CMD) 114 self.assertEqual(_report.status, report.Status.SUCCESS) 115 self.assertEqual( 116 _report.data, 117 {"devices": [{ 118 "ip": self.IP.external + ":6520", 119 "instance_name": self.INSTANCE, 120 "branch": self.BRANCH, 121 "build_id": self.BUILD_ID, 122 "build_target": self.BUILD_TARGET, 123 "gcs_bucket_build_id": self.BUILD_ID, 124 "logs": self.LOGS 125 }]}) 126 127 def testCreateDevicesWithAdbPort(self): 128 """Test Create Devices with adb port for cuttlefish avd type.""" 129 forwarded_ports = mock.Mock(adb_port=12345, vnc_port=56789) 130 mock_auto_connect = self.Patch(utils, "AutoConnect", 131 return_value=forwarded_ports) 132 cfg = self._CreateCfg() 133 _report = common_operations.CreateDevices(self.CMD, cfg, 134 self.device_factory, 1, 135 "cuttlefish", 136 autoconnect=True, 137 client_adb_port=12345) 138 139 mock_auto_connect.assert_called_with( 140 ip_addr="127.0.0.1", rsa_key_file="cfg/private/key", 141 target_vnc_port=6444, target_adb_port=6520, 142 ssh_user=constants.GCE_USER, client_adb_port=12345, 143 extra_args_ssh_tunnel="extra args") 144 self.assertEqual(_report.command, self.CMD) 145 self.assertEqual(_report.status, report.Status.SUCCESS) 146 self.assertEqual( 147 _report.data, 148 {"devices": [{ 149 "ip": self.IP.external + ":6520", 150 "instance_name": self.INSTANCE, 151 "branch": self.BRANCH, 152 "build_id": self.BUILD_ID, 153 "adb_port": 12345, 154 "device_serial": "127.0.0.1:12345", 155 "vnc_port": 56789, 156 "build_target": self.BUILD_TARGET, 157 "gcs_bucket_build_id": self.BUILD_ID, 158 "logs": self.LOGS 159 }]}) 160 161 def testCreateDevicesMultipleDevices(self): 162 """Test Create Devices with multiple cuttlefish devices.""" 163 forwarded_ports_1 = mock.Mock(adb_port=12345, vnc_port=56789) 164 forwarded_ports_2 = mock.Mock(adb_port=23456, vnc_port=67890) 165 self.Patch(self.device_factory, "GetVncPorts", return_value=[6444, 6445]) 166 self.Patch(self.device_factory, "GetAdbPorts", return_value=[6520, 6521]) 167 self.Patch(utils, "PickFreePort", return_value=12345) 168 mock_auto_connect = self.Patch( 169 utils, "AutoConnect", side_effects=[forwarded_ports_1, 170 forwarded_ports_2]) 171 cfg = self._CreateCfg() 172 _report = common_operations.CreateDevices(self.CMD, cfg, 173 self.device_factory, 1, 174 "cuttlefish", 175 autoconnect=True, 176 client_adb_port=None) 177 self.assertEqual(2, mock_auto_connect.call_count) 178 mock_auto_connect.assert_any_call( 179 ip_addr="127.0.0.1", rsa_key_file="cfg/private/key", 180 target_vnc_port=6444, target_adb_port=6520, 181 ssh_user=constants.GCE_USER, client_adb_port=None, 182 extra_args_ssh_tunnel="extra args") 183 mock_auto_connect.assert_any_call( 184 ip_addr="127.0.0.1", rsa_key_file="cfg/private/key", 185 target_vnc_port=6445, target_adb_port=6521, 186 ssh_user=constants.GCE_USER, client_adb_port=None, 187 extra_args_ssh_tunnel="extra args") 188 self.assertEqual(_report.command, self.CMD) 189 self.assertEqual(_report.status, report.Status.SUCCESS) 190 191 def testCreateDevicesInternalIP(self): 192 """Test Create Devices and report internal IP.""" 193 cfg = self._CreateCfg() 194 _report = common_operations.CreateDevices(self.CMD, cfg, 195 self.device_factory, 1, 196 self.AVD_TYPE, 197 report_internal_ip=True) 198 self.assertEqual(_report.command, self.CMD) 199 self.assertEqual(_report.status, report.Status.SUCCESS) 200 self.assertEqual( 201 _report.data, 202 {"devices": [{ 203 "ip": self.IP.internal + ":6520", 204 "instance_name": self.INSTANCE, 205 "branch": self.BRANCH, 206 "build_id": self.BUILD_ID, 207 "build_target": self.BUILD_TARGET, 208 "gcs_bucket_build_id": self.BUILD_ID, 209 "logs": self.LOGS 210 }]}) 211 212 def testCreateDevicesWithSshParameters(self): 213 """Test Create Devices with ssh user and key.""" 214 forwarded_ports = mock.Mock(adb_port=12345, vnc_port=56789) 215 mock_auto_connect = self.Patch(utils, "AutoConnect", 216 return_value=forwarded_ports) 217 mock_establish_webrtc = self.Patch(utils, "EstablishWebRTCSshTunnel") 218 self.Patch(utils, "PickFreePort", return_value=12345) 219 cfg = self._CreateCfg() 220 _report = common_operations.CreateDevices( 221 self.CMD, cfg, self.device_factory, 1, constants.TYPE_CF, 222 autoconnect=True, connect_webrtc=True, 223 ssh_user="user", ssh_private_key_path="private/key") 224 225 mock_auto_connect.assert_called_with( 226 ip_addr="127.0.0.1", rsa_key_file="private/key", 227 target_vnc_port=6444, target_adb_port=6520, ssh_user="user", 228 client_adb_port=None, extra_args_ssh_tunnel="extra args") 229 mock_establish_webrtc.assert_called_with( 230 ip_addr="127.0.0.1", rsa_key_file="private/key", 231 ssh_user="user", extra_args_ssh_tunnel="extra args", 232 webrtc_local_port=12345) 233 self.assertEqual(_report.status, report.Status.SUCCESS) 234 235 def testGetErrorType(self): 236 """Test GetErrorType.""" 237 # Test with CheckGCEZonesQuotaError() 238 error = errors.CheckGCEZonesQuotaError() 239 expected_result = constants.GCE_QUOTA_ERROR 240 self.assertEqual(common_operations._GetErrorType(error), expected_result) 241 242 # Test with DownloadArtifactError() 243 error = errors.DownloadArtifactError() 244 expected_result = constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR 245 self.assertEqual(common_operations._GetErrorType(error), expected_result) 246 247 # Test with DeviceConnectionError() 248 error = errors.DeviceConnectionError() 249 expected_result = constants.ACLOUD_SSH_CONNECT_ERROR 250 self.assertEqual(common_operations._GetErrorType(error), expected_result) 251 252 # Test with ACLOUD_UNKNOWN_ERROR 253 error = errors.DriverError() 254 expected_result = constants.ACLOUD_UNKNOWN_ERROR 255 self.assertEqual(common_operations._GetErrorType(error), expected_result) 256 257 # Test with error message about GCE quota issue 258 error = errors.DriverError("Quota exceeded for quota read group.") 259 expected_result = constants.GCE_QUOTA_ERROR 260 self.assertEqual(common_operations._GetErrorType(error), expected_result) 261 262 error = errors.DriverError("ZONE_RESOURCE_POOL_EXHAUSTED_WITH_DETAILS") 263 expected_result = constants.GCE_QUOTA_ERROR 264 self.assertEqual(common_operations._GetErrorType(error), expected_result) 265 266 def testCreateDevicesWithFetchCvdWrapper(self): 267 """Test Create Devices with FetchCvdWrapper.""" 268 self.Patch( 269 self.device_factory, 270 "GetFetchCvdWrapperLogIfExist", return_value={"fetch_log": "abc"}) 271 cfg = self._CreateCfg() 272 _report = common_operations.CreateDevices(self.CMD, cfg, 273 self.device_factory, 1, 274 constants.TYPE_CF) 275 self.assertEqual(_report.command, self.CMD) 276 self.assertEqual(_report.status, report.Status.SUCCESS) 277 self.assertEqual( 278 _report.data, 279 {"devices": [{ 280 "ip": self.IP.external + ":6520", 281 "instance_name": self.INSTANCE, 282 "branch": self.BRANCH, 283 "build_id": self.BUILD_ID, 284 "build_target": self.BUILD_TARGET, 285 "gcs_bucket_build_id": self.BUILD_ID, 286 "logs": self.LOGS, 287 "fetch_cvd_wrapper_log": { 288 "fetch_log": "abc" 289 }, 290 }]}) 291 292 293if __name__ == "__main__": 294 unittest.main() 295