refactored config flow #26

This commit is contained in:
tomaae 2020-04-08 20:31:47 +02:00
parent 8eb4e2a2c0
commit 2d88ed8aaf
8 changed files with 131 additions and 142 deletions

View file

@ -32,7 +32,7 @@
"device_tracker": {
"data": {
"scan_interval": "Scan interval (requires HA restart)",
"track_arp": "Show client MAC and IP on interfaces",
"track_iface_clients": "Show client MAC and IP on interfaces",
"unit_of_measurement": "Unit of measurement"
}
}

View file

@ -32,7 +32,7 @@
"device_tracker": {
"data": {
"scan_interval": "Период сканирования (требуется перезагрузка HA)",
"track_arp": "Показывать в интерфейсе MAC и IP клиентов",
"track_iface_clients": "Показывать в интерфейсе MAC и IP клиентов",
"unit_of_measurement": "Единицы измерения"
}
}

View file

@ -32,7 +32,7 @@
"device_tracker": {
"data": {
"scan_interval": "Interval skenovania (vyžaduje sa reštart HA)",
"track_arp": "Zobraziť klientske MAC a IP na rozhraniach",
"track_iface_clients": "Zobraziť klientske MAC a IP na rozhraniach",
"unit_of_measurement": "Merná jednotka"
}
}

View file

@ -2,21 +2,11 @@
import logging
from homeassistant.const import (
CONF_NAME,
CONF_HOST,
CONF_PORT,
CONF_UNIT_OF_MEASUREMENT,
CONF_USERNAME,
CONF_PASSWORD,
CONF_SSL,
)
from homeassistant.exceptions import ConfigEntryNotReady
from .const import (
DOMAIN,
DATA_CLIENT,
DEFAULT_TRAFFIC_TYPE,
)
from .mikrotik_controller import MikrotikControllerData
@ -38,21 +28,7 @@ async def async_setup(hass, _config):
# ---------------------------
async def async_setup_entry(hass, config_entry):
"""Set up Mikrotik Router as config entry."""
name = config_entry.data[CONF_NAME]
host = config_entry.data[CONF_HOST]
port = config_entry.data[CONF_PORT]
username = config_entry.data[CONF_USERNAME]
password = config_entry.data[CONF_PASSWORD]
use_ssl = config_entry.data[CONF_SSL]
if CONF_UNIT_OF_MEASUREMENT in config_entry.data:
traffic_type = config_entry.data[CONF_UNIT_OF_MEASUREMENT]
else:
traffic_type = DEFAULT_TRAFFIC_TYPE
mikrotik_controller = MikrotikControllerData(
hass, config_entry, name, host, port, username, password, use_ssl,
traffic_type
)
mikrotik_controller = MikrotikControllerData(hass, config_entry)
await mikrotik_controller.hwinfo_update()
await mikrotik_controller.async_update()

View file

