#!/usr/bin/env python3
import argparse
import os
import re
import logging
import paramiko
import datetime
import glob
import shutil
import time
import difflib


def get_args():
    parser = argparse.ArgumentParser(
        description='routerbackup - save routeros/ios/nexus/asa router configuration')

    parser.add_argument('-t', '--type',
                        required=True,
                        action='store',
                        help='device type {routeros,ios,nexus,asa}')
    parser.add_argument('-H', '--host',
                        required=True,
                        action='store',
                        help='host name, used also as backup filename base ')
    parser.add_argument('-u', '--user',
                        required=True,
                        action='store',
                        help='user name')
    parser.add_argument('-p', '--password',
                        required=False,
                        action='store',
                        help='password')
    parser.add_argument('-i', '--sshkey',
                        required=False,
                        action='store',
                        help='ssh private key file')
    parser.add_argument('-o', '--backupdir',
                        required=True,
                        action='store',
                        help='backup output directory')
    parser.add_argument('-a', '--address',
                        required=False,
                        action='store',
                        help='host address (default: hostname)')
    parser.add_argument('-l', '--logfile',
                        required=False,
                        action='store',
                        help='log file (no log if not given)')
    parser.add_argument('-L', '--loglevel',
                        required=False,
                        action='store',
                        help='log level')
    parser.add_argument('-d', '--diffprint',
                        required=False,
                        action='store_true',
                        help='print diff if configuration has changed')
    parser.add_argument('-D', '--difflog',
                        required=False,
                        action='store_true',
                        help='write diff to log when configuration has changed')
    parser.add_argument('-E', '--noerrors',
                        required=False,
                        action='store_true',
                        help='don\'t print any error messages')

    args = parser.parse_args()
    return args


class Config:
    def __init__(self, workdir):
        self.workdir = workdir
        self.ssh = None
        self.host = None
        self.user = None
        self.passwd = None
        self.sshkey = None
        self.config = None
        self.extra = {}
        self.filebase = None
        self.prev_config = None
        self.prev_timestamp = None
        self.prev_diff = []

    def connect(self, host=None, user=None, passwd=None, sshkey=None):
        if host: self.host = host
        if user: self.user = user
        if passwd: self.passwd = passwd
        if sshkey: self.sshkey = sshkey
        if not self.host or not self.user:
            raise ValueError('internal error: missing host or user')
        if not self.passwd and not self.sshkey:
            raise ValueError('no password and no ssh key given')

        self.ssh = paramiko.SSHClient()
        self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy)
        try:
            self.ssh.connect(self.host, username=self.user,
                password=self.passwd, key_filename=self.sshkey)
        except:
            # disabling sha2 makes sha1 default for older devices (routeros 6)
            self.ssh.connect(self.host, username=self.user,
                password=self.passwd, key_filename=self.sshkey,
                disabled_algorithms = dict(pubkeys=['rsa-sha2-256', 'rsa-sha2-512']))

    def is_bad_config(self):
        if not self.config:
            return 'is None.'
        if len(self.config) < 200:
            return 'lenght < 200 characters.'
        return False # is not bad

    def _prepare_config_for_diff(self, configlines):
        return configlines

    def set_filebase(self, base):
        self.filebase = base

    def load_previous(self):
        try:
            with open(f'{self.workdir}/{self.filebase}.config', 'r') as f:
                self.prev_config = f.read()
        except:
            log.info('No previous configuration found.')
            self.prev_config = None
            return

        self.prev_timestamp = 'previous'
        try:
            with open(f'{self.workdir}/{self.filebase}.lastchange', 'r') as f:
                oldunixts = int(f.readline())
                self.prev_timestamp = datetime.datetime.fromtimestamp(oldunixts)
        except:
            pass

    def has_previous(self):
        return True if self.prev_config else False

    def has_changed(self):
        self.prev_diff = []
        if not self.prev_config:
            return True
        if self.prev_config == self.config:
            return False

        contents = {}
        for cf, n in ((self.prev_config, 'old'), (self.config, 'new')):
            contents[n] = self._prepare_config_for_diff(cf.splitlines(keepends=True))

        has_changed = False
        for line in difflib.unified_diff(
                contents['old'], contents['new'],
                fromfile=f'{self.filebase} {self.prev_timestamp}', tofile=f'{self.filebase} now'
                ):
            has_changed = True
            self.prev_diff.append(line.rstrip())
        return has_changed

    def print_changed(self, diffprint=False, difflog=False):
        for line in self.prev_diff:
            if diffprint:
                print(line)
            if difflog:
                log.info(line)

    def write_to_disk(self, has_changed=True):
        basepath = f'{self.workdir}/{self.filebase}'
        with open(f'{basepath}.config', 'w') as f:
            f.write(self.config)
        for e in self.extra:
            if self.extra[e]['type'] == 'binary':
                mode = 'wb'
            else:
                mode = 'w'
            with open(f'{basepath}.{e}', mode) as f:
                f.write(self.extra[e]['content'])
        if has_changed:
            with open(f'{basepath}.lastchange', 'w') as f:
                f.write(f'{int(time.time())}\n{time.strftime("%F %T")}\n')
        log.info(f'Written {"differing" if has_changed else "similar"} configuration to: {basepath}.*')

    def backup_old(self):
        basepath = f'{self.workdir}/{self.filebase}'
        try:
            info = os.stat(basepath +'.config')
            oldtimestamp = datetime.datetime.fromtimestamp(info.st_mtime).strftime('%F_%H%M%S')
        except FileNotFoundError:
            return

        backuppath = f'{basepath}+{oldtimestamp}'
        os.mkdir(backuppath)
        for f in glob.glob(f'{basepath}.*'):
            shutil.move(f, backuppath)
        log.info(f'Moved old backup into: {backuppath}')



