1#!/usr/bin/env python3 2# 3# Copyright 2022 - Google 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17import logging 18from typing import Sequence 19 20import paramiko 21 22COMMAND_RETRY_TIMES = 3 23 24 25class RunCommandError(Exception): 26 """Raises an error when run command fail.""" 27 28 29class NotConnectedError(Exception): 30 """Raises an error when run command without SSH connect.""" 31 32 33class RemoteClient: 34 """The SSH client class interacts with the test machine. 35 36 Attributes: 37 host: A string representing the IP address of amarisoft. 38 port: A string representing the default port of SSH. 39 username: A string representing the username of amarisoft. 40 password: A string representing the password of amarisoft. 41 ssh: A SSH client. 42 sftp: A SFTP client. 43 """ 44 45 def __init__(self, 46 host: str, 47 username: str, 48 password: str, 49 port: str = '22') -> None: 50 self.host = host 51 self.port = port 52 self.username = username 53 self.password = password 54 self.ssh = paramiko.SSHClient() 55 self.sftp = None 56 57 def ssh_is_connected(self) -> bool: 58 """Checks SSH connect or not. 59 60 Returns: 61 True if SSH is connected, False otherwise. 62 """ 63 return self.ssh and self.ssh.get_transport().is_active() 64 65 def ssh_close(self) -> bool: 66 """Closes the SSH connection. 67 68 Returns: 69 True if ssh session closed, False otherwise. 70 """ 71 for _ in range(COMMAND_RETRY_TIMES): 72 if self.ssh_is_connected(): 73 self.ssh.close() 74 else: 75 return True 76 return False 77 78 def connect(self) -> bool: 79 """Creats SSH connection. 80 81 Returns: 82 True if success, False otherwise. 83 """ 84 for _ in range(COMMAND_RETRY_TIMES): 85 try: 86 self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 87 self.ssh.connect(self.host, self.port, self.username, self.password) 88 self.ssh.get_transport().set_keepalive(1) 89 self.sftp = paramiko.SFTPClient.from_transport(self.ssh.get_transport()) 90 return True 91 except Exception: # pylint: disable=broad-except 92 self.ssh_close() 93 return False 94 95 def run_cmd(self, cmd: str) -> Sequence[str]: 96 """Runs shell command. 97 98 Args: 99 cmd: Command to be executed. 100 101 Returns: 102 Standard output of the shell command. 103 104 Raises: 105 RunCommandError: Raise error when command failed. 106 NotConnectedError: Raised when run command without SSH connect. 107 """ 108 if not self.ssh_is_connected(): 109 raise NotConnectedError('ssh remote has not been established') 110 111 logging.debug('ssh remote -> %s', cmd) 112 _, stdout, stderr = self.ssh.exec_command(cmd) 113 err = stderr.readlines() 114 if err: 115 logging.error('command failed.') 116 raise RunCommandError(err) 117 return stdout.readlines() 118 119 def is_file_exist(self, file: str) -> bool: 120 """Checks target file exist. 121 122 Args: 123 file: Target file with absolute path. 124 125 Returns: 126 True if file exist, false otherwise. 127 """ 128 return any('exist' in line for line in self.run_cmd( 129 f'if [ -f "{file}" ]; then echo -e "exist"; fi')) 130 131 def sftp_upload(self, src: str, dst: str) -> bool: 132 """Uploads a local file to remote side. 133 134 Args: 135 src: The target file with absolute path. 136 dst: The absolute path to put the file with file name. 137 For example: 138 upload('/usr/local/google/home/zoeyliu/Desktop/sample_config.yml', 139 '/root/sample_config.yml') 140 141 Returns: 142 True if file upload success, False otherwise. 143 144 Raises: 145 NotConnectedError: Raised when run command without SSH connect. 146 """ 147 if not self.ssh_is_connected(): 148 raise NotConnectedError('ssh remote has not been established') 149 if not self.sftp: 150 raise NotConnectedError('sftp remote has not been established') 151 152 logging.info('[local] %s -> [remote] %s', src, dst) 153 self.sftp.put(src, dst) 154 return self.is_file_exist(dst) 155 156 def sftp_download(self, src: str, dst: str) -> bool: 157 """Downloads a file to local. 158 159 Args: 160 src: The target file with absolute path. 161 dst: The absolute path to put the file. 162 163 Returns: 164 True if file download success, False otherwise. 165 166 Raises: 167 NotConnectedError: Raised when run command without SSH connect. 168 """ 169 if not self.ssh_is_connected(): 170 raise NotConnectedError('ssh remote has not been established') 171 if not self.sftp: 172 raise NotConnectedError('sftp remote has not been established') 173 174 logging.info('[remote] %s -> [local] %s', src, dst) 175 self.sftp.get(src, dst) 176 return self.is_file_exist(dst) 177 178 def sftp_list_dir(self, path: str) -> Sequence[str]: 179 """Lists the names of the entries in the given path. 180 181 Args: 182 path: The path of the list. 183 184 Returns: 185 The names of the entries in the given path. 186 187 Raises: 188 NotConnectedError: Raised when run command without SSH connect. 189 """ 190 if not self.ssh_is_connected(): 191 raise NotConnectedError('ssh remote has not been established') 192 if not self.sftp: 193 raise NotConnectedError('sftp remote has not been established') 194 return sorted(self.sftp.listdir(path)) 195 196