@ -21,12 +21,18 @@ from homeassistant.core import callback
from .const import (
DOMAIN,
CONF_TRACK_ARP,
DEFAULT_TRACK_ARP,
CONF_TRACK_IFACE_CLIENTS,
DEFAULT_TRACK_IFACE_CLIENTS,
CONF_SCAN_INTERVAL,
DEFAULT_SCAN_INTERVAL,
DEFAULT_TRAFFIC_TYPE,
TRAFFIC_TYPES,
LIST_UNIT_OF_MEASUREMENT,
DEFAULT_UNIT_OF_MEASUREMENT,
DEFAULT_HOST,
DEFAULT_USERNAME,
DEFAULT_PASSWORD,
DEFAULT_PORT,
DEFAULT_NAME,
DEFAULT_SSL,
)
from .mikrotikapi import MikrotikAPI
@ -77,11 +83,11 @@ class MikrotikControllerConfigFlow(ConfigFlow, domain=DOMAIN):
# Test connection
api = MikrotikAPI(
host=user_input["host"],
username=user_input["username"],
password=user_input["password"],
port=user_input["port"],
use_ssl=user_input["ssl"]
host=user_input[CONF_HOST],
username=user_input[CONF_USERNAME],
password=user_input[CONF_PASSWORD],
port=user_input[CONF_PORT],
use_ssl=user_input[CONF_SSL]
)
if not api.connect():
errors[CONF_HOST] = api.error
@ -89,48 +95,42 @@ class MikrotikControllerConfigFlow(ConfigFlow, domain=DOMAIN):
# Save instance
if not errors:
return self.async_create_entry(
title=user_input[CONF_NAME], data=user_input
title=user_input[CONF_NAME],
data=user_input
)
return self._show_config_form(
host=user_input["host"],
username=user_input["username"],
password=user_input["password"],
port=user_input["port"],
name=user_input["name"],
use_ssl=user_input["ssl"],
user_input=user_input,
errors=errors,
)
return self._show_config_form(errors=errors)
return self._show_config_form(
user_input={
CONF_NAME: DEFAULT_NAME,
CONF_HOST: DEFAULT_HOST,
CONF_USERNAME: DEFAULT_USERNAME,
CONF_PASSWORD: DEFAULT_PASSWORD,
CONF_PORT: DEFAULT_PORT,
CONF_SSL: DEFAULT_SSL,
},
errors=errors
)
# ---------------------------
# _show_config_form
# ---------------------------
def _show_config_form(
self,
host="10.0.0.1",
username="admin",
password="admin",
port=0,
name="Mikrotik",
use_ssl=False,
errors=None,
):
def _show_config_form(self, user_input, errors=None):
"""Show the configuration form to edit data."""
return self.async_show_form(
step_id="user",
data_schema=vol.Schema(
{
vol.Required(CONF_HOST, default=host): str,
vol.Required(CONF_USERNAME, default=username): str,
vol.Required(CONF_PASSWORD, default=password): str,
vol.Optional(
CONF_UNIT_OF_MEASUREMENT, default=DEFAULT_TRAFFIC_TYPE
): vol.In(TRAFFIC_TYPES),
vol.Optional(CONF_PORT, default=port): int,
vol.Optional(CONF_NAME, default=name): str,
vol.Optional(CONF_SSL, default=use_ssl): bool,
vol.Required(CONF_NAME, default=user_input[CONF_NAME]): str,
vol.Required(CONF_HOST, default=user_input[CONF_HOST]): str,
vol.Required(CONF_USERNAME, default=user_input[CONF_USERNAME]): str,
vol.Required(CONF_PASSWORD, default=user_input[CONF_PASSWORD]): str,
vol.Optional(CONF_PORT, default=user_input[CONF_PORT]): int,
vol.Optional(CONF_SSL, default=user_input[CONF_SSL]): bool,
}
),
errors=errors,
@ -163,9 +163,9 @@ class MikrotikControllerOptionsFlowHandler(OptionsFlow):
data_schema=vol.Schema(
{
vol.Optional(
CONF_TRACK_ARP,
CONF_TRACK_IFACE_CLIENTS,
default=self.config_entry.options.get(
CONF_TRACK_ARP, DEFAULT_TRACK_ARP
CONF_TRACK_IFACE_CLIENTS, DEFAULT_TRACK_IFACE_CLIENTS
),
): bool,
vol.Optional(
@ -177,9 +177,9 @@ class MikrotikControllerOptionsFlowHandler(OptionsFlow):
vol.Optional(
CONF_UNIT_OF_MEASUREMENT,
default=self.config_entry.options.get(
CONF_UNIT_OF_MEASUREMENT, DEFAULT_TRAFFIC_TYPE
CONF_UNIT_OF_MEASUREMENT, DEFAULT_UNIT_OF_MEASUREMENT
),
): vol.In(TRAFFIC_TYPES),
): vol.In(LIST_UNIT_OF_MEASUREMENT),
}
),
)

View file