class RouterosConfig(Config):
    def __init__(self, workdir):
        super().__init__(workdir)
        self.extra['resource'] = {
            'type': 'text',
            'content': None
            }
        self.extra['iproute'] = {
            'type': 'text',
            'content': None
            }
        self.extra['backup'] = {
            'type': 'binary',
            'content': None
            }

    def connect(self, host=None, user=None, passwd=None, key=None):
        if user:
            user += '+c'
        super().connect(host, user, passwd, key)

    def get_config(self):
        stdin, stdout, stderr = self.ssh.exec_command('/system resource print')
        resource = []
        ros_version = None
        for line in stdout:
            line = line.strip()
            resource.append(line)
            if m:= re.match('version: *(\d)', line):
                ros_version = int(m[1])
        if ros_version == 6:
            export_cmds = [
                '/export terse',
                '/user export terse'
                ]
            certificate_cmd = '/certificate print detail'

        elif ros_version == 7:
            export_cmds = [
                '/export terse show-sensitive',
                '/user/export terse show-sensitive'
                ]
            certificate_cmd = '/certificate/print show-ids detail'
        else:
            raise RuntimeError("routeros version not found in system resource")

        export = []
        for cmd in export_cmds:
            stdin, stdout, stderr = self.ssh.exec_command(cmd)
            export.append('# '+ cmd.center(76, '='))
            preamble = True
            append_to_last = False
            for line in stdout:
                line = line.rstrip()
                if preamble:
                    if re.match('# .* by RouterOS ', line):
                        continue
                    if line[0] != '#':
                        preamble = False
                    export.append(line)
                    continue
                if re.match('# .* not ', line):
                    append_to_last = True
                    continue
                if append_to_last:
                    export[-1] += ' ' + line
                    append_to_last = False
                else:
                    export.append(line)

        stdin, stdout, stderr = self.ssh.exec_command(certificate_cmd)
        export.append('# '+ certificate_cmd.center(76, '='))
        for line in stdout:
            line = line.rstrip()
            line = re.sub(' *(days-valid|expires-after)=[^ ]*', '', line)
            export.append(line)

        iproute = []
        for cmd in ('/ip address print', '/ip route print', '/ip cloud print'):
            stdin, stdout, stderr = self.ssh.exec_command(cmd)
            iproute.append('# '+ cmd.center(76, '='))
            for line in stdout:
                line = line.rstrip()
                iproute.append(line)

        stdin, stdout, stderr = self.ssh.exec_command('/system backup save dont-encrypt=yes name=autobck')
        output = str(stdout.read(), 'utf8', 'backslashreplace')
        if not re.search('Configuration backup saved', output):
            raise RuntimeError(f'"/system backup" failed: {output.strip()}')

        sftp = self.ssh.open_sftp()
        with sftp.open('autobck.backup', 'r') as sf:
            backup = sf.read()

        self.config = '\n'.join(export)
        self.extra['resource']['content'] = '\n'.join(resource)
        self.extra['iproute']['content'] = '\n'.join(iproute)
        self.extra['backup']['content'] = backup

    def _prepare_config_for_diff(self, configlines):
        prepared = []
        for line in configlines:
            if re.match('/ip[ /]ipsec[ /]policy[ /]add', line):
                line = re.sub('(sa-(src|dst)-address)=[^ ]*', '\\1=...', line)
            prepared.append(line)
        return prepared


