1#!/usr/bin/env python
2#
3# Copyright 2018 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Common operations to create remote devices."""
17
18import logging
19import os
20
21from acloud import errors
22from acloud.public import avd
23from acloud.public import report
24from acloud.internal import constants
25from acloud.internal.lib import utils
26from acloud.internal.lib.adb_tools import AdbTools
27
28
29logger = logging.getLogger(__name__)
30_GCE_QUOTA_ERROR_KEYWORDS = [
31    "Quota exceeded for quota",
32    "ZONE_RESOURCE_POOL_EXHAUSTED",
33    "ZONE_RESOURCE_POOL_EXHAUSTED_WITH_DETAILS"]
34_DICT_ERROR_TYPE = {
35    constants.STAGE_INIT: constants.ACLOUD_INIT_ERROR,
36    constants.STAGE_GCE: constants.ACLOUD_CREATE_GCE_ERROR,
37    constants.STAGE_SSH_CONNECT: constants.ACLOUD_SSH_CONNECT_ERROR,
38    constants.STAGE_ARTIFACT: constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR,
39    constants.STAGE_BOOT_UP: constants.ACLOUD_BOOT_UP_ERROR,
40}
41
42
43def CreateSshKeyPairIfNecessary(cfg):
44    """Create ssh key pair if necessary.
45
46    Args:
47        cfg: An Acloudconfig instance.
48
49    Raises:
50        error.DriverError: If it falls into an unexpected condition.
51    """
52    if not cfg.ssh_public_key_path:
53        logger.warning(
54            "ssh_public_key_path is not specified in acloud config. "
55            "Project-wide public key will "
56            "be used when creating AVD instances. "
57            "Please ensure you have the correct private half of "
58            "a project-wide public key if you want to ssh into the "
59            "instances after creation.")
60    elif cfg.ssh_public_key_path and not cfg.ssh_private_key_path:
61        logger.warning(
62            "Only ssh_public_key_path is specified in acloud config, "
63            "but ssh_private_key_path is missing. "
64            "Please ensure you have the correct private half "
65            "if you want to ssh into the instances after creation.")
66    elif cfg.ssh_public_key_path and cfg.ssh_private_key_path:
67        utils.CreateSshKeyPairIfNotExist(cfg.ssh_private_key_path,
68                                         cfg.ssh_public_key_path)
69    else:
70        # Should never reach here.
71        raise errors.DriverError(
72            "Unexpected error in CreateSshKeyPairIfNecessary")
73
74
75class DevicePool:
76    """A class that manages a pool of virtual devices.
77
78    Attributes:
79        devices: A list of devices in the pool.
80    """
81
82    def __init__(self, device_factory, devices=None):
83        """Constructs a new DevicePool.
84
85        Args:
86            device_factory: A device factory capable of producing a goldfish or
87                cuttlefish device. The device factory must expose an attribute with
88                the credentials that can be used to retrieve information from the
89                constructed device.
90            devices: List of devices managed by this pool.
91        """
92        self._devices = devices or []
93        self._device_factory = device_factory
94        self._compute_client = device_factory.GetComputeClient()
95
96    def CreateDevices(self, num):
97        """Creates |num| devices for given build_target and build_id.
98
99        Args:
100            num: Number of devices to create.
101        """
102        # Create host instances for cuttlefish/goldfish device.
103        # Currently one instance supports only 1 device.
104        for _ in range(num):
105            instance = self._device_factory.CreateInstance()
106            ip = self._compute_client.GetInstanceIP(instance)
107            time_info = {
108                stage: round(exec_time, 2) for stage, exec_time in
109                getattr(self._compute_client, "execution_time", {}).items()}
110            stage = self._compute_client.stage if hasattr(
111                self._compute_client, "stage") else 0
112            openwrt = self._compute_client.openwrt if hasattr(
113                self._compute_client, "openwrt") else False
114            gce_hostname = self._compute_client.gce_hostname if hasattr(
115                self._compute_client, "gce_hostname") else None
116            self.devices.append(
117                avd.AndroidVirtualDevice(ip=ip, instance_name=instance,
118                                         time_info=time_info, stage=stage,
119                                         openwrt=openwrt, gce_hostname=gce_hostname))
120
121    @utils.TimeExecute(function_description="Waiting for AVD(s) to boot up",
122                       result_evaluator=utils.BootEvaluator)
123    def WaitForBoot(self, boot_timeout_secs):
124        """Waits for all devices to boot up.
125
126        Args:
127            boot_timeout_secs: Integer, the maximum time in seconds used to
128                               wait for the AVD to boot.
129
130        Returns:
131            A dictionary that contains all the failures.
132            The key is the name of the instance that fails to boot,
133            and the value is an errors.DeviceBootError object.
134        """
135        failures = {}
136        for device in self._devices:
137            try:
138                self._compute_client.WaitForBoot(device.instance_name, boot_timeout_secs)
139            except errors.DeviceBootError as e:
140                failures[device.instance_name] = e
141        return failures
142
143    def UpdateReport(self, reporter):
144        """Update report from compute client.
145
146        Args:
147            reporter: Report object.
148        """
149        reporter.UpdateData(self._compute_client.dict_report)
150
151    def CollectSerialPortLogs(self, output_file,
152                              port=constants.DEFAULT_SERIAL_PORT):
153        """Tar the instance serial logs into specified output_file.
154
155        Args:
156            output_file: String, the output tar file path
157            port: The serial port number to be collected
158        """
159        # For emulator, the serial log is the virtual host serial log.
160        # For GCE AVD device, the serial log is the AVD device serial log.
161        with utils.TempDir() as tempdir:
162            src_dict = {}
163            for device in self._devices:
164                logger.info("Store instance %s serial port %s output to %s",
165                            device.instance_name, port, output_file)
166                serial_log = self._compute_client.GetSerialPortOutput(
167                    instance=device.instance_name, port=port)
168                file_name = "%s_serial_%s.log" % (device.instance_name, port)
169                file_path = os.path.join(tempdir, file_name)
170                src_dict[file_path] = file_name
171                with open(file_path, "w") as f:
172                    f.write(serial_log.encode("utf-8"))
173            utils.MakeTarFile(src_dict, output_file)
174
175    def SetDeviceBuildInfo(self):
176        """Add devices build info."""
177        for device in self._devices:
178            device.build_info = self._device_factory.GetBuildInfoDict()
179
180    @property
181    def devices(self):
182        """Returns a list of devices in the pool.
183
184        Returns:
185            A list of devices in the pool.
186        """
187        return self._devices
188
189def _GetErrorType(error):
190    """Get proper error type from the exception error.
191
192    Args:
193        error: errors object.
194
195    Returns:
196        String of error type. e.g. "ACLOUD_BOOT_UP_ERROR".
197    """
198    if isinstance(error, errors.CheckGCEZonesQuotaError):
199        return constants.GCE_QUOTA_ERROR
200    if isinstance(error, errors.DownloadArtifactError):
201        return constants.ACLOUD_DOWNLOAD_ARTIFACT_ERROR
202    if isinstance(error, errors.DeviceConnectionError):
203        return constants.ACLOUD_SSH_CONNECT_ERROR
204    for keyword in _GCE_QUOTA_ERROR_KEYWORDS:
205        if keyword in str(error):
206            return constants.GCE_QUOTA_ERROR
207    return constants.ACLOUD_UNKNOWN_ERROR
208
209# pylint: disable=too-many-locals,unused-argument,too-many-branches,too-many-statements
210def CreateDevices(command, cfg, device_factory, num, avd_type,
211                  report_internal_ip=False, autoconnect=False,
212                  serial_log_file=None, client_adb_port=None,
213                  boot_timeout_secs=None, unlock_screen=False,
214                  wait_for_boot=True, connect_webrtc=False,
215                  ssh_private_key_path=None,
216                  ssh_user=constants.GCE_USER):
217    """Create a set of devices using the given factory.
218
219    Main jobs in create devices.
220        1. Create GCE instance: Launch instance in GCP(Google Cloud Platform).
221        2. Starting up AVD: Wait device boot up.
222
223    Args:
224        command: The name of the command, used for reporting.
225        cfg: An AcloudConfig instance.
226        device_factory: A factory capable of producing a single device.
227        num: The number of devices to create.
228        avd_type: String, the AVD type(cuttlefish, goldfish...).
229        report_internal_ip: Boolean to report the internal ip instead of
230                            external ip.
231        serial_log_file: String, the file path to tar the serial logs.
232        autoconnect: Boolean, whether to auto connect to device.
233        client_adb_port: Integer, Specify port for adb forwarding.
234        boot_timeout_secs: Integer, boot timeout secs.
235        unlock_screen: Boolean, whether to unlock screen after invoke vnc client.
236        wait_for_boot: Boolean, True to check serial log include boot up
237                       message.
238        connect_webrtc: Boolean, whether to auto connect webrtc to device.
239        ssh_private_key_path: String, the private key for SSH tunneling.
240        ssh_user: String, the user name for SSH tunneling.
241
242    Raises:
243        errors: Create instance fail.
244
245    Returns:
246        A Report instance.
247    """
248    reporter = report.Report(command=command)
249    try:
250        CreateSshKeyPairIfNecessary(cfg)
251        device_pool = DevicePool(device_factory)
252        device_pool.CreateDevices(num)
253        device_pool.SetDeviceBuildInfo()
254        if wait_for_boot:
255            failures = device_pool.WaitForBoot(boot_timeout_secs)
256        else:
257            failures = device_factory.GetFailures()
258
259        if failures:
260            reporter.SetStatus(report.Status.BOOT_FAIL)
261        else:
262            reporter.SetStatus(report.Status.SUCCESS)
263
264        # Collect logs
265        logs = device_factory.GetLogs()
266        if serial_log_file:
267            device_pool.CollectSerialPortLogs(
268                serial_log_file, port=constants.DEFAULT_SERIAL_PORT)
269
270        device_pool.UpdateReport(reporter)
271        # Write result to report.
272        for device in device_pool.devices:
273            ip = (device.ip.internal if report_internal_ip
274                  else device.ip.external)
275            extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel
276            # TODO(b/154175542): Report multiple devices.
277            vnc_ports = device_factory.GetVncPorts()
278            adb_ports = device_factory.GetAdbPorts()
279            if not vnc_ports[0] and not adb_ports[0]:
280                vnc_ports[0], adb_ports[0] = utils.AVD_PORT_DICT[avd_type]
281
282            device_dict = {
283                "ip": ip + (":" + str(adb_ports[0]) if adb_ports[0] else ""),
284                "instance_name": device.instance_name
285            }
286            if device.build_info:
287                device_dict.update(device.build_info)
288            if device.time_info:
289                device_dict.update(device.time_info)
290            if device.openwrt:
291                device_dict.update(device_factory.GetOpenWrtInfoDict())
292            if device.gce_hostname:
293                device_dict[constants.GCE_HOSTNAME] = device.gce_hostname
294                logger.debug(
295                    "To connect with hostname, erase the extra_args_ssh_tunnel: %s",
296                    extra_args_ssh_tunnel)
297                extra_args_ssh_tunnel=""
298            if autoconnect and reporter.status == report.Status.SUCCESS:
299                forwarded_ports = _EstablishAdbVncConnections(
300                    device.gce_hostname or ip, vnc_ports, adb_ports,
301                    client_adb_port, ssh_user,
302                    ssh_private_key_path=(ssh_private_key_path or
303                                          cfg.ssh_private_key_path),
304                    extra_args_ssh_tunnel=extra_args_ssh_tunnel,
305                    unlock_screen=unlock_screen)
306                if forwarded_ports:
307                    forwarded_port = forwarded_ports[0]
308                    device_dict[constants.VNC_PORT] = forwarded_port.vnc_port
309                    device_dict[constants.ADB_PORT] = forwarded_port.adb_port
310                    device_dict[constants.DEVICE_SERIAL] = (
311                        constants.REMOTE_INSTANCE_ADB_SERIAL %
312                        forwarded_port.adb_port)
313            if connect_webrtc and reporter.status == report.Status.SUCCESS:
314                webrtc_local_port = utils.PickFreePort()
315                device_dict[constants.WEBRTC_PORT] = webrtc_local_port
316                utils.EstablishWebRTCSshTunnel(
317                    ip_addr=device.gce_hostname or ip,
318                    webrtc_local_port=webrtc_local_port,
319                    rsa_key_file=(ssh_private_key_path or
320                                  cfg.ssh_private_key_path),
321                    ssh_user=ssh_user,
322                    extra_args_ssh_tunnel=extra_args_ssh_tunnel)
323            if device.instance_name in logs:
324                device_dict[constants.LOGS] = logs[device.instance_name]
325            if hasattr(device_factory, 'GetFetchCvdWrapperLogIfExist'):
326                fetch_cvd_wrapper_log = device_factory.GetFetchCvdWrapperLogIfExist()
327                if fetch_cvd_wrapper_log:
328                    device_dict["fetch_cvd_wrapper_log"] = fetch_cvd_wrapper_log
329            if device.instance_name in failures:
330                reporter.SetErrorType(constants.ACLOUD_BOOT_UP_ERROR)
331                if device.stage:
332                    reporter.SetErrorType(_DICT_ERROR_TYPE[device.stage])
333                reporter.AddData(key="devices_failing_boot", value=device_dict)
334                reporter.AddError(str(failures[device.instance_name]))
335            else:
336                reporter.AddData(key="devices", value=device_dict)
337    except (errors.DriverError, errors.CheckGCEZonesQuotaError) as e:
338        reporter.SetErrorType(_GetErrorType(e))
339        reporter.AddError(str(e))
340        reporter.SetStatus(report.Status.FAIL)
341    return reporter
342
343
344def _EstablishAdbVncConnections(ip, vnc_ports, adb_ports, client_adb_port,
345                                ssh_user, ssh_private_key_path,
346                                extra_args_ssh_tunnel, unlock_screen):
347    """Establish the adb and vnc connections.
348
349    Create the ssh tunnels with adb ports and vnc ports. Then unlock the device
350    screen via the adb port.
351
352    Args:
353        ip: String, the IPv4 address.
354        vnc_ports: List of integer, the vnc ports.
355        adb_ports: List of integer, the adb ports.
356        client_adb_port: Integer, Specify port for adb forwarding.
357        ssh_user: String, the user name for SSH tunneling.
358        ssh_private_key_path: String, the private key for SSH tunneling.
359        extra_args_ssh_tunnel: String, extra args for ssh tunnel connection.
360        unlock_screen: Boolean, whether to unlock screen after invoking vnc client.
361
362    Returns:
363        A list of namedtuple of (vnc_port, adb_port)
364    """
365    forwarded_ports = []
366    for vnc_port, adb_port in zip(vnc_ports, adb_ports):
367        forwarded_port = utils.AutoConnect(
368            ip_addr=ip,
369            rsa_key_file=ssh_private_key_path,
370            target_vnc_port=vnc_port,
371            target_adb_port=adb_port,
372            ssh_user=ssh_user,
373            client_adb_port=client_adb_port,
374            extra_args_ssh_tunnel=extra_args_ssh_tunnel)
375        forwarded_ports.append(forwarded_port)
376        if unlock_screen:
377            AdbTools(forwarded_port.adb_port).AutoUnlockScreen()
378    return forwarded_ports
379