@ -5,14 +5,22 @@ DEFAULT_NAME = "Mikrotik Router"
DATA_CLIENT = "client"
ATTRIBUTION = "Data provided by Mikrotik"
CONF_SCAN_INTERVAL = "scan_interval"
CONF_UNIT_OF_MEASUREMENT = "unit_of_measurement"
DEFAULT_SCAN_INTERVAL = 30
CONF_TRACK_ARP = "track_arp"
DEFAULT_TRACK_ARP = True
DEFAULT_ENCODING = "ISO-8859-1"
DEFAULT_LOGIN_METHOD = "plain"
DEFAULT_TRAFFIC_TYPE = "Kbps"
TRAFFIC_TYPES = ["bps", "Kbps", "Mbps", "B/s", "KB/s", "MB/s"]
DEFAULT_HOST = "10.0.0.1"
DEFAULT_USERNAME = "admin"
DEFAULT_PASSWORD = "admin"
DEFAULT_PORT = 0
DEFAULT_NAME = "Mikrotik"
DEFAULT_SSL = False
LIST_UNIT_OF_MEASUREMENT = ["bps", "Kbps", "Mbps", "B/s", "KB/s", "MB/s"]
CONF_TRACK_IFACE_CLIENTS = "track_iface_clients"
CONF_SCAN_INTERVAL = "scan_interval"
DEFAULT_TRACK_IFACE_CLIENTS = True
DEFAULT_SCAN_INTERVAL = 30
DEFAULT_UNIT_OF_MEASUREMENT = "Kbps"

View file