class IosConfig(Config):
    def __init__(self, workdir):
        super().__init__(workdir)
        self.extra['iproute'] = {
            'type': 'text',
            'content': None
            }

    def get_config(self):
        shrun = []
        stdin, stdout, stderr = self.ssh.exec_command('sh run')
        for line in stdout:
            line = line.rstrip()
            shrun.append(line)

        iproute = []
        for cmd in ('sh ip aliases', 'sh ip route'):
            # IOS allows only 1 exec per connection:
            self.connect()
            stdin, stdout, stderr = self.ssh.exec_command(cmd)
            iproute.append('# '+ cmd.center(76, '='))
            for line in stdout:
                line = line.rstrip()
                iproute.append(line)

        self.config = '\n'.join(shrun)
        self.extra['iproute']['content'] = '\n'.join(iproute)

    def is_bad_config(self):
        super().is_bad_config()
        if not re.search('end\s*$', self.config):
            return "does not end with 'end'."
        return False

    def _prepare_config_for_diff(self, configlines):
        prepared = []
        for line in configlines:
            if not re.search('\S', line):
                continue
            if re.match('! Last configuration change at ', line):
                continue
            prepared.append(line)
        return prepared


class NexusConfig(Config):
    def __init__(self, workdir):
        super().__init__(workdir)
        self.extra['iproute'] = {
            'type': 'text',
            'content': None
            }

    def get_config(self):
        shrun = []
        stdin, stdout, stderr = self.ssh.exec_command('sh run')
        for line in stdout:
            line = line.rstrip()
            shrun.append(line)

        iproute = []
        for cmd in ('sh ip interface brief', 'sh ip route'):
            stdin, stdout, stderr = self.ssh.exec_command(cmd)
            iproute.append('# '+ cmd.center(76, '='))
            for line in stdout:
                line = line.rstrip()
                iproute.append(line)

        self.config = '\n'.join(shrun)
        self.extra['iproute']['content'] = '\n'.join(iproute)

    def _prepare_config_for_diff(self, configlines):
        prepared = []
        for line in configlines:
            if not re.search('\S', line):
                continue
            if re.match('!Running configuration last done at', line):
                continue
            if re.match('!Time', line):
                continue
            prepared.append(line)
        return prepared


class ConfigWithShell(Config):
    def _shell_command(self, channel, cmd,
            waitfor=None,
            waitmax=None,
            send_timeout=10,
            recv_timeout=20
            ):
        start = time.time()
        while not channel.send_ready():
            if time.time() - start > send_timeout:
                raise RuntimeError('shell send timeout')
            time.sleep(0.1)
        channel.send(cmd +'\n')

        start = time.time()
        output = ''
        while True:
            while not channel.recv_ready():
                if len(output) > 0:
                    if waitfor and re.search(waitfor, output):
                        return output.splitlines()
                    if waitmax and time.time() - start > waitmax:
                        return output.splitlines()
                if time.time() - start > recv_timeout:
                    raise RuntimeError(f'shell recv timeout (after receiving {len(output)} characters)')
                time.sleep(0.1)
            while channel.recv_ready():
                output += str(channel.recv(65536), 'utf8', 'backslashreplace')


class AsaConfig(ConfigWithShell):
    def __init__(self, workdir):
        super().__init__(workdir)
        self.extra['iproute'] = {
            'type': 'text',
            'content': None
            }

    def get_config(self):
        shrun = []
        channel = self.ssh.invoke_shell()
        self._shell_command(channel, 'term pager 0', waitfor='#\s*$', waitmax=1)
        shrun = self._shell_command(channel, 'sh run', waitfor=': end')

        # strip late previous lines and command echo
        for i in range(0, 5):
            if re.search('^\s*: Saved\s*$', shrun[i]):
                del shrun[0:i]
        # strip garbage after ': end' line
        for i in range(-2, -6, -1): # -2! (-1 would delete all if : end is on last line)
            if re.search('^\s*: end\s*$', shrun[i]):
                del shrun[i+1:]

        iproute = []
        for cmd in ('sh ip address', 'sh route'):
            iproute.append('# '+ cmd.center(76, '='))
            output = self._shell_command(channel, cmd, waitfor='#\s*$', waitmax=2)
            iproute += output

        self.config = '\n'.join(shrun)
        self.extra['iproute']['content'] = '\n'.join(iproute)

    def is_bad_config(self):
        super().is_bad_config()
        if not re.search(': end', self.config):
            return "does not contain ': end'."
        return False


