diff --git a/.github/generate_releasenotes.py b/.github/generate_releasenotes.py index 31a2534..4beb498 100644 --- a/.github/generate_releasenotes.py +++ b/.github/generate_releasenotes.py @@ -70,6 +70,8 @@ def get_integration_commits(github, skip=True): continue if " workflow" in msg: continue + if " test" in msg: + continue if "docs" in msg: continue if "dev debug" in msg: diff --git a/.github/generate_requirements.py b/.github/generate_requirements.py new file mode 100644 index 0000000..ec1f4e8 --- /dev/null +++ b/.github/generate_requirements.py @@ -0,0 +1,22 @@ +import configparser + + +def main(): + parser = configparser.ConfigParser() + parser.read("Pipfile") + + packages = "packages" + with open("requirements.txt", "w") as f: + for key in parser[packages]: + value = parser[packages][key] + f.write(key + value.replace('"', "") + "\n") + + devpackages = "dev-packages" + with open("requirements_tests.txt", "w") as f: + for key in parser[devpackages]: + value = parser[devpackages][key] + f.write(key + value.replace('"', "") + "\n") + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8ab11d4..1397c89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ on: jobs: black: - name: Black + name: Python Code Format Check runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -29,10 +29,14 @@ jobs: uses: actions/setup-python@v1 with: python-version: 3.8 + - name: Generate Requirements lists + run: | + python3 .github/generate_requirements.py - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install -r requirements_tests.txt - name: Lint with flake8 run: | pip install flake8 @@ -40,10 +44,10 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --ignore W503,E722 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=15 --max-line-length=127 --statistics - #- name: Test with pytest - # run: | - # pip install pytest - # pytest + - name: Test with pytest + run: | + pip install pytest + pytest sonarcloud: name: SonarCloud runs-on: ubuntu-latest diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..24b4e06 --- /dev/null +++ b/Pipfile @@ -0,0 +1,27 @@ +[[source]] +name = "pypi" +url = "https://pypi.org/simple" +verify_ssl = true + +[dev-packages] +wheel = ">=0.34" +pygithub = ">=1.47" +homeassistant = ">=0.108.4" +sqlalchemy = "==1.3.16" +codecov = "==2.0.15" +mock-open = "==1.3.1" +mypy = "==0.770" +pre-commit = "==2.2.0" +pylint = "==2.4.4" +astroid = "==2.3.3" +pylint-strict-informational = "==0.1" +pytest-aiohttp = "==0.3.0" +pytest-cov = "==2.8.1" +pytest-sugar = "==0.9.2" +pytest-timeout = "==1.3.3" +pytest = "==5.3.5" +requests_mock = "==1.7.0" +responses = "==0.10.6" + +[packages] +librouteros = "==3.0.0" diff --git a/custom_components/mikrotik_router/mikrotik_controller.py b/custom_components/mikrotik_router/mikrotik_controller.py index 1fe2fd7..d1c85d9 100644 --- a/custom_components/mikrotik_router/mikrotik_controller.py +++ b/custom_components/mikrotik_router/mikrotik_controller.py @@ -30,6 +30,7 @@ from .const import ( CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL, DEFAULT_UNIT_OF_MEASUREMENT, + CONF_TRACK_HOSTS_TIMEOUT, ) from .exceptions import ApiEntryNotFound from .helper import parse_api diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..b1e5d9c --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,20 @@ +"""Tests for the Mikrotik Router component.""" + +from custom_components.mikrotik_router import config_flow + +MOCK_DATA = { + config_flow.CONF_NAME: config_flow.DEFAULT_DEVICE_NAME, + config_flow.CONF_HOST: config_flow.DEFAULT_HOST, + config_flow.CONF_USERNAME: config_flow.DEFAULT_USERNAME, + config_flow.CONF_PASSWORD: config_flow.DEFAULT_PASSWORD, + config_flow.CONF_PORT: config_flow.DEFAULT_PORT, + config_flow.CONF_SSL: config_flow.DEFAULT_SSL, +} + +MOCK_OPTIONS = { + config_flow.CONF_SCAN_INTERVAL: config_flow.DEFAULT_SCAN_INTERVAL, + config_flow.CONF_UNIT_OF_MEASUREMENT: config_flow.DEFAULT_UNIT_OF_MEASUREMENT, + config_flow.CONF_TRACK_IFACE_CLIENTS: config_flow.DEFAULT_TRACK_IFACE_CLIENTS, + config_flow.CONF_TRACK_HOSTS: config_flow.DEFAULT_TRACK_HOSTS, + config_flow.CONF_TRACK_HOSTS_TIMEOUT: config_flow.DEFAULT_TRACK_HOST_TIMEOUT, +} diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 0000000..f39d458 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,1148 @@ +"""Test the helper method for writing tests.""" +import asyncio +import collections +from collections import OrderedDict +from contextlib import contextmanager +from datetime import timedelta +import functools as ft +from io import StringIO +import json +import logging +import os +import sys +import threading +from unittest.mock import MagicMock, Mock, patch +import uuid + +from aiohttp.test_utils import unused_port as get_test_instance_port # noqa + +from homeassistant import auth, config_entries, core as ha, loader +from homeassistant.auth import ( + auth_store, + models as auth_models, + permissions as auth_permissions, + providers as auth_providers, +) +from homeassistant.auth.permissions import system_policies +from homeassistant.components import mqtt, recorder +from homeassistant.components.device_automation import ( # noqa: F401 + _async_get_device_automation_capabilities as async_get_device_automation_capabilities, + _async_get_device_automations as async_get_device_automations, +) +from homeassistant.components.mqtt.models import Message +from homeassistant.config import async_process_component_config +from homeassistant.const import ( + ATTR_DISCOVERED, + ATTR_SERVICE, + DEVICE_DEFAULT_NAME, + EVENT_HOMEASSISTANT_CLOSE, + EVENT_PLATFORM_DISCOVERED, + EVENT_STATE_CHANGED, + EVENT_TIME_CHANGED, + STATE_OFF, + STATE_ON, +) +from homeassistant.core import State +from homeassistant.helpers import ( + area_registry, + device_registry, + entity, + entity_platform, + entity_registry, + intent, + restore_state, + storage, +) +from homeassistant.helpers.json import JSONEncoder +from homeassistant.setup import async_setup_component, setup_component +from homeassistant.util.async_ import run_callback_threadsafe +import homeassistant.util.dt as date_util +from homeassistant.util.unit_system import METRIC_SYSTEM +import homeassistant.util.yaml.loader as yaml_loader + +_LOGGER = logging.getLogger(__name__) +INSTANCES = [] +CLIENT_ID = "https://example.com/app" +CLIENT_REDIRECT_URI = "https://example.com/app/callback" + + +def threadsafe_callback_factory(func): + """Create threadsafe functions out of callbacks. + + Callback needs to have `hass` as first argument. + """ + + @ft.wraps(func) + def threadsafe(*args, **kwargs): + """Call func threadsafe.""" + hass = args[0] + return run_callback_threadsafe( + hass.loop, ft.partial(func, *args, **kwargs) + ).result() + + return threadsafe + + +def threadsafe_coroutine_factory(func): + """Create threadsafe functions out of coroutine. + + Callback needs to have `hass` as first argument. + """ + + @ft.wraps(func) + def threadsafe(*args, **kwargs): + """Call func threadsafe.""" + hass = args[0] + return asyncio.run_coroutine_threadsafe( + func(*args, **kwargs), hass.loop + ).result() + + return threadsafe + + +def get_test_config_dir(*add_path): + """Return a path to a test config dir.""" + return os.path.join(os.path.dirname(__file__), "testing_config", *add_path) + + +def get_test_home_assistant(): + """Return a Home Assistant object pointing at test config directory.""" + if sys.platform == "win32": + loop = asyncio.ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + + asyncio.set_event_loop(loop) + hass = loop.run_until_complete(async_test_home_assistant(loop)) + + stop_event = threading.Event() + + def run_loop(): + """Run event loop.""" + # pylint: disable=protected-access + loop._thread_ident = threading.get_ident() + loop.run_forever() + stop_event.set() + + orig_stop = hass.stop + + def start_hass(*mocks): + """Start hass.""" + asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result() + + def stop_hass(): + """Stop hass.""" + orig_stop() + stop_event.wait() + loop.close() + + hass.start = start_hass + hass.stop = stop_hass + + threading.Thread(name="LoopThread", target=run_loop, daemon=False).start() + + return hass + + +# pylint: disable=protected-access +async def async_test_home_assistant(loop): + """Return a Home Assistant object pointing at test config dir.""" + hass = ha.HomeAssistant(loop) + store = auth_store.AuthStore(hass) + hass.auth = auth.AuthManager(hass, store, {}, {}) + ensure_auth_manager_loaded(hass.auth) + INSTANCES.append(hass) + + orig_async_add_job = hass.async_add_job + orig_async_add_executor_job = hass.async_add_executor_job + orig_async_create_task = hass.async_create_task + + def async_add_job(target, *args): + """Add job.""" + if isinstance(target, Mock): + return mock_coro(target(*args)) + return orig_async_add_job(target, *args) + + def async_add_executor_job(target, *args): + """Add executor job.""" + if isinstance(target, Mock): + return mock_coro(target(*args)) + return orig_async_add_executor_job(target, *args) + + def async_create_task(coroutine): + """Create task.""" + if isinstance(coroutine, Mock): + return mock_coro() + return orig_async_create_task(coroutine) + + hass.async_add_job = async_add_job + hass.async_add_executor_job = async_add_executor_job + hass.async_create_task = async_create_task + + hass.config.location_name = "test home" + hass.config.config_dir = get_test_config_dir() + hass.config.latitude = 32.87336 + hass.config.longitude = -117.22743 + hass.config.elevation = 0 + hass.config.time_zone = date_util.get_time_zone("US/Pacific") + hass.config.units = METRIC_SYSTEM + hass.config.skip_pip = True + + hass.config_entries = config_entries.ConfigEntries(hass, {}) + hass.config_entries._entries = [] + hass.config_entries._store._async_ensure_stop_listener = lambda: None + + hass.state = ha.CoreState.running + + # Mock async_start + orig_start = hass.async_start + + async def mock_async_start(): + """Start the mocking.""" + # We only mock time during tests and we want to track tasks + with patch("homeassistant.core._async_create_timer"), patch.object( + hass, "async_stop_track_tasks" + ): + await orig_start() + + hass.async_start = mock_async_start + + @ha.callback + def clear_instance(event): + """Clear global instance.""" + INSTANCES.remove(hass) + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, clear_instance) + + return hass + + +def async_mock_service(hass, domain, service, schema=None): + """Set up a fake service & return a calls log list to this service.""" + calls = [] + + @ha.callback + def mock_service_log(call): # pylint: disable=unnecessary-lambda + """Mock service call.""" + calls.append(call) + + hass.services.async_register(domain, service, mock_service_log, schema=schema) + + return calls + + +mock_service = threadsafe_callback_factory(async_mock_service) + + +@ha.callback +def async_mock_intent(hass, intent_typ): + """Set up a fake intent handler.""" + intents = [] + + class MockIntentHandler(intent.IntentHandler): + intent_type = intent_typ + + @asyncio.coroutine + def async_handle(self, intent): + """Handle the intent.""" + intents.append(intent) + return intent.create_response() + + intent.async_register(hass, MockIntentHandler()) + + return intents + + +@ha.callback +def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False): + """Fire the MQTT message.""" + if isinstance(payload, str): + payload = payload.encode("utf-8") + msg = Message(topic, payload, qos, retain) + hass.data["mqtt"]._mqtt_handle_message(msg) + + +fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message) + + +@ha.callback +def async_fire_time_changed(hass, time): + """Fire a time changes event.""" + hass.bus.async_fire(EVENT_TIME_CHANGED, {"now": date_util.as_utc(time)}) + + +fire_time_changed = threadsafe_callback_factory(async_fire_time_changed) + + +def fire_service_discovered(hass, service, info): + """Fire the MQTT message.""" + hass.bus.fire( + EVENT_PLATFORM_DISCOVERED, {ATTR_SERVICE: service, ATTR_DISCOVERED: info} + ) + + +@ha.callback +def async_fire_service_discovered(hass, service, info): + """Fire the MQTT message.""" + hass.bus.async_fire( + EVENT_PLATFORM_DISCOVERED, {ATTR_SERVICE: service, ATTR_DISCOVERED: info} + ) + + +def load_fixture(filename): + """Load a fixture.""" + path = os.path.join(os.path.dirname(__file__), "fixtures", filename) + with open(path, encoding="utf-8") as fptr: + return fptr.read() + + +def mock_state_change_event(hass, new_state, old_state=None): + """Mock state change envent.""" + event_data = {"entity_id": new_state.entity_id, "new_state": new_state} + + if old_state: + event_data["old_state"] = old_state + + hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context) + + +async def async_mock_mqtt_component(hass, config=None): + """Mock the MQTT component.""" + if config is None: + config = {mqtt.CONF_BROKER: "mock-broker"} + + async def _async_fire_mqtt_message(topic, payload, qos, retain): + async_fire_mqtt_message(hass, topic, payload, qos, retain) + + with patch("paho.mqtt.client.Client") as mock_client: + mock_client().connect.return_value = 0 + mock_client().subscribe.return_value = (0, 0) + mock_client().unsubscribe.return_value = (0, 0) + mock_client().publish.return_value = (0, 0) + mock_client().publish.side_effect = _async_fire_mqtt_message + + result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config}) + assert result + await hass.async_block_till_done() + + hass.data["mqtt"] = MagicMock( + spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"] + ) + + return hass.data["mqtt"] + + +mock_mqtt_component = threadsafe_coroutine_factory(async_mock_mqtt_component) + + +@ha.callback +def mock_component(hass, component): + """Mock a component is setup.""" + if component in hass.config.components: + AssertionError(f"Integration {component} is already setup") + + hass.config.components.add(component) + + +def mock_registry(hass, mock_entries=None): + """Mock the Entity Registry.""" + registry = entity_registry.EntityRegistry(hass) + registry.entities = mock_entries or OrderedDict() + + hass.data[entity_registry.DATA_REGISTRY] = registry + return registry + + +def mock_area_registry(hass, mock_entries=None): + """Mock the Area Registry.""" + registry = area_registry.AreaRegistry(hass) + registry.areas = mock_entries or OrderedDict() + + hass.data[area_registry.DATA_REGISTRY] = registry + return registry + + +def mock_device_registry(hass, mock_entries=None): + """Mock the Device Registry.""" + registry = device_registry.DeviceRegistry(hass) + registry.devices = mock_entries or OrderedDict() + + hass.data[device_registry.DATA_REGISTRY] = registry + return registry + + +class MockGroup(auth_models.Group): + """Mock a group in Home Assistant.""" + + def __init__(self, id=None, name="Mock Group", policy=system_policies.ADMIN_POLICY): + """Mock a group.""" + kwargs = {"name": name, "policy": policy} + if id is not None: + kwargs["id"] = id + + super().__init__(**kwargs) + + def add_to_hass(self, hass): + """Test helper to add entry to hass.""" + return self.add_to_auth_manager(hass.auth) + + def add_to_auth_manager(self, auth_mgr): + """Test helper to add entry to hass.""" + ensure_auth_manager_loaded(auth_mgr) + auth_mgr._store._groups[self.id] = self + return self + + +class MockUser(auth_models.User): + """Mock a user in Home Assistant.""" + + def __init__( + self, + id=None, + is_owner=False, + is_active=True, + name="Mock User", + system_generated=False, + groups=None, + ): + """Initialize mock user.""" + kwargs = { + "is_owner": is_owner, + "is_active": is_active, + "name": name, + "system_generated": system_generated, + "groups": groups or [], + "perm_lookup": None, + } + if id is not None: + kwargs["id"] = id + super().__init__(**kwargs) + + def add_to_hass(self, hass): + """Test helper to add entry to hass.""" + return self.add_to_auth_manager(hass.auth) + + def add_to_auth_manager(self, auth_mgr): + """Test helper to add entry to hass.""" + ensure_auth_manager_loaded(auth_mgr) + auth_mgr._store._users[self.id] = self + return self + + def mock_policy(self, policy): + """Mock a policy for a user.""" + self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup) + + +async def register_auth_provider(hass, config): + """Register an auth provider.""" + provider = await auth_providers.auth_provider_from_config( + hass, hass.auth._store, config + ) + assert provider is not None, "Invalid config specified" + key = (provider.type, provider.id) + providers = hass.auth._providers + + if key in providers: + raise ValueError("Provider already registered") + + providers[key] = provider + return provider + + +@ha.callback +def ensure_auth_manager_loaded(auth_mgr): + """Ensure an auth manager is considered loaded.""" + store = auth_mgr._store + if store._users is None: + store._set_defaults() + + +class MockModule: + """Representation of a fake module.""" + + # pylint: disable=invalid-name + def __init__( + self, + domain=None, + dependencies=None, + setup=None, + requirements=None, + config_schema=None, + platform_schema=None, + platform_schema_base=None, + async_setup=None, + async_setup_entry=None, + async_unload_entry=None, + async_migrate_entry=None, + async_remove_entry=None, + partial_manifest=None, + ): + """Initialize the mock module.""" + self.__name__ = f"homeassistant.components.{domain}" + self.__file__ = f"homeassistant/components/{domain}" + self.DOMAIN = domain + self.DEPENDENCIES = dependencies or [] + self.REQUIREMENTS = requirements or [] + # Overlay to be used when generating manifest from this module + self._partial_manifest = partial_manifest + + if config_schema is not None: + self.CONFIG_SCHEMA = config_schema + + if platform_schema is not None: + self.PLATFORM_SCHEMA = platform_schema + + if platform_schema_base is not None: + self.PLATFORM_SCHEMA_BASE = platform_schema_base + + if setup is not None: + # We run this in executor, wrap it in function + self.setup = lambda *args: setup(*args) + + if async_setup is not None: + self.async_setup = async_setup + + if setup is None and async_setup is None: + self.async_setup = mock_coro_func(True) + + if async_setup_entry is not None: + self.async_setup_entry = async_setup_entry + + if async_unload_entry is not None: + self.async_unload_entry = async_unload_entry + + if async_migrate_entry is not None: + self.async_migrate_entry = async_migrate_entry + + if async_remove_entry is not None: + self.async_remove_entry = async_remove_entry + + def mock_manifest(self): + """Generate a mock manifest to represent this module.""" + return { + **loader.manifest_from_legacy_module(self.DOMAIN, self), + **(self._partial_manifest or {}), + } + + +class MockPlatform: + """Provide a fake platform.""" + + __name__ = "homeassistant.components.light.bla" + __file__ = "homeassistant/components/blah/light" + + # pylint: disable=invalid-name + def __init__( + self, + setup_platform=None, + dependencies=None, + platform_schema=None, + async_setup_platform=None, + async_setup_entry=None, + scan_interval=None, + ): + """Initialize the platform.""" + self.DEPENDENCIES = dependencies or [] + + if platform_schema is not None: + self.PLATFORM_SCHEMA = platform_schema + + if scan_interval is not None: + self.SCAN_INTERVAL = scan_interval + + if setup_platform is not None: + # We run this in executor, wrap it in function + self.setup_platform = lambda *args: setup_platform(*args) + + if async_setup_platform is not None: + self.async_setup_platform = async_setup_platform + + if async_setup_entry is not None: + self.async_setup_entry = async_setup_entry + + if setup_platform is None and async_setup_platform is None: + self.async_setup_platform = mock_coro_func() + + +class MockEntityPlatform(entity_platform.EntityPlatform): + """Mock class with some mock defaults.""" + + def __init__( + self, + hass, + logger=None, + domain="test_domain", + platform_name="test_platform", + platform=None, + scan_interval=timedelta(seconds=15), + entity_namespace=None, + ): + """Initialize a mock entity platform.""" + if logger is None: + logger = logging.getLogger("homeassistant.helpers.entity_platform") + + # Otherwise the constructor will blow up. + if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock): + platform.PARALLEL_UPDATES = 0 + + super().__init__( + hass=hass, + logger=logger, + domain=domain, + platform_name=platform_name, + platform=platform, + scan_interval=scan_interval, + entity_namespace=entity_namespace, + ) + + +class MockToggleEntity(entity.ToggleEntity): + """Provide a mock toggle device.""" + + def __init__(self, name, state, unique_id=None): + """Initialize the mock entity.""" + self._name = name or DEVICE_DEFAULT_NAME + self._state = state + self.calls = [] + + @property + def name(self): + """Return the name of the entity if any.""" + self.calls.append(("name", {})) + return self._name + + @property + def state(self): + """Return the state of the entity if any.""" + self.calls.append(("state", {})) + return self._state + + @property + def is_on(self): + """Return true if entity is on.""" + self.calls.append(("is_on", {})) + return self._state == STATE_ON + + def turn_on(self, **kwargs): + """Turn the entity on.""" + self.calls.append(("turn_on", kwargs)) + self._state = STATE_ON + + def turn_off(self, **kwargs): + """Turn the entity off.""" + self.calls.append(("turn_off", kwargs)) + self._state = STATE_OFF + + def last_call(self, method=None): + """Return the last call.""" + if not self.calls: + return None + if method is None: + return self.calls[-1] + try: + return next(call for call in reversed(self.calls) if call[0] == method) + except StopIteration: + return None + + +class MockConfigEntry(config_entries.ConfigEntry): + """Helper for creating config entries that adds some defaults.""" + + def __init__( + self, + *, + domain="test", + data=None, + version=1, + entry_id=None, + source=config_entries.SOURCE_USER, + title="Mock Title", + state=None, + options={}, + system_options={}, + connection_class=config_entries.CONN_CLASS_UNKNOWN, + unique_id=None, + ): + """Initialize a mock config entry.""" + kwargs = { + "entry_id": entry_id or uuid.uuid4().hex, + "domain": domain, + "data": data or {}, + "system_options": system_options, + "options": options, + "version": version, + "title": title, + "connection_class": connection_class, + "unique_id": unique_id, + } + if source is not None: + kwargs["source"] = source + if state is not None: + kwargs["state"] = state + super().__init__(**kwargs) + + def add_to_hass(self, hass): + """Test helper to add entry to hass.""" + hass.config_entries._entries.append(self) + + def add_to_manager(self, manager): + """Test helper to add entry to entry manager.""" + manager._entries.append(self) + + +def patch_yaml_files(files_dict, endswith=True): + """Patch load_yaml with a dictionary of yaml files.""" + # match using endswith, start search with longest string + matchlist = sorted(list(files_dict.keys()), key=len) if endswith else [] + + def mock_open_f(fname, **_): + """Mock open() in the yaml module, used by load_yaml.""" + # Return the mocked file on full match + if fname in files_dict: + _LOGGER.debug("patch_yaml_files match %s", fname) + res = StringIO(files_dict[fname]) + setattr(res, "name", fname) + return res + + # Match using endswith + for ends in matchlist: + if fname.endswith(ends): + _LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname) + res = StringIO(files_dict[ends]) + setattr(res, "name", fname) + return res + + # Fallback for hass.components (i.e. services.yaml) + if "homeassistant/components" in fname: + _LOGGER.debug("patch_yaml_files using real file: %s", fname) + return open(fname, encoding="utf-8") + + # Not found + raise FileNotFoundError(f"File not found: {fname}") + + return patch.object(yaml_loader, "open", mock_open_f, create=True) + + +def mock_coro(return_value=None, exception=None): + """Return a coro that returns a value or raise an exception.""" + return mock_coro_func(return_value, exception)() + + +def mock_coro_func(return_value=None, exception=None): + """Return a method to create a coro function that returns a value.""" + + @asyncio.coroutine + def coro(*args, **kwargs): + """Fake coroutine.""" + if exception: + raise exception + return return_value + + return coro + + +@contextmanager +def assert_setup_component(count, domain=None): + """Collect valid configuration from setup_component. + + - count: The amount of valid platforms that should be setup + - domain: The domain to count is optional. It can be automatically + determined most of the time + + Use as a context manager around setup.setup_component + with assert_setup_component(0) as result_config: + setup_component(hass, domain, start_config) + # using result_config is optional + """ + config = {} + + async def mock_psc(hass, config_input, integration): + """Mock the prepare_setup_component to capture config.""" + domain_input = integration.domain + res = await async_process_component_config(hass, config_input, integration) + config[domain_input] = None if res is None else res.get(domain_input) + _LOGGER.debug( + "Configuration for %s, Validated: %s, Original %s", + domain_input, + config[domain_input], + config_input.get(domain_input), + ) + return res + + assert isinstance(config, dict) + with patch("homeassistant.config.async_process_component_config", mock_psc): + yield config + + if domain is None: + assert len(config) == 1, "assert_setup_component requires DOMAIN: {}".format( + list(config.keys()) + ) + domain = list(config.keys())[0] + + res = config.get(domain) + res_len = 0 if res is None else len(res) + assert ( + res_len == count + ), f"setup_component failed, expected {count} got {res_len}: {res}" + + +def init_recorder_component(hass, add_config=None): + """Initialize the recorder.""" + config = dict(add_config) if add_config else {} + config[recorder.CONF_DB_URL] = "sqlite://" # In memory DB + + with patch("homeassistant.components.recorder.migration.migrate_schema"): + assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config}) + assert recorder.DOMAIN in hass.config.components + _LOGGER.info("In-memory recorder successfully started") + + +def mock_restore_cache(hass, states): + """Mock the DATA_RESTORE_CACHE.""" + key = restore_state.DATA_RESTORE_STATE_TASK + data = restore_state.RestoreStateData(hass) + now = date_util.utcnow() + + last_states = {} + for state in states: + restored_state = state.as_dict() + restored_state["attributes"] = json.loads( + json.dumps(restored_state["attributes"], cls=JSONEncoder) + ) + last_states[state.entity_id] = restore_state.StoredState( + State.from_dict(restored_state), now + ) + data.last_states = last_states + _LOGGER.debug("Restore cache: %s", data.last_states) + assert len(data.last_states) == len(states), f"Duplicate entity_id? {states}" + + async def get_restore_state_data() -> restore_state.RestoreStateData: + return data + + # Patch the singleton task in hass.data to return our new RestoreStateData + hass.data[key] = hass.async_create_task(get_restore_state_data()) + + +class MockDependency: + """Decorator to mock install a dependency.""" + + def __init__(self, root, *args): + """Initialize decorator.""" + self.root = root + self.submodules = args + + def __enter__(self): + """Start mocking.""" + + def resolve(mock, path): + """Resolve a mock.""" + if not path: + return mock + + return resolve(getattr(mock, path[0]), path[1:]) + + base = MagicMock() + to_mock = { + f"{self.root}.{tom}": resolve(base, tom.split(".")) + for tom in self.submodules + } + to_mock[self.root] = base + + self.patcher = patch.dict("sys.modules", to_mock) + self.patcher.start() + return base + + def __exit__(self, *exc): + """Stop mocking.""" + self.patcher.stop() + return False + + def __call__(self, func): + """Apply decorator.""" + + def run_mocked(*args, **kwargs): + """Run with mocked dependencies.""" + with self as base: + args = list(args) + [base] + func(*args, **kwargs) + + return run_mocked + + +class MockEntity(entity.Entity): + """Mock Entity class.""" + + def __init__(self, **values): + """Initialize an entity.""" + self._values = values + + if "entity_id" in values: + self.entity_id = values["entity_id"] + + @property + def name(self): + """Return the name of the entity.""" + return self._handle("name") + + @property + def should_poll(self): + """Return the ste of the polling.""" + return self._handle("should_poll") + + @property + def unique_id(self): + """Return the unique ID of the entity.""" + return self._handle("unique_id") + + @property + def state(self): + """Return the state of the entity.""" + return self._handle("state") + + @property + def available(self): + """Return True if entity is available.""" + return self._handle("available") + + @property + def device_info(self): + """Info how it links to a device.""" + return self._handle("device_info") + + @property + def device_class(self): + """Info how device should be classified.""" + return self._handle("device_class") + + @property + def unit_of_measurement(self): + """Info on the units the entity state is in.""" + return self._handle("unit_of_measurement") + + @property + def capability_attributes(self): + """Info about capabilities.""" + return self._handle("capability_attributes") + + @property + def supported_features(self): + """Info about supported features.""" + return self._handle("supported_features") + + @property + def entity_registry_enabled_default(self): + """Return if the entity should be enabled when first added to the entity registry.""" + return self._handle("entity_registry_enabled_default") + + def _handle(self, attr): + """Return attribute value.""" + if attr in self._values: + return self._values[attr] + return getattr(super(), attr) + + +@contextmanager +def mock_storage(data=None): + """Mock storage. + + Data is a dict {'key': {'version': version, 'data': data}} + + Written data will be converted to JSON to ensure JSON parsing works. + """ + if data is None: + data = {} + + orig_load = storage.Store._async_load + + async def mock_async_load(store): + """Mock version of load.""" + if store._data is None: + # No data to load + if store.key not in data: + return None + + mock_data = data.get(store.key) + + if "data" not in mock_data or "version" not in mock_data: + _LOGGER.error('Mock data needs "version" and "data"') + raise ValueError('Mock data needs "version" and "data"') + + store._data = mock_data + + # Route through original load so that we trigger migration + loaded = await orig_load(store) + _LOGGER.info("Loading data for %s: %s", store.key, loaded) + return loaded + + def mock_write_data(store, path, data_to_write): + """Mock version of write data.""" + _LOGGER.info("Writing data to %s: %s", store.key, data_to_write) + # To ensure that the data can be serialized + data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder)) + + async def mock_remove(store): + """Remove data.""" + data.pop(store.key, None) + + with patch( + "homeassistant.helpers.storage.Store._async_load", + side_effect=mock_async_load, + autospec=True, + ), patch( + "homeassistant.helpers.storage.Store._write_data", + side_effect=mock_write_data, + autospec=True, + ), patch( + "homeassistant.helpers.storage.Store.async_remove", + side_effect=mock_remove, + autospec=True, + ): + yield data + + +async def flush_store(store): + """Make sure all delayed writes of a store are written.""" + if store._data is None: + return + + store._async_cleanup_final_write_listener() + store._async_cleanup_delay_listener() + await store._async_handle_write_data() + + +async def get_system_health_info(hass, domain): + """Get system health info.""" + return await hass.data["system_health"]["info"][domain](hass) + + +def mock_integration(hass, module): + """Mock an integration.""" + integration = loader.Integration( + hass, f"homeassistant.components.{module.DOMAIN}", None, module.mock_manifest() + ) + + _LOGGER.info("Adding mock integration: %s", module.DOMAIN) + hass.data.setdefault(loader.DATA_INTEGRATIONS, {})[module.DOMAIN] = integration + hass.data.setdefault(loader.DATA_COMPONENTS, {})[module.DOMAIN] = module + + +def mock_entity_platform(hass, platform_path, module): + """Mock a entity platform. + + platform_path is in form light.hue. Will create platform + hue.light. + """ + domain, platform_name = platform_path.split(".") + mock_platform(hass, f"{platform_name}.{domain}", module) + + +def mock_platform(hass, platform_path, module=None): + """Mock a platform. + + platform_path is in form hue.config_flow. + """ + domain, platform_name = platform_path.split(".") + integration_cache = hass.data.setdefault(loader.DATA_INTEGRATIONS, {}) + module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + + if domain not in integration_cache: + mock_integration(hass, MockModule(domain)) + + _LOGGER.info("Adding mock integration platform: %s", platform_path) + module_cache[platform_path] = module or Mock() + + +def async_capture_events(hass, event_name): + """Create a helper that captures events.""" + events = [] + + @ha.callback + def capture_events(event): + events.append(event) + + hass.bus.async_listen(event_name, capture_events) + + return events + + +@ha.callback +def async_mock_signal(hass, signal): + """Catch all dispatches to a signal.""" + calls = [] + + @ha.callback + def mock_signal_handler(*args): + """Mock service call.""" + calls.append(args) + + hass.helpers.dispatcher.async_dispatcher_connect(signal, mock_signal_handler) + + return calls + + +class hashdict(dict): + """ + hashable dict implementation, suitable for use as a key into other dicts. + + >>> h1 = hashdict({"apples": 1, "bananas":2}) + >>> h2 = hashdict({"bananas": 3, "mangoes": 5}) + >>> h1+h2 + hashdict(apples=1, bananas=3, mangoes=5) + >>> d1 = {} + >>> d1[h1] = "salad" + >>> d1[h1] + 'salad' + >>> d1[h2] + Traceback (most recent call last): + ... + KeyError: hashdict(bananas=3, mangoes=5) + + based on answers from + http://stackoverflow.com/questions/1151658/python-hashable-dicts + + """ + + def __key(self): + return tuple(sorted(self.items())) + + def __repr__(self): # noqa: D105 no docstring + return ", ".join(f"{i[0]!s}={i[1]!r}" for i in self.__key()) + + def __hash__(self): # noqa: D105 no docstring + return hash(self.__key()) + + def __setitem__(self, key, value): # noqa: D105 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + def __delitem__(self, key): # noqa: D105 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + def clear(self): # noqa: D102 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + def pop(self, *args, **kwargs): # noqa: D102 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + def popitem(self, *args, **kwargs): # noqa: D102 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + def setdefault(self, *args, **kwargs): # noqa: D102 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + def update(self, *args, **kwargs): # noqa: D102 no docstring + raise TypeError(f"{self.__class__.__name__} does not support item assignment") + + # update is not ok because it mutates the object + # __add__ is ok because it creates a new object + # while the new object is under construction, it's ok to mutate it + def __add__(self, right): # noqa: D105 no docstring + result = hashdict(self) + dict.update(result, right) + return result + + +def assert_lists_same(a, b): + """Compare two lists, ignoring order.""" + assert collections.Counter([hashdict(i) for i in a]) == collections.Counter( + [hashdict(i) for i in b] + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ab574e3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,253 @@ +"""Set up some common test helper things.""" +import functools +import logging +from unittest.mock import patch + +import pytest +import requests_mock as _requests_mock + +from homeassistant import util +from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY +from homeassistant.auth.providers import homeassistant, legacy_api_password +from homeassistant.components.websocket_api.auth import ( + TYPE_AUTH, + TYPE_AUTH_OK, + TYPE_AUTH_REQUIRED, +) +from homeassistant.components.websocket_api.http import URL +from homeassistant.exceptions import ServiceNotFound +from homeassistant.setup import async_setup_component +from homeassistant.util import location + +from tests.ignore_uncaught_exceptions import ( + IGNORE_UNCAUGHT_EXCEPTIONS, + IGNORE_UNCAUGHT_JSON_EXCEPTIONS, +) + +pytest.register_assert_rewrite("tests.common") + +from tests.common import ( # noqa: E402, isort:skip + CLIENT_ID, + INSTANCES, + MockUser, + async_test_home_assistant, + mock_coro, + mock_storage as mock_storage, +) + + +logging.basicConfig(level=logging.DEBUG) +logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + + +def check_real(func): + """Force a function to require a keyword _test_real to be passed in.""" + + @functools.wraps(func) + async def guard_func(*args, **kwargs): + real = kwargs.pop("_test_real", None) + + if not real: + raise Exception( + 'Forgot to mock or pass "_test_real=True" to %s', func.__name__ + ) + + return await func(*args, **kwargs) + + return guard_func + + +# Guard a few functions that would make network connections +location.async_detect_location_info = check_real(location.async_detect_location_info) +util.get_local_ip = lambda: "127.0.0.1" + + +@pytest.fixture(autouse=True) +def verify_cleanup(): + """Verify that the test has cleaned up resources correctly.""" + yield + + if len(INSTANCES) >= 2: + count = len(INSTANCES) + for inst in INSTANCES: + inst.stop() + pytest.exit(f"Detected non stopped instances ({count}), aborting test run") + + +@pytest.fixture +def hass_storage(): + """Fixture to mock storage.""" + with mock_storage() as stored_data: + yield stored_data + + +@pytest.fixture +def hass(loop, hass_storage, request): + """Fixture to provide a test instance of Home Assistant.""" + + def exc_handle(loop, context): + """Handle exceptions by rethrowing them, which will fail the test.""" + exceptions.append(context["exception"]) + orig_exception_handler(loop, context) + + exceptions = [] + hass = loop.run_until_complete(async_test_home_assistant(loop)) + orig_exception_handler = loop.get_exception_handler() + loop.set_exception_handler(exc_handle) + + yield hass + + loop.run_until_complete(hass.async_stop(force=True)) + for ex in exceptions: + if ( + request.module.__name__, + request.function.__name__, + ) in IGNORE_UNCAUGHT_EXCEPTIONS: + continue + if isinstance(ex, ServiceNotFound): + continue + if ( + isinstance(ex, TypeError) + and "is not JSON serializable" in str(ex) + and (request.module.__name__, request.function.__name__) + in IGNORE_UNCAUGHT_JSON_EXCEPTIONS + ): + continue + raise ex + + +@pytest.fixture +def requests_mock(): + """Fixture to provide a requests mocker.""" + with _requests_mock.mock() as m: + yield m + + +@pytest.fixture +def mock_device_tracker_conf(): + """Prevent device tracker from reading/writing data.""" + devices = [] + + async def mock_update_config(path, id, entity): + devices.append(entity) + + with patch( + "homeassistant.components.device_tracker.legacy" + ".DeviceTracker.async_update_config", + side_effect=mock_update_config, + ), patch( + "homeassistant.components.device_tracker.legacy.async_load_config", + side_effect=lambda *args: mock_coro(devices), + ): + yield devices + + +@pytest.fixture +def hass_access_token(hass, hass_admin_user): + """Return an access token to access Home Assistant.""" + refresh_token = hass.loop.run_until_complete( + hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID) + ) + return hass.auth.async_create_access_token(refresh_token) + + +@pytest.fixture +def hass_owner_user(hass, local_auth): + """Return a Home Assistant admin user.""" + return MockUser(is_owner=True).add_to_hass(hass) + + +@pytest.fixture +def hass_admin_user(hass, local_auth): + """Return a Home Assistant admin user.""" + admin_group = hass.loop.run_until_complete( + hass.auth.async_get_group(GROUP_ID_ADMIN) + ) + return MockUser(groups=[admin_group]).add_to_hass(hass) + + +@pytest.fixture +def hass_read_only_user(hass, local_auth): + """Return a Home Assistant read only user.""" + read_only_group = hass.loop.run_until_complete( + hass.auth.async_get_group(GROUP_ID_READ_ONLY) + ) + return MockUser(groups=[read_only_group]).add_to_hass(hass) + + +@pytest.fixture +def hass_read_only_access_token(hass, hass_read_only_user): + """Return a Home Assistant read only user.""" + refresh_token = hass.loop.run_until_complete( + hass.auth.async_create_refresh_token(hass_read_only_user, CLIENT_ID) + ) + return hass.auth.async_create_access_token(refresh_token) + + +@pytest.fixture +def legacy_auth(hass): + """Load legacy API password provider.""" + prv = legacy_api_password.LegacyApiPasswordAuthProvider( + hass, + hass.auth._store, + {"type": "legacy_api_password", "api_password": "test-password"}, + ) + hass.auth._providers[(prv.type, prv.id)] = prv + return prv + + +@pytest.fixture +def local_auth(hass): + """Load local auth provider.""" + prv = homeassistant.HassAuthProvider( + hass, hass.auth._store, {"type": "homeassistant"} + ) + hass.auth._providers[(prv.type, prv.id)] = prv + return prv + + +@pytest.fixture +def hass_client(hass, aiohttp_client, hass_access_token): + """Return an authenticated HTTP client.""" + + async def auth_client(): + """Return an authenticated client.""" + return await aiohttp_client( + hass.http.app, headers={"Authorization": f"Bearer {hass_access_token}"} + ) + + return auth_client + + +@pytest.fixture +def hass_ws_client(aiohttp_client, hass_access_token, hass): + """Websocket client fixture connected to websocket server.""" + + async def create_client(hass=hass, access_token=hass_access_token): + """Create a websocket client.""" + assert await async_setup_component(hass, "websocket_api", {}) + + client = await aiohttp_client(hass.http.app) + + with patch("homeassistant.components.http.auth.setup_auth"): + websocket = await client.ws_connect(URL) + auth_resp = await websocket.receive_json() + assert auth_resp["type"] == TYPE_AUTH_REQUIRED + + if access_token is None: + await websocket.send_json( + {"type": TYPE_AUTH, "access_token": "incorrect"} + ) + else: + await websocket.send_json( + {"type": TYPE_AUTH, "access_token": access_token} + ) + + auth_ok = await websocket.receive_json() + assert auth_ok["type"] == TYPE_AUTH_OK + + # wrap in client + websocket.client = client + return websocket + + return create_client diff --git a/tests/ignore_uncaught_exceptions.py b/tests/ignore_uncaught_exceptions.py new file mode 100644 index 0000000..3b569eb --- /dev/null +++ b/tests/ignore_uncaught_exceptions.py @@ -0,0 +1,38 @@ +"""List of modules that have uncaught exceptions today. Will be shrunk over time.""" +IGNORE_UNCAUGHT_EXCEPTIONS = [ + ("tests.components.dyson.test_air_quality", "test_purecool_aiq_attributes"), + ("tests.components.dyson.test_air_quality", "test_purecool_aiq_update_state"), + ( + "tests.components.dyson.test_air_quality", + "test_purecool_component_setup_only_once", + ), + ("tests.components.dyson.test_air_quality", "test_purecool_aiq_without_discovery"), + ( + "tests.components.dyson.test_air_quality", + "test_purecool_aiq_empty_environment_state", + ), + ( + "tests.components.dyson.test_climate", + "test_setup_component_with_parent_discovery", + ), + ("tests.components.dyson.test_fan", "test_purecoollink_attributes"), + ("tests.components.dyson.test_fan", "test_purecool_turn_on"), + ("tests.components.dyson.test_fan", "test_purecool_set_speed"), + ("tests.components.dyson.test_fan", "test_purecool_turn_off"), + ("tests.components.dyson.test_fan", "test_purecool_set_dyson_speed"), + ("tests.components.dyson.test_fan", "test_purecool_oscillate"), + ("tests.components.dyson.test_fan", "test_purecool_set_night_mode"), + ("tests.components.dyson.test_fan", "test_purecool_set_auto_mode"), + ("tests.components.dyson.test_fan", "test_purecool_set_angle"), + ("tests.components.dyson.test_fan", "test_purecool_set_flow_direction_front"), + ("tests.components.dyson.test_fan", "test_purecool_set_timer"), + ("tests.components.dyson.test_fan", "test_purecool_update_state"), + ("tests.components.dyson.test_fan", "test_purecool_update_state_filter_inv"), + ("tests.components.dyson.test_fan", "test_purecool_component_setup_only_once"), + ("tests.components.dyson.test_sensor", "test_purecool_component_setup_only_once"), + ("tests.components.ios.test_init", "test_creating_entry_sets_up_sensor"), + ("tests.components.ios.test_init", "test_not_configuring_ios_not_creates_entry"), + ("tests.components.local_file.test_camera", "test_file_not_readable"), +] + +IGNORE_UNCAUGHT_JSON_EXCEPTIONS = [] diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py new file mode 100644 index 0000000..90730b3 --- /dev/null +++ b/tests/test_config_flow.py @@ -0,0 +1,186 @@ +from datetime import timedelta +from unittest.mock import patch + +import librouteros +import pytest + +from homeassistant import data_entry_flow +from custom_components import mikrotik_router + +from homeassistant.const import ( + CONF_NAME, + CONF_HOST, + CONF_PORT, + CONF_USERNAME, + CONF_PASSWORD, + CONF_SSL, +) + +from . import MOCK_DATA + +from tests.common import MockConfigEntry + +DEMO_USER_INPUT = { + CONF_NAME: "Home router", + CONF_HOST: "0.0.0.0", + CONF_USERNAME: "username", + CONF_PASSWORD: "password", + CONF_PORT: 8278, + CONF_SSL: True, +} + +DEMO_CONFIG_ENTRY = { + CONF_NAME: "Home router", + CONF_HOST: "0.0.0.0", + CONF_USERNAME: "username", + CONF_PASSWORD: "password", + CONF_PORT: 8278, + CONF_SSL: True, + mikrotik_router.mikrotik_controller.CONF_SCAN_INTERVAL: 60, + mikrotik_router.mikrotik_controller.CONF_UNIT_OF_MEASUREMENT: "Mbps", + mikrotik_router.mikrotik_controller.CONF_TRACK_IFACE_CLIENTS: True, + mikrotik_router.mikrotik_controller.CONF_TRACK_HOSTS: True, + mikrotik_router.mikrotik_controller.CONF_TRACK_HOSTS_TIMEOUT: 180, +} + + +@pytest.fixture(name="api") +def mock_mikrotik_api(): + """Mock an api.""" + with patch("librouteros.connect"): + yield + + +@pytest.fixture(name="auth_error") +def mock_api_authentication_error(): + """Mock an api.""" + with patch( + "librouteros.connect", + side_effect=librouteros.exceptions.TrapError("invalid user name or password"), + ): + yield + + +@pytest.fixture(name="conn_error") +def mock_api_connection_error(): + """Mock an api.""" + with patch( + "librouteros.connect", side_effect=librouteros.exceptions.ConnectionClosed + ): + yield + + +async def test_import(hass, api): + """Test import step.""" + result = await hass.config_entries.flow.async_init( + mikrotik_router.DOMAIN, context={"source": "import"}, data=MOCK_DATA + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == "Mikrotik" + assert result["data"][CONF_NAME] == "Mikrotik" + assert result["data"][CONF_HOST] == "10.0.0.1" + assert result["data"][CONF_USERNAME] == "admin" + assert result["data"][CONF_PASSWORD] == "admin" + assert result["data"][CONF_PORT] == 0 + assert result["data"][CONF_SSL] is False + + +async def test_flow_works(hass, api): + """Test config flow.""" + + result = await hass.config_entries.flow.async_init( + mikrotik_router.DOMAIN, context={"source": "user"} + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "user" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=DEMO_USER_INPUT + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["title"] == "Home router" + assert result["data"][CONF_NAME] == "Home router" + assert result["data"][CONF_HOST] == "0.0.0.0" + assert result["data"][CONF_USERNAME] == "username" + assert result["data"][CONF_PASSWORD] == "password" + assert result["data"][CONF_PORT] == 8278 + assert result["data"][CONF_SSL] is True + + +async def test_options(hass): + """Test updating options.""" + entry = MockConfigEntry(domain=mikrotik_router.DOMAIN, data=DEMO_CONFIG_ENTRY) + entry.add_to_hass(hass) + + result = await hass.config_entries.options.async_init(entry.entry_id) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "device_tracker" + + result = await hass.config_entries.options.async_configure( + result["flow_id"], + user_input={ + mikrotik_router.mikrotik_controller.CONF_SCAN_INTERVAL: 30, + mikrotik_router.mikrotik_controller.CONF_UNIT_OF_MEASUREMENT: "Kbps", + mikrotik_router.mikrotik_controller.CONF_TRACK_IFACE_CLIENTS: True, + mikrotik_router.mikrotik_controller.CONF_TRACK_HOSTS: False, + mikrotik_router.mikrotik_controller.CONF_TRACK_HOSTS_TIMEOUT: 180, + }, + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result["data"] == { + mikrotik_router.mikrotik_controller.CONF_SCAN_INTERVAL: 30, + mikrotik_router.mikrotik_controller.CONF_UNIT_OF_MEASUREMENT: "Kbps", + mikrotik_router.mikrotik_controller.CONF_TRACK_IFACE_CLIENTS: True, + mikrotik_router.mikrotik_controller.CONF_TRACK_HOSTS: False, + mikrotik_router.mikrotik_controller.CONF_TRACK_HOSTS_TIMEOUT: 180, + } + + +async def test_name_exists(hass, api): + """Test name already configured.""" + + entry = MockConfigEntry(domain=mikrotik_router.DOMAIN, data=DEMO_CONFIG_ENTRY) + entry.add_to_hass(hass) + user_input = DEMO_USER_INPUT.copy() + user_input[CONF_HOST] = "0.0.0.1" + + result = await hass.config_entries.flow.async_init( + mikrotik_router.DOMAIN, context={"source": "user"} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=user_input + ) + + assert result["type"] == "form" + assert result["errors"] == {"base": "name_exists"} + + +async def test_connection_error(hass, conn_error): + """Test error when connection is unsuccessful.""" + + result = await hass.config_entries.flow.async_init( + mikrotik_router.DOMAIN, context={"source": "user"} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=DEMO_USER_INPUT + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"host": "cannot_connect"} + + +async def test_wrong_credentials(hass, auth_error): + """Test error when credentials are wrong.""" + + result = await hass.config_entries.flow.async_init( + mikrotik_router.DOMAIN, context={"source": "user"} + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input=DEMO_USER_INPUT + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["errors"] == {"host": "cannot_connect"}