@ -10,14 +10,23 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.util.dt import utcnow
from homeassistant.const import (
CONF_NAME,
CONF_HOST,
CONF_PORT,
CONF_UNIT_OF_MEASUREMENT,
CONF_USERNAME,
CONF_PASSWORD,
CONF_SSL,
)
from .const import (
DOMAIN,
CONF_TRACK_ARP,
DEFAULT_TRACK_ARP,
CONF_TRACK_IFACE_CLIENTS,
DEFAULT_TRACK_IFACE_CLIENTS,
CONF_SCAN_INTERVAL,
CONF_UNIT_OF_MEASUREMENT,
DEFAULT_SCAN_INTERVAL,
DEFAULT_TRAFFIC_TYPE,
DEFAULT_UNIT_OF_MEASUREMENT,
)
from .exceptions import ApiEntryNotFound
from .helper import from_entry, parse_api
@ -32,24 +41,12 @@ _LOGGER = logging.getLogger(__name__)
class MikrotikControllerData:
"""MikrotikController Class"""
def __init__(
self,
hass,
config_entry,
name,
host,
port,
username,
password,
use_ssl,
traffic_type,
):
def __init__(self, hass, config_entry):
"""Initialize MikrotikController."""
self.name = name
self.hass = hass
self.host = host
self.config_entry = config_entry
self.traffic_type = traffic_type
self.name = config_entry.data[CONF_NAME]
self.host = config_entry.data[CONF_HOST]
self.data = {
"routerboard": {},
@ -72,7 +69,13 @@ class MikrotikControllerData:
self.listeners = []
self.lock = asyncio.Lock()
self.api = MikrotikAPI(host, username, password, port, use_ssl)
self.api = MikrotikAPI(
config_entry.data[CONF_HOST],
config_entry.data[CONF_USERNAME],
config_entry.data[CONF_PASSWORD],
config_entry.data[CONF_PORT],
config_entry.data[CONF_SSL]
)
self.nat_removed = {}
@ -83,22 +86,25 @@ class MikrotikControllerData:
self.hass, self.force_fwupdate_check, timedelta(hours=1)
)
def _get_traffic_type_and_div(self):
traffic_type = self.option_traffic_type
if traffic_type == "Kbps":
traffic_div = 0.001
elif traffic_type == "Mbps":
traffic_div = 0.000001
elif traffic_type == "B/s":
traffic_div = 0.125
elif traffic_type == "KB/s":
traffic_div = 0.000125
elif traffic_type == "MB/s":
traffic_div = 0.000000125
# ---------------------------
# _get_unit_of_measurement
# ---------------------------
def _get_unit_of_measurement(self):
uom_type = self.option_unit_of_measurement
if uom_type == "Kbps":
uom_div = 0.001
elif uom_type == "Mbps":
uom_div = 0.000001
elif uom_type == "B/s":
uom_div = 0.125
elif uom_type == "KB/s":
uom_div = 0.000125
elif uom_type == "MB/s":
uom_div = 0.000000125
else:
traffic_type = "bps"
traffic_div = 1
return traffic_type, traffic_div
uom_type = "bps"
uom_div = 1
return uom_type, uom_div
# ---------------------------
# force_update
@ -117,12 +123,12 @@ class MikrotikControllerData:
await self.async_fwupdate_check()
# ---------------------------
# option_track_arp
# option_track_iface_clients
# ---------------------------
@property
def option_track_arp(self):
def option_track_iface_clients(self):
"""Config entry option to not track ARP."""
return self.config_entry.options.get(CONF_TRACK_ARP, DEFAULT_TRACK_ARP)
return self.config_entry.options.get(CONF_TRACK_IFACE_CLIENTS, DEFAULT_TRACK_IFACE_CLIENTS)
# ---------------------------
# option_scan_interval
@ -136,13 +142,13 @@ class MikrotikControllerData:
return timedelta(seconds=scan_interval)
# ---------------------------
# option_traffic_type
# option_unit_of_measurement
# ---------------------------
@property
def option_traffic_type(self):
def option_unit_of_measurement(self):
"""Config entry option to not track ARP."""
return self.config_entry.options.get(
CONF_UNIT_OF_MEASUREMENT, DEFAULT_TRAFFIC_TYPE
CONF_UNIT_OF_MEASUREMENT, DEFAULT_UNIT_OF_MEASUREMENT
)
# ---------------------------
@ -298,18 +304,18 @@ class MikrotikControllerData:
],
)
traffic_type, traffic_div = self._get_traffic_type_and_div()
uom_type, uom_div = self._get_unit_of_measurement()
for uid in self.data["interface"]:
self.data["interface"][uid][
"rx-bits-per-second-attr"] = traffic_type
"rx-bits-per-second-attr"] = uom_type
self.data["interface"][uid][
"tx-bits-per-second-attr"] = traffic_type
"tx-bits-per-second-attr"] = uom_type
self.data["interface"][uid]["rx-bits-per-second"] = round(
self.data["interface"][uid]["rx-bits-per-second"] * traffic_div
self.data["interface"][uid]["rx-bits-per-second"] * uom_div
)
self.data["interface"][uid]["tx-bits-per-second"] = round(
self.data["interface"][uid]["tx-bits-per-second"] * traffic_div
self.data["interface"][uid]["tx-bits-per-second"] * uom_div
)
# ---------------------------
@ -320,7 +326,7 @@ class MikrotikControllerData:
self.data["arp_tmp"] = {}
# Remove data if disabled
if not self.option_track_arp:
if not self.option_track_iface_clients:
for uid in self.data["interface"]:
self.data["interface"][uid]["client-ip-address"] = "disabled"
self.data["interface"][uid]["client-mac-address"] = "disabled"
@ -635,37 +641,37 @@ class MikrotikControllerData:
]
)
traffic_type, traffic_div = self._get_traffic_type_and_div()
uom_type, uom_div = self._get_unit_of_measurement()
for uid in self.data["queue"]:
upload_max_limit_bps, download_max_limit_bps = [int(x) for x in
self.data["queue"][uid]["max-limit"].split('/')]
self.data["queue"][uid]["upload-max-limit"] = \
f"{round(upload_max_limit_bps * traffic_div)} {traffic_type}"
f"{round(upload_max_limit_bps * uom_div)} {uom_type}"
self.data["queue"][uid]["download-max-limit"] = \
f"{round(download_max_limit_bps * traffic_div)} {traffic_type}"
f"{round(download_max_limit_bps * uom_div)} {uom_type}"
upload_limit_at_bps, download_limit_at_bps = [int(x) for x in
self.data["queue"][uid]["limit-at"].split('/')]
self.data["queue"][uid]["upload-limit-at"] = \
f"{round(upload_limit_at_bps * traffic_div)} {traffic_type}"
f"{round(upload_limit_at_bps * uom_div)} {uom_type}"
self.data["queue"][uid]["download-limit-at"] = \
f"{round(download_limit_at_bps * traffic_div)} {traffic_type}"
f"{round(download_limit_at_bps * uom_div)} {uom_type}"
upload_burst_limit_bps, download_burst_limit_bps = [int(x) for x in
self.data["queue"][uid]["burst-limit"].split('/')]
self.data["queue"][uid]["upload-burst-limit"] = \
f"{round(upload_burst_limit_bps * traffic_div)} {traffic_type}"
f"{round(upload_burst_limit_bps * uom_div)} {uom_type}"
self.data["queue"][uid]["download-burst-limit"] = \
f"{round(download_burst_limit_bps * traffic_div)} {traffic_type}"
f"{round(download_burst_limit_bps * uom_div)} {uom_type}"
upload_burst_threshold_bps,\
download_burst_threshold_bps = [int(x) for x in self.data["queue"][uid]["burst-threshold"].split('/')]
self.data["queue"][uid]["upload-burst-threshold"] = \
f"{round(upload_burst_threshold_bps * traffic_div)} {traffic_type}"
f"{round(upload_burst_threshold_bps * uom_div)} {uom_type}"
self.data["queue"][uid]["download-burst-threshold"] = \
f"{round(download_burst_threshold_bps * traffic_div)} {traffic_type}"
f"{round(download_burst_threshold_bps * uom_div)} {uom_type}"
upload_burst_time, download_burst_time = self.data["queue"][uid]["burst-time"].split('/')
self.data["queue"][uid]["upload-burst-time"] = upload_burst_time
@ -857,7 +863,7 @@ class MikrotikControllerData:
"""Get Accounting data from Mikrotik"""
# Check if accounting and account-local-traffic is enabled
accounting_enabled, local_traffic_enabled = self.api.is_accounting_and_local_traffic_enabled()
traffic_type, traffic_div = self._get_traffic_type_and_div()
uom_type, uom_div = self._get_unit_of_measurement()
# Build missing hosts from main hosts dict
for uid, vals in self.data["host"].items():
@ -866,7 +872,7 @@ class MikrotikControllerData:
'address': vals['address'],
'mac-address': vals['mac-address'],
'host-name': vals['host-name'],
'tx-rx-attr': traffic_type,
'tx-rx-attr': uom_type,
'available': False,
'local_accounting': False
}
@ -928,7 +934,7 @@ class MikrotikControllerData:
_LOGGER.warning(f"Address {addr} not found in accounting data, skipping update")
continue
self.data['accounting'][uid]['tx-rx-attr'] = traffic_type
self.data['accounting'][uid]['tx-rx-attr'] = uom_type
self.data['accounting'][uid]['available'] = accounting_enabled
self.data['accounting'][uid]['local_accounting'] = local_traffic_enabled
@ -937,11 +943,11 @@ class MikrotikControllerData:
continue
self.data['accounting'][uid]['wan-tx'] = round(
tmp_accounting_values[addr]['wan-tx'] / time_diff * traffic_div, 2) \
tmp_accounting_values[addr]['wan-tx'] / time_diff * uom_div, 2) \
if tmp_accounting_values[addr]['wan-tx'] else 0.0
self.data['accounting'][uid]['wan-rx'] = round(
tmp_accounting_values[addr]['wan-rx'] / time_diff * traffic_div, 2) \
tmp_accounting_values[addr]['wan-rx'] / time_diff * uom_div, 2) \
if tmp_accounting_values[addr]['wan-rx'] else 0.0
if not local_traffic_enabled:
@ -949,9 +955,9 @@ class MikrotikControllerData:
continue
self.data['accounting'][uid]['lan-tx'] = round(
tmp_accounting_values[addr]['lan-tx'] / time_diff * traffic_div, 2) \
tmp_accounting_values[addr]['lan-tx'] / time_diff * uom_div, 2) \
if tmp_accounting_values[addr]['lan-tx'] else 0.0
self.data['accounting'][uid]['lan-rx'] = round(
tmp_accounting_values[addr]['lan-rx'] / time_diff * traffic_div, 2) \
tmp_accounting_values[addr]['lan-rx'] / time_diff * uom_div, 2) \
if tmp_accounting_values[addr]['lan-rx'] else 0.0

View file

@ -11,8 +11,7 @@
"port": "Port",
"username": "Username",
"password": "Password",
"ssl": "Use SSL",
"unit_of_measurement": "Unit of measurement"
"ssl": "Use SSL"
}
}
},
@ -32,7 +31,7 @@
"device_tracker": {
"data": {
"scan_interval": "Scan interval (requires HA restart)",
"track_arp": "Show client MAC and IP on interfaces",
"track_iface_clients": "Show client MAC and IP on interfaces",
"unit_of_measurement": "Unit of measurement"
}
}