diff --git a/router_manager/forms.py b/router_manager/forms.py index 46da0d2..58f347a 100644 --- a/router_manager/forms.py +++ b/router_manager/forms.py @@ -2,7 +2,7 @@ from django import forms from crispy_forms.helper import FormHelper from crispy_forms.layout import Layout, Submit, Row, Column, HTML from .models import Router, RouterGroup, SSHKey -from routerlib.functions import test_authentication +from routerlib.functions import test_authentication, connect_to_ssh import ipaddress import socket @@ -12,7 +12,7 @@ class RouterForm(forms.ModelForm): class Meta: model = Router - fields = ['name', 'address', 'username', 'password', 'ssh_key', 'monitoring', 'router_type', 'enabled', 'backup_profile'] + fields = ['name', 'port', 'address', 'username', 'password', 'ssh_key', 'monitoring', 'router_type', 'enabled', 'backup_profile'] def __init__(self, *args, **kwargs): super(RouterForm, self).__init__(*args, **kwargs) @@ -27,7 +27,7 @@ class RouterForm(forms.ModelForm): self.helper.layout = Layout( Row( Column('name', css_class='form-group col-md-6 mb-0'), - Column('address', css_class='form-group col-md-6 mb-0'), + Column('ssh_key', css_class='form-group col-md-6 mb-0'), css_class='form-row' ), Row( @@ -35,7 +35,12 @@ class RouterForm(forms.ModelForm): Column('password', css_class='form-group col-md-6 mb-0'), css_class='form-row' ), - 'ssh_key', + Row( + Column('address', css_class='form-group col-md-6 mb-0'), + Column('port', css_class='form-group col-md-6 mb-0'), + css_class='form-row' + ), + 'backup_profile', 'router_type', 'monitoring', @@ -59,6 +64,7 @@ class RouterForm(forms.ModelForm): address = cleaned_data.get('address') router_type = cleaned_data.get('router_type') backup_profile = cleaned_data.get('backup_profile') + port = cleaned_data.get('port') if name: name = name.strip() @@ -82,6 +88,11 @@ class RouterForm(forms.ModelForm): if backup_profile: raise forms.ValidationError('Monitoring only routers cannot have a backup profile') return cleaned_data + else: + if not port: + raise forms.ValidationError('You must provide a port') + if not 1 <= port <= 65535: + raise forms.ValidationError('Invalid port number') if ssh_key and password: raise forms.ValidationError('You must provide a password or an SSH Key, not both') @@ -94,8 +105,9 @@ class RouterForm(forms.ModelForm): if ssh_key and not password: cleaned_data['password'] = '' + test_authentication_success, test_authentication_message = test_authentication( - router_type, cleaned_data['address'], username, cleaned_data['password'], ssh_key + router_type, cleaned_data['address'], port, username, cleaned_data['password'], ssh_key ) if not test_authentication_success: if test_authentication_message: diff --git a/router_manager/migrations/0016_router_port.py b/router_manager/migrations/0016_router_port.py new file mode 100644 index 0000000..56d0578 --- /dev/null +++ b/router_manager/migrations/0016_router_port.py @@ -0,0 +1,18 @@ +# Generated by Django 5.0.3 on 2024-04-12 17:52 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('router_manager', '0015_alter_routerstatus_status_online'), + ] + + operations = [ + migrations.AddField( + model_name='router', + name='port', + field=models.IntegerField(default=22), + ), + ] diff --git a/router_manager/models.py b/router_manager/models.py index e7fa257..b0d0fa3 100644 --- a/router_manager/models.py +++ b/router_manager/models.py @@ -26,6 +26,7 @@ class Router(models.Model): name = models.CharField(max_length=100, unique=True) internal_notes = models.TextField(null=True, blank=True) address = models.CharField(max_length=100) + port = models.IntegerField(default=22) username = models.CharField(max_length=100, default='admin') password = models.CharField(max_length=100, null=True, blank=True) ssh_key = models.ForeignKey(SSHKey, on_delete=models.SET_NULL, null=True, blank=True) diff --git a/routerlib/backup_functions.py b/routerlib/backup_functions.py index 13bfc18..cb9855a 100644 --- a/routerlib/backup_functions.py +++ b/routerlib/backup_functions.py @@ -75,14 +75,15 @@ def execute_backup(router_backup: RouterBackup): router = router_backup.router backup_name = gen_backup_name(router_backup) file_extension = get_router_backup_file_extension(router.router_type) + ssh_client = None try: if router_backup.router.router_type == 'routeros': - ssh_client = connect_to_ssh(router.address, router.username, router.password, router.ssh_key) + ssh_client = connect_to_ssh(router.address, router.port, router.username, router.password, router.ssh_key) ssh_client.exec_command(f'/system backup save name={backup_name}.{file_extension["binary"]}') ssh_client.exec_command(f'/export file={backup_name}.{file_extension["text"]}') return True, [f'{backup_name}.{file_extension["binary"]}', f'{backup_name}.{file_extension["text"]}'], error_message elif router_backup.router.router_type == 'openwrt': - ssh_client = connect_to_ssh(router.address, router.username, router.password, router.ssh_key) + ssh_client = connect_to_ssh(router.address, router.port, router.username, router.password, router.ssh_key) stdin, stdout, stderr = ssh_client.exec_command('uci export') backup_text = stdout.read().decode('utf-8') if backup_text: @@ -100,7 +101,8 @@ def execute_backup(router_backup: RouterBackup): error_message = f"Failed to execute backup: {str(e)}" return False, [], error_message finally: - ssh_client.close() + if ssh_client: + ssh_client.close() def retrieve_backup(router_backup: RouterBackup): @@ -109,12 +111,13 @@ def retrieve_backup(router_backup: RouterBackup): backup_name = gen_backup_name(router_backup) success = False file_extension = get_router_backup_file_extension(router.router_type) + ssh_client = None try: if router_backup.router.router_type == 'routeros': rsc_file_path = f'/tmp/{backup_name}.{file_extension["text"]}' backup_file_path = f'/tmp/{backup_name}.{file_extension["binary"]}' - ssh_client = connect_to_ssh(router.address, router.username, router.password, router.ssh_key) + ssh_client = connect_to_ssh(router.address, router.port, router.username, router.password, router.ssh_key) scp_client = SCPClient(ssh_client.get_transport()) scp_client.get(f'/{backup_name}.{file_extension["text"]}', rsc_file_path) scp_client.get(f'/{backup_name}.{file_extension["binary"]}', backup_file_path) @@ -139,7 +142,7 @@ def retrieve_backup(router_backup: RouterBackup): elif router_backup.router.router_type == 'openwrt': remote_backup_file_path = f'/tmp/{backup_name}.{file_extension["binary"]}' local_backup_file_path = f'/tmp/{backup_name}.{file_extension["binary"]}' - ssh_client = connect_to_ssh(router.address, router.username, router.password, router.ssh_key) + ssh_client = connect_to_ssh(router.address, router.port, router.username, router.password, router.ssh_key) scp_client = SCPClient(ssh_client.get_transport()) scp_client.get(remote_backup_file_path, local_backup_file_path) with open(local_backup_file_path, 'rb') as backup_file: @@ -156,23 +159,26 @@ def retrieve_backup(router_backup: RouterBackup): except Exception as e: return success, f"Failed to retrieve backup files: {str(e)}" finally: - ssh_client.close() + if ssh_client: + ssh_client.close() return success, error_message def clean_up_backup_files(router_backup: RouterBackup): router = router_backup.router + ssh_client = None try: if router_backup.router.router_type == 'routeros': - ssh_client = connect_to_ssh(router.address, router.username, router.password, router.ssh_key) + ssh_client = connect_to_ssh(router.address, router.port, router.username, router.password, router.ssh_key) ssh_client.exec_command('file remove [find where name~"routerfleet-backup-"]') elif router_backup.router.router_type == 'openwrt': - ssh_client = connect_to_ssh(router.address, router.username, router.password, router.ssh_key) + ssh_client = connect_to_ssh(router.address, router.port, router.username, router.password, router.ssh_key) ssh_client.exec_command('rm /tmp/routerfleet-backup-*') else: print(f"Router type not supported: {router_backup.router.get_router_type_display()}") except Exception as e: print(f"Failed to clean up backup files: {str(e)}") finally: - ssh_client.close() + if ssh_client: + ssh_client.close() diff --git a/routerlib/functions.py b/routerlib/functions.py index 5b0ebd8..5ec3429 100644 --- a/routerlib/functions.py +++ b/routerlib/functions.py @@ -40,18 +40,18 @@ def load_private_key_from_string(key_str): return None -def connect_to_ssh(address, username, password, sshkey=None): +def connect_to_ssh(address, port, username, password, sshkey=None): ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) if sshkey: private_key = load_private_key_from_string(sshkey.private_key) - ssh_client.connect(address, username=username, pkey=private_key, look_for_keys=False, timeout=10, allow_agent=False) + ssh_client.connect(address, port=port, username=username, pkey=private_key, look_for_keys=False, timeout=10, allow_agent=False) else: - ssh_client.connect(address, username=username, password=password, look_for_keys=False, timeout=10, allow_agent=False) + ssh_client.connect(address, port=port, username=username, password=password, look_for_keys=False, timeout=10, allow_agent=False) return ssh_client -def test_authentication(router_type, address, username, password, sshkey=None): +def test_authentication(router_type, address, port, username, password, sshkey=None): router_features = get_router_features(router_type) if 'ssh' in router_features: connection_type = 'ssh' @@ -61,14 +61,14 @@ def test_authentication(router_type, address, username, password, sshkey=None): return False, 'Router type not supported' if connection_type == 'ssh': - return test_ssh_authentication(router_type, address, username, password, sshkey) + return test_ssh_authentication(router_type, address, port, username, password, sshkey) elif connection_type == 'telnet': return test_telnet_authentication(address, username, password, sshkey=None) -def test_ssh_authentication(router_type, address, username, password, sshkey=None): +def test_ssh_authentication(router_type, address, port, username, password, sshkey=None): try: - ssh_client = connect_to_ssh(address, username, password, sshkey) + ssh_client = connect_to_ssh(address, port, username, password, sshkey) if router_type == 'routeros': stdin, stdout, stderr = ssh_client.exec_command('/system resource print') output = stdout.read().decode()