1# Copyright 2019 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Ssh Utilities."""
15from __future__ import print_function
16import logging
17
18import re
19import subprocess
20import sys
21import threading
22
23from acloud import errors
24from acloud.internal import constants
25from acloud.internal.lib import utils
26
27logger = logging.getLogger(__name__)
28
29_SSH_CMD = ("-i %(rsa_key_file)s -o LogLevel=ERROR -o ControlPath=none "
30            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no")
31_SSH_IDENTITY = "-l %(login_user)s %(ip_addr)s"
32SSH_CMD_DEFAULT_RETRY = 5
33_SSH_CMD_RETRY_SLEEP = 3
34_CONNECTION_TIMEOUT = 10
35_MAX_REPORTED_ERROR_LINES = 10
36_ERROR_MSG_RE = re.compile(r".*]\s*\"(?:message|response)\"\s:\s\"(?P<content>.*)\"")
37_ERROR_MSG_TO_QUOTE_RE = r"(\\u2019)|(\\u2018)"
38_ERROR_MSG_DEL_STYLE_RE = r"(<style.+\/style>)"
39_ERROR_MSG_DEL_TAGS_RE = (r"(<[\/]*(a|b|p|span|ins|code|title)>)|"
40                          r"(<(a|span|meta|html|!)[^>]*>)")
41
42
43def _SshCallWait(cmd, timeout=None):
44    """Runs a single SSH command.
45
46    - SSH returns code 0 for "Successful execution".
47    - Use wait() until the process is complete without receiving any output.
48
49    Args:
50        cmd: String of the full SSH command to run, including the SSH binary
51             and its arguments.
52        timeout: Optional integer, number of seconds to give
53
54    Returns:
55        An exit status of 0 indicates that it ran successfully.
56    """
57    logger.info("Running command \"%s\"", cmd)
58    process = subprocess.Popen(cmd, shell=True, stdin=None,
59                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
60    if timeout:
61        # TODO: if process is killed, out error message to log.
62        timer = threading.Timer(timeout, process.kill)
63        timer.start()
64    process.wait()
65    if timeout:
66        timer.cancel()
67    return process.returncode
68
69
70def _SshCall(cmd, timeout=None):
71    """Runs a single SSH command.
72
73    - SSH returns code 0 for "Successful execution".
74    - Use communicate() until the process and the child thread are complete.
75
76    Args:
77        cmd: String of the full SSH command to run, including the SSH binary
78             and its arguments.
79        timeout: Optional integer, number of seconds to give
80
81    Returns:
82        An exit status of 0 indicates that it ran successfully.
83    """
84    logger.info("Running command \"%s\"", cmd)
85    process = subprocess.Popen(cmd, shell=True, stdin=None,
86                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
87    if timeout:
88        # TODO: if process is killed, out error message to log.
89        timer = threading.Timer(timeout, process.kill)
90        timer.start()
91    process.communicate()
92    if timeout:
93        timer.cancel()
94    return process.returncode
95
96
97def _SshLogOutput(cmd, timeout=None, show_output=False, hide_error_msg=False):
98    """Runs a single SSH command while logging its output and processes its return code.
99
100    Output is streamed to the log at the debug level for more interactive debugging.
101    SSH returns error code 255 for "failed to connect", so this is interpreted as a failure in
102    SSH rather than a failure on the target device and this is converted to a different exception
103    type.
104
105    Args:
106        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
107        timeout: Optional integer, number of seconds to give.
108        show_output: Boolean, True to show command output in screen.
109        hide_error_msg: Boolean, True to hide error message.
110
111    Returns:
112        A string, stdout and stderr.
113
114    Raises:
115        errors.DeviceConnectionError: Failed to connect to the GCE instance.
116        subprocess.CalledProcessError: The process exited with an error on the instance.
117        errors.LaunchCVDFail: Happened on launch_cvd with specific pattern of error message.
118    """
119    # Use "exec" to let cmd to inherit the shell process, instead of having the
120    # shell launch a child process which does not get killed.
121    cmd = "exec " + cmd
122    logger.info("Running command \"%s\"", cmd)
123    process = subprocess.Popen(cmd, shell=True, stdin=None,
124                               stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
125                               universal_newlines=True)
126    if timeout:
127        # TODO: if process is killed, out error message to log.
128        timer = threading.Timer(timeout, process.kill)
129        timer.start()
130    stdout, _ = process.communicate()
131    if stdout:
132        if (show_output or process.returncode != 0) and not hide_error_msg:
133            print(stdout.strip(), file=sys.stderr)
134        else:
135            # fetch_cvd and launch_cvd can be noisy, so left at debug
136            logger.debug(stdout.strip())
137    if timeout:
138        timer.cancel()
139    if process.returncode == 255:
140        error_msg = (f"Failed to send command to instance {cmd}\n"
141                     f"Error message: {_GetErrorMessage(stdout)}")
142        if constants.ERROR_MSG_SSO_INVALID in stdout:
143            raise errors.SshConnectFail(error_msg)
144        raise errors.DeviceConnectionError(error_msg)
145    if process.returncode != 0:
146        if constants.ERROR_MSG_VNC_NOT_SUPPORT in stdout:
147            raise errors.LaunchCVDFail(constants.ERROR_MSG_VNC_NOT_SUPPORT)
148        if constants.ERROR_MSG_WEBRTC_NOT_SUPPORT in stdout:
149            raise errors.LaunchCVDFail(constants.ERROR_MSG_WEBRTC_NOT_SUPPORT)
150        raise subprocess.CalledProcessError(process.returncode, cmd)
151    return stdout
152
153
154def _GetErrorMessage(stdout):
155    """Get error message.
156
157    Fetch the content of "message" or "response" from the ssh output and filter
158    unused content then log into report. Once the two fields didn't match, to
159    log last _MAX_REPORTED_ERROR_LINES lines into report.
160
161    Args:
162        stdout: String of the ssh output.
163
164    Returns:
165        String of the formatted ssh output.
166    """
167    matches = _ERROR_MSG_RE.finditer(stdout)
168    for match in matches:
169        return _FilterUnusedContent(match.group("content"))
170    split_stdout = stdout.splitlines()[-_MAX_REPORTED_ERROR_LINES::]
171    return "\n".join(split_stdout)
172
173def _FilterUnusedContent(content):
174    """Filter unused content from html.
175
176    Remove the html tags and style from content.
177
178    Args:
179        content: String, html content.
180
181    Returns:
182        String without html style or tags.
183    """
184    content = re.sub(_ERROR_MSG_TO_QUOTE_RE, "'", content)
185    content = re.sub(_ERROR_MSG_DEL_STYLE_RE, "", content, flags=re.DOTALL)
186    content = re.sub(_ERROR_MSG_DEL_TAGS_RE, "", content)
187    content = re.sub(r"\\n", " ", content)
188    return content
189
190
191def ShellCmdWithRetry(cmd, timeout=None, show_output=False,
192                      retry=SSH_CMD_DEFAULT_RETRY):
193    """Runs a shell command on remote device.
194
195    If the network is unstable and causes SSH connect fail, it will retry. When
196    it retry in a short time, you may encounter unstable network. We will use
197    the mechanism of RETRY_BACKOFF_FACTOR. The retry time for each failure is
198    times * retries.
199
200    Args:
201        cmd: String of the full SSH command to run, including the SSH binary and its arguments.
202        timeout: Optional integer, number of seconds to give.
203        show_output: Boolean, True to show command output in screen.
204        retry: Integer, the retry times.
205
206    Returns:
207        A string, stdout and stderr.
208
209    Raises:
210        errors.DeviceConnectionError: For any non-zero return code of remote_cmd.
211        errors.LaunchCVDFail: Happened on launch_cvd with specific pattern of error message.
212        subprocess.CalledProcessError: The process exited with an error on the instance.
213    """
214    return utils.RetryExceptionType(
215        exception_types=(errors.DeviceConnectionError,
216                         errors.LaunchCVDFail,
217                         subprocess.CalledProcessError),
218        max_retries=retry,
219        functor=_SshLogOutput,
220        sleep_multiplier=_SSH_CMD_RETRY_SLEEP,
221        retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
222        cmd=cmd,
223        timeout=timeout,
224        show_output=show_output)
225
226
227class IP():
228    """ A class that control the IP address."""
229    def __init__(self, external=None, internal=None, ip=None):
230        """Init for IP.
231            Args:
232                external: String, external ip.
233                internal: String, internal ip.
234                ip: String, default ip to set for either external and internal
235                if neither is set.
236        """
237        self.external = external or ip
238        self.internal = internal or ip
239
240
241class Ssh():
242    """A class that control the remote instance via the IP address.
243
244    Attributes:
245        _ip: an IP object.
246        _user: String of user login into the instance.
247        _ssh_private_key_path: Path to the private key file.
248        _extra_args_ssh_tunnel: String, extra args for ssh or scp.
249        _report_internal_ip: Boolean, True to use internal ip.
250        _gce_hostname: String, the hostname for ssh connect.
251    """
252    def __init__(self, ip, user, ssh_private_key_path,
253                 extra_args_ssh_tunnel=None, report_internal_ip=False,
254                 gce_hostname=None):
255        self._ip = ip.internal if report_internal_ip else ip.external
256        self._user = user
257        self._ssh_private_key_path = ssh_private_key_path
258        self._extra_args_ssh_tunnel = extra_args_ssh_tunnel
259        if gce_hostname:
260            self._ip = gce_hostname
261            self._extra_args_ssh_tunnel = None
262            logger.debug(
263                "To connect with hostname, erase the extra_args_ssh_tunnel: %s",
264                extra_args_ssh_tunnel)
265
266    def Run(self, target_command, timeout=None, show_output=False,
267            retry=SSH_CMD_DEFAULT_RETRY):
268        """Run a shell command over SSH on a remote instance.
269
270        Example:
271            ssh:
272                base_cmd_list is ["ssh", "-i", "~/private_key_path" ,"-l" , "user", "1.1.1.1"]
273                target_command is "remote command"
274            scp:
275                base_cmd_list is ["scp", "-i", "~/private_key_path"]
276                target_command is "{src_file} {dst_file}"
277
278        Args:
279            target_command: String, text of command to run on the remote instance.
280            timeout: Integer, the maximum time to wait for the command to respond.
281            show_output: Boolean, True to show command output in screen.
282            retry: Integer, the retry times.
283
284        Returns:
285            A string, stdout and stderr.
286        """
287        return ShellCmdWithRetry(
288            self.GetBaseCmd(constants.SSH_BIN) + " " + target_command,
289            timeout,
290            show_output,
291            retry)
292
293    def GetBaseCmd(self, execute_bin):
294        """Get a base command over SSH on a remote instance.
295
296        Example:
297            execute bin is ssh:
298                ssh -i ~/private_key_path $extra_args -l user 1.1.1.1
299            execute bin is scp:
300                scp -i ~/private_key_path $extra_args
301
302        Args:
303            execute_bin: String, execute type, e.g. ssh or scp.
304
305        Returns:
306            Strings of base connection command.
307
308        Raises:
309            errors.UnknownType: Don't support the execute bin.
310        """
311        base_cmd = [utils.FindExecutable(execute_bin)]
312        base_cmd.append(_SSH_CMD % {"rsa_key_file": self._ssh_private_key_path})
313        if self._extra_args_ssh_tunnel:
314            base_cmd.append(self._extra_args_ssh_tunnel)
315
316        if execute_bin == constants.SSH_BIN:
317            base_cmd.append(_SSH_IDENTITY %
318                            {"login_user":self._user, "ip_addr":self._ip})
319            return " ".join(base_cmd)
320        if execute_bin == constants.SCP_BIN:
321            return " ".join(base_cmd)
322
323        raise errors.UnknownType("Don't support the execute bin %s." % execute_bin)
324
325    def GetCmdOutput(self, cmd):
326        """Runs a single SSH command and get its output.
327
328        Args:
329            cmd: String, text of command to run on the remote instance.
330
331        Returns:
332            String of the command output.
333        """
334        ssh_cmd = "exec " + self.GetBaseCmd(constants.SSH_BIN) + " " + cmd
335        logger.info("Running command \"%s\"", ssh_cmd)
336        process = subprocess.Popen(ssh_cmd, shell=True, stdin=None,
337                                   stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
338                                   universal_newlines=True)
339        stdout, _ = process.communicate()
340        return stdout
341
342    def CheckSshConnection(self, timeout):
343        """Run remote 'uptime' ssh command to check ssh connection.
344
345        Args:
346            timeout: Integer, the maximum time to wait for the command to respond.
347
348        Raises:
349            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
350        """
351        remote_cmd = [self.GetBaseCmd(constants.SSH_BIN)]
352        remote_cmd.append("uptime")
353        try:
354            _SshLogOutput(" ".join(remote_cmd), timeout, hide_error_msg=True)
355        except subprocess.CalledProcessError as e:
356            raise errors.DeviceConnectionError(
357                "Ssh isn't ready in the remote instance.") from e
358
359    @utils.TimeExecute(function_description="Waiting for SSH server")
360    def WaitForSsh(self, timeout=None, max_retry=SSH_CMD_DEFAULT_RETRY):
361        """Wait until the remote instance is ready to accept commands over SSH.
362
363        Args:
364            timeout: Integer, the maximum time in seconds to wait for the
365                     command to respond.
366            max_retry: Integer, the maximum number of retry.
367
368        Raises:
369            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
370        """
371        ssh_timeout = timeout or constants.DEFAULT_SSH_TIMEOUT
372        sleep_multiplier = ssh_timeout / sum(range(max_retry + 1))
373        logger.debug("Retry with interval time: %s secs", str(sleep_multiplier))
374        try:
375            utils.RetryExceptionType(
376                exception_types=errors.DeviceConnectionError,
377                max_retries=max_retry,
378                functor=self.CheckSshConnection,
379                sleep_multiplier=sleep_multiplier,
380                retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
381                timeout=_CONNECTION_TIMEOUT)
382        except errors.DeviceConnectionError as ssh_timeout:
383            ssh_cmd = "%s uptime" % self.GetBaseCmd(constants.SSH_BIN)
384            _SshLogOutput(ssh_cmd, timeout=_CONNECTION_TIMEOUT)
385            raise errors.DeviceConnectionError(
386                "Ssh connect timeout.\nYou can try the ssh connect command to "
387                "get detail information: '%s'" % ssh_cmd) from ssh_timeout
388
389    def ScpPushFile(self, src_file, dst_file):
390        """Scp push file to remote.
391
392        Args:
393            src_file: The source file path to be pulled.
394            dst_file: The destination file path the file is pulled to.
395        """
396        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
397        scp_command.append(src_file)
398        scp_command.append("%s@%s:%s" %(self._user, self._ip, dst_file))
399        ShellCmdWithRetry(" ".join(scp_command))
400
401    def ScpPushFiles(self, src_files, dst_dir):
402        """Push files to one specific folder of remote instance via scp command.
403
404        Args:
405            src_files: The source file path list to be pushed.
406            dst_dir: The destination directory the files to be pushed to.
407        """
408        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
409        scp_command.extend(src_files)
410        scp_command.append("%s@%s:%s" % (self._user, self._ip, dst_dir))
411        ShellCmdWithRetry(" ".join(scp_command))
412
413    def ScpPullFile(self, src_file, dst_file):
414        """Scp pull file from remote.
415
416        Args:
417            src_file: The source file path to be pulled.
418            dst_file: The destination file path the file is pulled to.
419        """
420        scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
421        scp_command.append("%s@%s:%s" %(self._user, self._ip, src_file))
422        scp_command.append(dst_file)
423        ShellCmdWithRetry(" ".join(scp_command))
424