class ComwareConfig(ConfigWithShell):
    def get_config(self):
        shrun = []
        channel = self.ssh.invoke_shell()
        self._shell_command(channel, 'screen-length disable', waitfor='>\s*$', waitmax=1)
        dicur = self._shell_command(channel, 'display current-configuration', waitfor='(?m:^return\s*$)')
        # strip garbage after 'return' line
        for i in range(-2, -6, -1): # -2! (-1 would delete all if return is on last line)
            if re.search('^\s*return\s*$', dicur[i]):
                del dicur[i+1:]
        self.config = '\n'.join(dicur)

    def is_bad_config(self):
        super().is_bad_config()
        if not re.search('return', self.config):
            return "does not contain 'return'."
        return False


def setup_logging(hostname, logfile, loglevel):
    if logfile:
        handler = logging.FileHandler(logfile, 'a')
    else:
        handler = logging.StreamHandler()
    fmt = logging.Formatter(
        "{asctime} "+ hostname +"[{process}] {name}|{levelname} {message}",
        style='{')
    handler.setFormatter(fmt)
    level = logging.getLevelName(loglevel)
    if type(level) != int:
        level = logging.INFO
    logging.basicConfig(handlers=[handler], level=level)
    if level > logging.DEBUG:
        # paramiko is too verbose
        logging.getLogger('paramiko').setLevel(logging.WARNING)
    global log
    log = logging.getLogger('main')


def main():
    args = get_args()
    setup_logging(args.host, args.logfile, args.loglevel)
    try:
        if args.type not in ('routeros', 'ios', 'nexus', 'asa', 'comware'):
            raise ValueError(f'Invalid devtype: {args.type}')

        if args.type == 'routeros':
            cf = RouterosConfig(workdir=args.backupdir)
        elif args.type == 'ios':
            cf = IosConfig(workdir=args.backupdir)
        elif args.type == 'nexus':
            cf = NexusConfig(workdir=args.backupdir)
        elif args.type == 'asa':
            cf = AsaConfig(workdir=args.backupdir)
        elif args.type == 'comware':
            cf = ComwareConfig(workdir=args.backupdir)

        filebase = args.host
        cf.set_filebase(filebase)
        log.debug(f'Backup filename base: {filebase}')
        cf.load_previous()

        address = args.address or args.host
        log.info(f'Connecting to {args.type}: {args.user}@{address} with{" password" if args.password else""}{" ssh-key" if args.sshkey else""}')
        start = time.time()
        os.environ['HOME'] = '/' # prevent paramiko from searching private keys
        has_changed = False
        changed_count = 0
        for i in range(3):
            cf.connect(address, args.user, args.password, args.sshkey or '')
            cf.get_config()

            if reason := cf.is_bad_config():
                log.warning(f'Bad configuration: {reason}')
                #return
                time.sleep(1)
                continue

            has_changed = cf.has_changed()
            if not has_changed:
                # first try: good
                # second try: first was probably incomplete -> good
                break
            if not cf.has_previous():
                break
            changed_count += 1
            log.info(f'Got differing configuration on {changed_count}. attempt.')
            if changed_count >= 2:
                # got changed config twice, that is no mistake
                break
            time.sleep(3)

        if has_changed:
            cf.print_changed(diffprint=args.diffprint, difflog=args.difflog)
            cf.backup_old()
        # overwrite unchanged config, too, to update modification time and
        # also store changes that get skipped by _prepare_config_for_diff():
        cf.write_to_disk(has_changed)
        log.info(f'Finished in {time.time()-start:.1f}s.')
    except KeyboardInterrupt:
        log.error(f'Killed by KeyboardInterrupt.')
    except Exception as e:
        log.exception(f'Exception: %s', e)
        if not args.noerrors:
            print(f'Error: {e}')

if __name__ == "__main__":
    main()
# vim: set ft=python tabstop=4 shiftwidth=4 expandtab smarttab:
