python ssh库

  1. 版本1
# -*- coding: utf-8 -*-

import time
import paramiko
import sys
import re
from tenacity import retry, stop_after_attempt, wait_fixed
import functools
from concurrent import futures
from threading import Lock
import socket
from paramiko.ssh_exception import SSHException
from paramiko.ssh_exception import AuthenticationException


class NetmikoTimeoutException(SSHException):
    """SSH session timed trying to connect to the device."""
    pass


class NetmikoAuthenticationException(AuthenticationException):
    """SSH authentication exception based on Paramiko AuthenticationException."""
    pass


MAX_BUFFER = 65535
global_delay_factor=1
NetMikoTimeoutException = NetmikoTimeoutException
NetMikoAuthenticationException = NetmikoAuthenticationException

executor = futures.ThreadPoolExecutor(1)


def timeout(timeout):
    def decorator(func):
        functools.wraps(func)
        def wrapper(*args, **kw):
            return executor.submit(func, *args, **kw).result(timeout=timeout)
        return wrapper
    return decorator


class ParaSession(object):
    # will init invoke_shell
    def __init__(self, hostname, password, port=22, username='root', timeout=60):
        self.t = None  # paramiko.Transport
        self.sftp = None
        self._closed = True
        self._channel_closed = True
        self._sftp_closed = True
        self.hostname = hostname
        self.password = password
        self.port = port
        self.username = username
        self.timeout = timeout
        print('- start to create SSH connection -')
        self.client = paramiko.SSHClient()
        self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        # noinspection PyBroadException
        try:
            self.client.connect(hostname=hostname,
                                port=port,
                                username=username,
                                password=password,
                                timeout=timeout)
            self._closed = False
        except Exception as e:
            print(f'Connection Error And please check the ENV: {str(e)}')
        else:
            try:
                self.channel = self.client.invoke_shell()
                self._channel_closed = False
                print('- Connection created successfully -')
            except Exception as e:
                try:
                    self.close()
                except:
                    pass
                print(f'Open channel failed And please check the ENV: {str(e)}')

    def sftp_client(self):
        """
        :return: sftp obj
        """
        try:
            self.t = paramiko.Transport((self.hostname, self.port))
            self.t.connect(username=self.username, password=self.password)
            self._sftp_closed = False
            self.sftp = paramiko.SFTPClient.from_transport(self.t)
            print('- SFTP created successfully -')
            return self.sftp
        except Exception as e:
            try:
                self.close()
            except:
                pass
            print(f'Open sftp failed And please check the ENV: {str(e)}')

    def close(self):
        if not self._closed:
            if not self._channel_closed:
                self.channel.close()
                self._channel_closed = True
            self.client.close()
            self._closed = True
        if not self._sftp_closed:
            self.sftp.close()
            self._sftp_closed = True

    def __del__(self):
        if not self._closed:
            self.close()


class BaseSSHSession(object):
    def __init__(self, hostname, username, password, su_password, port=22,timeout=100):
        self.hostname = hostname
        self.password = password
        self.username = username
        self.su_password = su_password
        self.port = port
        self.global_delay_factor = global_delay_factor
        self.base_prompt = None
        self.root_prompt = None
        self.timeout = timeout
        self.RETURN = "
