diff --git a/homeassistant/components/config/config_entries.py b/homeassistant/components/config/config_entries.py index 68967439b2a0..21fc55ebafc7 100644 --- a/homeassistant/components/config/config_entries.py +++ b/homeassistant/components/config/config_entries.py @@ -54,8 +54,18 @@ class ConfigManagerEntryIndexView(HomeAssistantView): """List available config entries.""" hass = request.app["hass"] - return self.json( - [ + results = [] + + for entry in hass.config_entries.async_entries(): + handler = config_entries.HANDLERS.get(entry.domain) + supports_options = ( + # Guard in case handler is no longer registered (custom compnoent etc) + handler is not None + # pylint: disable=comparison-with-callable + and handler.async_get_options_flow + != config_entries.ConfigFlow.async_get_options_flow + ) + results.append( { "entry_id": entry.entry_id, "domain": entry.domain, @@ -63,14 +73,11 @@ class ConfigManagerEntryIndexView(HomeAssistantView): "source": entry.source, "state": entry.state, "connection_class": entry.connection_class, - "supports_options": hasattr( - config_entries.HANDLERS.get(entry.domain), - "async_get_options_flow", - ), + "supports_options": supports_options, } - for entry in hass.config_entries.async_entries() - ] - ) + ) + + return self.json(results) class ConfigManagerEntryResourceView(HomeAssistantView): diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index f341cc2ce029..922878fb3241 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -1,29 +1,11 @@ """Helpers for data entry flows for config entries.""" -from functools import partial - +from typing import Callable, Awaitable, Union from homeassistant import config_entries from .typing import HomeAssistantType - # mypy: allow-untyped-defs - -def register_discovery_flow(domain, title, discovery_function, connection_class): - """Register flow for discovered integrations that not require auth.""" - config_entries.HANDLERS.register(domain)( - partial( - DiscoveryFlowHandler, domain, title, discovery_function, connection_class - ) - ) - - -def register_webhook_flow(domain, title, description_placeholder, allow_multiple=False): - """Register flow for webhook integrations.""" - config_entries.HANDLERS.register(domain)( - partial( - WebhookFlowHandler, domain, title, description_placeholder, allow_multiple - ) - ) +DiscoveryFunctionType = Callable[[], Union[Awaitable[bool], bool]] class DiscoveryFlowHandler(config_entries.ConfigFlow): @@ -31,7 +13,13 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow): VERSION = 1 - def __init__(self, domain, title, discovery_function, connection_class): + def __init__( + self, + domain: str, + title: str, + discovery_function: DiscoveryFunctionType, + connection_class: str, + ) -> None: """Initialize the discovery config flow.""" self._domain = domain self._title = title @@ -91,12 +79,35 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow): return self.async_create_entry(title=self._title, data={}) +def register_discovery_flow( + domain: str, + title: str, + discovery_function: DiscoveryFunctionType, + connection_class: str, +) -> None: + """Register flow for discovered integrations that not require auth.""" + + class DiscoveryFlow(DiscoveryFlowHandler): + """Discovery flow handler.""" + + def __init__(self) -> None: + super().__init__(domain, title, discovery_function, connection_class) + + config_entries.HANDLERS.register(domain)(DiscoveryFlow) + + class WebhookFlowHandler(config_entries.ConfigFlow): """Handle a webhook config flow.""" VERSION = 1 - def __init__(self, domain, title, description_placeholder, allow_multiple): + def __init__( + self, + domain: str, + title: str, + description_placeholder: dict, + allow_multiple: bool, + ) -> None: """Initialize the discovery config flow.""" self._domain = domain self._title = title @@ -131,6 +142,20 @@ class WebhookFlowHandler(config_entries.ConfigFlow): ) +def register_webhook_flow( + domain: str, title: str, description_placeholder: dict, allow_multiple: bool = False +) -> None: + """Register flow for webhook integrations.""" + + class WebhookFlow(WebhookFlowHandler): + """Webhook flow handler.""" + + def __init__(self) -> None: + super().__init__(domain, title, description_placeholder, allow_multiple) + + config_entries.HANDLERS.register(domain)(WebhookFlow) + + async def webhook_async_remove_entry( hass: HomeAssistantType, entry: config_entries.ConfigEntry ) -> None: diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index efe476b70556..13cd8da0597c 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -37,65 +37,61 @@ def client(hass, hass_client): yield hass.loop.run_until_complete(hass_client()) -@HANDLERS.register("comp1") -class Comp1ConfigFlow: - """Config flow with options flow.""" - - @staticmethod - @callback - def async_get_options_flow(config, options): - """Get options flow.""" - pass - - -@HANDLERS.register("comp2") -class Comp2ConfigFlow: - """Config flow without options flow.""" - - def __init__(self): - """Init.""" - pass - - async def test_get_entries(hass, client): """Test get entries.""" - MockConfigEntry( - domain="comp1", - title="Test 1", - source="bla", - connection_class=core_ce.CONN_CLASS_LOCAL_POLL, - ).add_to_hass(hass) - MockConfigEntry( - domain="comp2", - title="Test 2", - source="bla2", - state=core_ce.ENTRY_STATE_LOADED, - connection_class=core_ce.CONN_CLASS_ASSUMED, - ).add_to_hass(hass) + with patch.dict(HANDLERS, clear=True): - resp = await client.get("/api/config/config_entries/entry") - assert resp.status == 200 - data = await resp.json() - for entry in data: - entry.pop("entry_id") - assert data == [ - { - "domain": "comp1", - "title": "Test 1", - "source": "bla", - "state": "not_loaded", - "connection_class": "local_poll", - "supports_options": True, - }, - { - "domain": "comp2", - "title": "Test 2", - "source": "bla2", - "state": "loaded", - "connection_class": "assumed", - "supports_options": False, - }, - ] + @HANDLERS.register("comp1") + class Comp1ConfigFlow: + """Config flow with options flow.""" + + @staticmethod + @callback + def async_get_options_flow(config, options): + """Get options flow.""" + pass + + hass.helpers.config_entry_flow.register_discovery_flow( + "comp2", "Comp 2", lambda: None, core_ce.CONN_CLASS_ASSUMED + ) + + MockConfigEntry( + domain="comp1", + title="Test 1", + source="bla", + connection_class=core_ce.CONN_CLASS_LOCAL_POLL, + ).add_to_hass(hass) + MockConfigEntry( + domain="comp2", + title="Test 2", + source="bla2", + state=core_ce.ENTRY_STATE_LOADED, + connection_class=core_ce.CONN_CLASS_ASSUMED, + ).add_to_hass(hass) + + resp = await client.get("/api/config/config_entries/entry") + assert resp.status == 200 + data = await resp.json() + for entry in data: + entry.pop("entry_id") + assert data == [ + { + "domain": "comp1", + "title": "Test 1", + "source": "bla", + "state": "not_loaded", + "connection_class": "local_poll", + "supports_options": True, + }, + { + "domain": "comp2", + "title": "Test 2", + "source": "bla2", + "state": "loaded", + "connection_class": "assumed", + "supports_options": False, + }, + ] @asyncio.coroutine