1#!/usr/bin/env python3
2#
3#   Copyright 2022 - Google
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import logging
18from typing import Sequence
19
20import paramiko
21
22COMMAND_RETRY_TIMES = 3
23
24
25class RunCommandError(Exception):
26  """Raises an error when run command fail."""
27
28
29class NotConnectedError(Exception):
30  """Raises an error when run command without SSH connect."""
31
32
33class RemoteClient:
34  """The SSH client class interacts with the test machine.
35
36  Attributes:
37    host: A string representing the IP address of amarisoft.
38    port: A string representing the default port of SSH.
39    username: A string representing the username of amarisoft.
40    password: A string representing the password of amarisoft.
41    ssh: A SSH client.
42    sftp: A SFTP client.
43  """
44
45  def __init__(self,
46               host: str,
47               username: str,
48               password: str,
49               port: str = '22') -> None:
50    self.host = host
51    self.port = port
52    self.username = username
53    self.password = password
54    self.ssh = paramiko.SSHClient()
55    self.sftp = None
56
57  def ssh_is_connected(self) -> bool:
58    """Checks SSH connect or not.
59
60    Returns:
61      True if SSH is connected, False otherwise.
62    """
63    return self.ssh and self.ssh.get_transport().is_active()
64
65  def ssh_close(self) -> bool:
66    """Closes the SSH connection.
67
68    Returns:
69      True if ssh session closed, False otherwise.
70    """
71    for _ in range(COMMAND_RETRY_TIMES):
72      if self.ssh_is_connected():
73        self.ssh.close()
74      else:
75        return True
76    return False
77
78  def connect(self) -> bool:
79    """Creats SSH connection.
80
81    Returns:
82      True if success, False otherwise.
83    """
84    for _ in range(COMMAND_RETRY_TIMES):
85      try:
86        self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
87        self.ssh.connect(self.host, self.port, self.username, self.password)
88        self.ssh.get_transport().set_keepalive(1)
89        self.sftp = paramiko.SFTPClient.from_transport(self.ssh.get_transport())
90        return True
91      except Exception:  # pylint: disable=broad-except
92        self.ssh_close()
93    return False
94
95  def run_cmd(self, cmd: str) -> Sequence[str]:
96    """Runs shell command.
97
98    Args:
99      cmd: Command to be executed.
100
101    Returns:
102      Standard output of the shell command.
103
104    Raises:
105       RunCommandError: Raise error when command failed.
106       NotConnectedError: Raised when run command without SSH connect.
107    """
108    if not self.ssh_is_connected():
109      raise NotConnectedError('ssh remote has not been established')
110
111    logging.debug('ssh remote -> %s', cmd)
112    _, stdout, stderr = self.ssh.exec_command(cmd)
113    err = stderr.readlines()
114    if err:
115      logging.error('command failed.')
116      raise RunCommandError(err)
117    return stdout.readlines()
118
119  def is_file_exist(self, file: str) -> bool:
120    """Checks target file exist.
121
122    Args:
123        file: Target file with absolute path.
124
125    Returns:
126        True if file exist, false otherwise.
127    """
128    return any('exist' in line for line in self.run_cmd(
129        f'if [ -f "{file}" ]; then echo -e "exist"; fi'))
130
131  def sftp_upload(self, src: str, dst: str) -> bool:
132    """Uploads a local file to remote side.
133
134    Args:
135      src: The target file with absolute path.
136      dst: The absolute path to put the file with file name.
137      For example:
138        upload('/usr/local/google/home/zoeyliu/Desktop/sample_config.yml',
139        '/root/sample_config.yml')
140
141    Returns:
142      True if file upload success, False otherwise.
143
144    Raises:
145       NotConnectedError: Raised when run command without SSH connect.
146    """
147    if not self.ssh_is_connected():
148      raise NotConnectedError('ssh remote has not been established')
149    if not self.sftp:
150      raise NotConnectedError('sftp remote has not been established')
151
152    logging.info('[local] %s -> [remote] %s', src, dst)
153    self.sftp.put(src, dst)
154    return self.is_file_exist(dst)
155
156  def sftp_download(self, src: str, dst: str) -> bool:
157    """Downloads a file to local.
158
159    Args:
160      src: The target file with absolute path.
161      dst: The absolute path to put the file.
162
163    Returns:
164      True if file download success, False otherwise.
165
166    Raises:
167       NotConnectedError: Raised when run command without SSH connect.
168    """
169    if not self.ssh_is_connected():
170      raise NotConnectedError('ssh remote has not been established')
171    if not self.sftp:
172      raise NotConnectedError('sftp remote has not been established')
173
174    logging.info('[remote] %s -> [local] %s', src, dst)
175    self.sftp.get(src, dst)
176    return self.is_file_exist(dst)
177
178  def sftp_list_dir(self, path: str) -> Sequence[str]:
179    """Lists the names of the entries in the given path.
180
181    Args:
182      path: The path of the list.
183
184    Returns:
185      The names of the entries in the given path.
186
187    Raises:
188       NotConnectedError: Raised when run command without SSH connect.
189    """
190    if not self.ssh_is_connected():
191      raise NotConnectedError('ssh remote has not been established')
192    if not self.sftp:
193      raise NotConnectedError('sftp remote has not been established')
194    return sorted(self.sftp.listdir(path))
195
196