"
        self.ssh_obj = ParaSession(hostname=hostname, username=username, password=password, port=port)
        self.remote_conn = self.ssh_obj.channel
        self.session_preparation()
        self._session_locker = Lock()

    # 使用的时候不一定是超时才会到这里,调试慎重
    # @retry(stop=stop_after_attempt(2), wait=wait_fixed(5))
    # @timeout(6)
    def session_preparation(self, delay_factor=1):
        EXIT = 'exit'+ self.RETURN
        self.set_base_prompt()
        # self.set_root_prompt()
        # self.write_channel(EXIT)
        self.clear_buffer()

        # self.exec_ssh_cmd('whoami')
        # print(self.base_prompt)
        # print(self.root_prompt)

    def exec_ssh_cmd(self, cmd, delay_factor=2):
        sleep_time = delay_factor * 0.1
        time.sleep(sleep_time)
        if hasattr(self.ssh_obj, 'channel'):
            try:
                self.write_channel(cmd + self.RETURN)
            except Exception as e:
                print(f'CMD send errors: {str(e)}')
                return None
            time.sleep(delay_factor)
            print('Exec ssh command: %s' % str(cmd))
            try:
                buff = ''
                while not (self.base_prompt in buff):
                    resp = self.read_channel()
                    buff += resp
            except Exception as e:
                print(f'CMD receive errors: {str(e)}')
            return resp
        else:
            self.close_all_session()
            print(f'session error and please check the ENV.')
            return None

    def open_sftp_session(self, hostname, password, port=22, username='root'):
        ssh_sftp = ParaSession(hostname=hostname, password=password, port=port, username=username).sftp_client()
        return ssh_sftp

    def get_remotefile(self, remote_path, local_path, session='default'):
        if hasattr(self.ssh_obj,  'sftp'):
            try:
                self.ssh_obj.sftp.get(remote_path, local_path)
                return True
            except Exception as e:
                print(f'Get remote file errors: {str(e)}')
                return False
        else:
            self.close_all_session()
            print(f'session: {session} is error and please check the ENV.')
            return None

    def put_localfile(self, local_path, remote_path, session='default'):
        if hasattr(self.ssh_obj,  'sftp'):
            try:
                self.ssh_obj.sftp.put(local_path, remote_path)
                return True
            except Exception as e:
                print(f'Put local file errors: {str(e)}')
                return False
        else:
            self.close_all_session()
            print(f'session: {session} is error and please check the ENV.')
            return None

    def close_session(self):
        self.ssh_obj.close()

    def close_all_session(self):
        self.close_session()

    def __del__(self):
        self.close_all_session()

    def _read_channel(self):
        output = ""
        while True:
            if self.remote_conn.recv_ready():
                outbuf = self.remote_conn.recv(MAX_BUFFER)  # 会挂住,需要recv_ready()判断
                if len(outbuf) == 0:
                    raise EOFError("Channel stream closed by remote device.")
                output += outbuf.decode("utf-8", "ignore")
            else:
                break
        return output

    def normalize_linefeeds(self, a_string):
        newline = re.compile("(


|

|
|

)")
        a_string = newline.sub("
", a_string)
        return re.sub("
", "
", a_string)

    def read_channel(self):
        return self._read_channel()

    def write_bytes(self, out_data, encoding="ascii"):
        """Legacy for Python2 and Python3 compatible byte stream."""
        if sys.version_info[0] >= 3:
            if isinstance(out_data, type("")):
                if encoding == "utf-8":
                    return out_data.encode("utf-8")
                else:
                    return out_data.encode("ascii", "ignore")
            elif isinstance(out_data, type(b"")):
                return out_data
        msg = "Invalid value for out_data neither unicode nor byte string: {}".format(
            out_data
        )
        raise ValueError(msg)

    def _write_channel(self, out_data):
        self.remote_conn.sendall(self.write_bytes(out_data))

    def write_channel(self, out_data):
        self._write_channel(out_data)

    def find_prompt(self, delay_factor=1):
        # RETURN = "
"
        sleep_time = delay_factor * 0.1
        time.sleep(sleep_time)
        prompt = self.read_channel().strip()
        # Check if the only thing you received was a newline
        count = 0
        while count <= 12 and not prompt:
            prompt = self.read_channel().strip()
            if not prompt:
                self.write_channel(self.RETURN)
                time.sleep(sleep_time)
                if sleep_time <= 3:
                    # Double the sleep_time when it is small
                    sleep_time *= 2
                else:
                    sleep_time += 1
            count += 1
        # If multiple lines in the output take the last line
        prompt = self.normalize_linefeeds(prompt)
        prompt = prompt.split("
")[-1]
        prompt = prompt.strip()
        if not prompt:
            raise ValueError(f"Unable to find prompt: {prompt}")
        time.sleep(delay_factor * 0.1)
        return prompt

    def clear_buffer(self, backoff=True, delay_factor=1):
        """Read any data available in the channel."""
        sleep_time = 0.1 * delay_factor
        for _ in range(10):
            time.sleep(sleep_time)
            data = self.read_channel()
            if not data:
                break
            if backoff:
                sleep_time *= 2
                sleep_time = 3 if sleep_time >= 3 else sleep_time

    # 待废弃
    def root_su(self, password, delay_factor=1):
        sleep_time = delay_factor * 0.1
        RETURN = "
"
        waite_for_password = re.compile("Password:")
        prompt = self.read_channel().strip()
        count = 0
        while count <= 13 and not prompt:
            prompt = self.read_channel().strip()
            if not prompt:
                self.write_channel(RETURN)
                time.sleep(sleep_time)
                if sleep_time <= 3:
                    sleep_time *= 2
                else:
                    sleep_time += 1
            else:
                prompt = self.normalize_linefeeds(prompt)
                prompt = prompt.split("
")[-1]
                prompt = prompt.strip()
                if prompt.endswith('$'):
                    self.write_channel('su' + RETURN)
                    time.sleep(sleep_time)
                    prompt = self.read_channel().strip()
                if waite_for_password.search(prompt):
                    self.write_channel(password + RETURN)
                    time.sleep(sleep_time)
                    prompt = self.read_channel().strip()
                if prompt.endswith('#'):
                    return prompt
            count += 1
        if not prompt:
            raise ValueError(f"Unable to find prompt: {prompt}")

    def set_base_prompt(self, delay_factor=1, prompt_terminator="$", ):
        prompt = self.find_prompt(delay_factor=delay_factor)
        if not prompt[-1] in prompt_terminator:
            raise ValueError(f"Prompt not found: {repr(prompt)}")
        self.base_prompt = prompt[:-1]
        return self.base_prompt

    # 待废弃
    def set_root_prompt(self, delay_factor=1, prompt_terminator="#"):
        prompt = self.root_su(self.su_password, delay_factor=delay_factor)
        if not prompt[-1] in prompt_terminator:
            raise ValueError(f"Router prompt not found: {repr(prompt)}")
        self.root_prompt = prompt[:-1]
        return self.root_prompt

    def check_base_prompt(self, check_sre, prompt_terminator="$"):
        return self.base_prompt + prompt_terminator in check_sre

    def check_root_prompt(self, check_sre, prompt_terminator="#"):
        return self.root_prompt + prompt_terminator in check_sre

    def enable(self, cmd="", pattern="ssword", secret="", re_flags=re.IGNORECASE):
        output = ""
        msg = (
            "Failed to enter su mode. Please ensure you pass "
            "the 'secret' argument to ConnectHandler."
        )
        if not self.check_enable_mode():
            self.write_channel(self.normalize_cmd(cmd))
            try:
                output += self.read_until_prompt_or_pattern(
                    pattern=pattern, re_flags=re_flags
                )
                self.write_channel(self.normalize_cmd(secret))
                # output += self.read_until_prompt(pattern="#")
                output += self.read_until_prompt()
            except NetmikoTimeoutException:
                raise ValueError(msg)
            if not self.check_enable_mode():
                raise ValueError(msg)
        return output

    def exit_enable_mode(self, exit_command=""):
        output = ""
        if self.check_enable_mode():
            self.write_channel(self.normalize_cmd(exit_command))
            output += self.read_until_prompt()
            if self.check_enable_mode():
                raise ValueError("Failed to exit enable mode.")
        return output

    def normalize_cmd(self, command):
        command = command.rstrip()
        command += self.RETURN
        return command

    def check_enable_mode(self, check_string=""):
        self.write_channel(self.RETURN)
        output = self.read_until_prompt()
        return check_string in output

    def read_until_prompt_or_pattern(self, pattern="", re_flags=0):
        combined_pattern = re.escape(self.base_prompt)
        if pattern:
            combined_pattern = r"({}|{})".format(combined_pattern, pattern)
        return self._read_channel_expect(combined_pattern, re_flags=re_flags)

    def _read_channel_expect(self, pattern="", re_flags=0, max_loops=150):
        output = ""
        if not pattern:
            # 这里设置问题导致hang,需要重写子类set_base_prompt
            pattern = re.escape(self.base_prompt)
            # pattern = re.escape('int4-Standard-PC-i440FX-PIIX-1996')
        i = 1
        loop_delay = 0.1
        # Default to making loop time be roughly equivalent to self.timeout
        if max_loops == 150:
            max_loops = int(self.timeout / loop_delay)
        while i < max_loops:
            try:
                self._lock_netmiko_session()
                new_data = self.remote_conn.recv(MAX_BUFFER)
                if len(new_data) == 0:
                    raise EOFError("Channel stream closed by remote device.")
                new_data = new_data.decode("utf-8", "ignore")
                output += new_data

            except socket.timeout:
                raise NetmikoTimeoutException(
                    "Timed-out reading channel, data not available."
                )
            finally:
                self._unlock_netmiko_session()

            if re.search(pattern, output, flags=re_flags):
                return output
            time.sleep(loop_delay * self.global_delay_factor)
            i += 1
            # print('_read_channel_expect:',i,': ',loop_delay * self.global_delay_factor)
        raise NetmikoTimeoutException(
            f"Timed-out reading channel, pattern not found in output: {pattern}"
        )

    def read_until_prompt(self, *args, **kwargs):
        return self._read_channel_expect(*args, **kwargs)

    def _lock_netmiko_session(self, start=None):
        if not start:
            start = time.time()
        # Wait here until the SSH channel lock is acquired or until session_timeout exceeded
        while not self._session_locker.acquire(False) and not self._timeout_exceeded(
            start, "The netmiko channel is not available!"
        ):
            time.sleep(0.1)
        return True

    def _unlock_netmiko_session(self):
        if self._session_locker.locked():
            self._session_locker.release()

    def _timeout_exceeded(self, start, msg="Timeout exceeded!"):
        if not start:
            # Must provide a comparison time
            return False
        if time.time() - start > self.session_timeout:
            # session_timeout exceeded
            raise NetmikoTimeoutException(msg)
        return False


class LinuxBaseConnection(BaseSSHSession):
    """Base Class for cisco-like behavior."""

    def check_enable_mode(self, check_string="#"):
        """Check if in enable mode. Return boolean."""
        return super().check_enable_mode(check_string=check_string)

    def enable(self, cmd="su", pattern="ssword", secret="nokia123", re_flags=re.IGNORECASE):
        """Enter enable mode."""
        return super().enable(cmd=cmd, pattern=pattern, secret=secret, re_flags=re_flags)

    def exit_enable_mode(self, exit_command="disable"):
        """Exits enable (privileged exec) mode."""
        return super().exit_enable_mode(exit_command=exit_command)

    def set_base_prompt(self):
        base_prompt_re = re.compile("@(w.*):")
        prompt = super().set_base_prompt()
        prompt = base_prompt_re.search(prompt)
        self.base_prompt = prompt[1]
        return self.base_prompt


if __name__ == '__main__':
    host_ip = '10.101.35.249'
    user_name = 'int4'
    pass_word = 'nokia123'
    su_password = pass_word

    ssh_session_obj = LinuxBaseConnection(host_ip, user_name, pass_word, su_password)
    result_su = ssh_session_obj.exec_ssh_cmd('ifconfig')
    print(result_su)

    # 方式1,待废弃
    # ssh_session_obj.root_su(password=su_password)
    # 方式2
    ssh_session_obj.enable()

    result_su = ssh_session_obj.exec_ssh_cmd('whoami')
    print(result_su)
    result_su = ssh_session_obj.exec_ssh_cmd('ifconfig')
    print(result_su)

原文地址:https://www.cnblogs.com/amize/p/15146720.html