Add generic classes BaseFlowHandler and BaseFlowManager (#111814)

* Add generic classes BaseFlowHandler and BaseFlowManager

* Migrate zwave_js

* Update tests

* Update tests

* Address review comments
This commit is contained in:
Erik Montnemery 2024-02-29 16:52:39 +01:00 committed by GitHub
parent 3a8b6412ed
commit a0e558c457
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 341 additions and 273 deletions

View File

@ -91,6 +91,8 @@ async def auth_manager_from_config(
class AuthManagerFlowManager(data_entry_flow.FlowManager):
"""Manage authentication flows."""
_flow_result = FlowResult
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
"""Init auth manager flows."""
super().__init__(hass)
@ -110,7 +112,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: FlowResult
self, flow: data_entry_flow.BaseFlowHandler, result: FlowResult
) -> FlowResult:
"""Return a user as result of login flow."""
flow = cast(LoginFlow, flow)

View File

@ -96,6 +96,8 @@ class MultiFactorAuthModule:
class SetupFlow(data_entry_flow.FlowHandler):
"""Handler for the setup flow."""
_flow_result = FlowResult
def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
) -> None:

View File

@ -184,6 +184,8 @@ async def load_auth_provider_module(
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
_flow_result = FlowResult
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider

View File

@ -38,6 +38,8 @@ _LOGGER = logging.getLogger(__name__)
class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow( # type: ignore[override]
self,
handler_key: str,
@ -54,7 +56,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
return await mfa_module.async_setup_flow(user_id)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result)

View File

@ -48,9 +48,11 @@ class ConfirmRepairFlow(RepairsFlow):
)
class RepairsFlowManager(data_entry_flow.FlowManager):
class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowResult]):
"""Manage repairs flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow(
self,
handler_key: str,
@ -82,7 +84,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager):
return flow
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Complete a fix flow."""
if result.get("type") != data_entry_flow.FlowResultType.ABORT:

View File

@ -7,9 +7,11 @@ from homeassistant import data_entry_flow
from homeassistant.core import HomeAssistant
class RepairsFlow(data_entry_flow.FlowHandler):
class RepairsFlow(data_entry_flow.BaseFlowHandler[data_entry_flow.FlowResult]):
"""Handle a flow for fixing an issue."""
_flow_result = data_entry_flow.FlowResult
issue_id: str
data: dict[str, str | int | float | None] | None

View File

@ -11,7 +11,6 @@ from serial.tools import list_ports
import voluptuous as vol
from zwave_js_server.version import VersionInfo, get_server_version
from homeassistant import config_entries, exceptions
from homeassistant.components import usb
from homeassistant.components.hassio import (
AddonError,
@ -22,14 +21,21 @@ from homeassistant.components.hassio import (
is_hassio,
)
from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.config_entries import (
SOURCE_USB,
ConfigEntriesFlowManager,
ConfigEntry,
ConfigEntryBaseFlow,
ConfigEntryState,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
OptionsFlowManager,
)
from homeassistant.const import CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import (
AbortFlow,
FlowHandler,
FlowManager,
FlowResult,
)
from homeassistant.data_entry_flow import AbortFlow, BaseFlowManager
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from . import disconnect_client
@ -156,7 +162,7 @@ async def async_get_usb_ports(hass: HomeAssistant) -> dict[str, str]:
return await hass.async_add_executor_job(get_usb_ports)
class BaseZwaveJSFlow(FlowHandler, ABC):
class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
"""Represent the base config flow for Z-Wave JS."""
def __init__(self) -> None:
@ -176,12 +182,12 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
@property
@abstractmethod
def flow_manager(self) -> FlowManager:
def flow_manager(self) -> BaseFlowManager:
"""Return the flow manager of the flow."""
async def async_step_install_addon(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Install Z-Wave JS add-on."""
if not self.install_task:
self.install_task = self.hass.async_create_task(self._async_install_addon())
@ -207,13 +213,13 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
async def async_step_install_failed(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Add-on installation failed."""
return self.async_abort(reason="addon_install_failed")
async def async_step_start_addon(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Start Z-Wave JS add-on."""
if not self.start_task:
self.start_task = self.hass.async_create_task(self._async_start_addon())
@ -237,7 +243,7 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
async def async_step_start_failed(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Add-on start failed."""
return self.async_abort(reason="addon_start_failed")
@ -275,13 +281,13 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
@abstractmethod
async def async_step_configure_addon(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Ask for config for Z-Wave JS add-on."""
@abstractmethod
async def async_step_finish_addon_setup(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Prepare info needed to complete the config entry.
Get add-on discovery info and server version info.
@ -325,7 +331,7 @@ class BaseZwaveJSFlow(FlowHandler, ABC):
return discovery_info_config
class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
class ZWaveJSConfigFlow(BaseZwaveJSFlow, ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Z-Wave JS."""
VERSION = 1
@ -338,19 +344,19 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
self._usb_discovery = False
@property
def flow_manager(self) -> config_entries.ConfigEntriesFlowManager:
def flow_manager(self) -> ConfigEntriesFlowManager:
"""Return the correct flow manager."""
return self.hass.config_entries.flow
@staticmethod
@callback
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
config_entry: ConfigEntry,
) -> OptionsFlowHandler:
"""Return the options flow."""
return OptionsFlowHandler(config_entry)
async def async_step_import(self, data: dict[str, Any]) -> FlowResult:
async def async_step_import(self, data: dict[str, Any]) -> ConfigFlowResult:
"""Handle imported data.
This step will be used when importing data
@ -364,7 +370,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle the initial step."""
if is_hassio(self.hass):
return await self.async_step_on_supervisor()
@ -373,7 +379,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle zeroconf discovery."""
home_id = str(discovery_info.properties["homeId"])
await self.async_set_unique_id(home_id)
@ -384,7 +390,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_zeroconf_confirm(
self, user_input: dict | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Confirm the setup."""
if user_input is not None:
return await self.async_step_manual({CONF_URL: self.ws_address})
@ -398,7 +404,9 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
},
)
async def async_step_usb(self, discovery_info: usb.UsbServiceInfo) -> FlowResult:
async def async_step_usb(
self, discovery_info: usb.UsbServiceInfo
) -> ConfigFlowResult:
"""Handle USB Discovery."""
if not is_hassio(self.hass):
return self.async_abort(reason="discovery_requires_supervisor")
@ -441,7 +449,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_usb_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle USB Discovery confirmation."""
if user_input is None:
return self.async_show_form(
@ -455,7 +463,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_manual(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle a manual configuration."""
if user_input is None:
return self.async_show_form(
@ -491,7 +499,9 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
step_id="manual", data_schema=get_manual_schema(user_input), errors=errors
)
async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult:
async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> ConfigFlowResult:
"""Receive configuration from add-on discovery info.
This flow is triggered by the Z-Wave JS add-on.
@ -517,7 +527,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_hassio_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Confirm the add-on discovery."""
if user_input is not None:
return await self.async_step_on_supervisor(
@ -528,7 +538,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_on_supervisor(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle logic when on Supervisor host."""
if user_input is None:
return self.async_show_form(
@ -563,7 +573,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_configure_addon(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Ask for config for Z-Wave JS add-on."""
addon_info = await self._async_get_addon_info()
addon_config = addon_info.options
@ -628,7 +638,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_finish_addon_setup(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Prepare info needed to complete the config entry.
Get add-on discovery info and server version info.
@ -638,7 +648,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
discovery_info = await self._async_get_addon_discovery_info()
self.ws_address = f"ws://{discovery_info['host']}:{discovery_info['port']}"
if not self.unique_id or self.context["source"] == config_entries.SOURCE_USB:
if not self.unique_id or self.context["source"] == SOURCE_USB:
if not self.version_info:
try:
self.version_info = await async_get_version_info(
@ -664,7 +674,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
return self._async_create_entry_from_vars()
@callback
def _async_create_entry_from_vars(self) -> FlowResult:
def _async_create_entry_from_vars(self) -> ConfigFlowResult:
"""Return a config entry for the flow."""
# Abort any other flows that may be in progress
for progress in self._async_in_progress():
@ -685,10 +695,10 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN):
)
class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
class OptionsFlowHandler(BaseZwaveJSFlow, OptionsFlow):
"""Handle an options flow for Z-Wave JS."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
def __init__(self, config_entry: ConfigEntry) -> None:
"""Set up the options flow."""
super().__init__()
self.config_entry = config_entry
@ -696,7 +706,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
self.revert_reason: str | None = None
@property
def flow_manager(self) -> config_entries.OptionsFlowManager:
def flow_manager(self) -> OptionsFlowManager:
"""Return the correct flow manager."""
return self.hass.config_entries.options
@ -707,7 +717,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Manage the options."""
if is_hassio(self.hass):
return await self.async_step_on_supervisor()
@ -716,7 +726,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_manual(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle a manual configuration."""
if user_input is None:
return self.async_show_form(
@ -759,7 +769,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_on_supervisor(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle logic when on Supervisor host."""
if user_input is None:
return self.async_show_form(
@ -780,7 +790,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_configure_addon(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Ask for config for Z-Wave JS add-on."""
addon_info = await self._async_get_addon_info()
addon_config = addon_info.options
@ -819,7 +829,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
if (
self.config_entry.data.get(CONF_USE_ADDON)
and self.config_entry.state == config_entries.ConfigEntryState.LOADED
and self.config_entry.state == ConfigEntryState.LOADED
):
# Disconnect integration before restarting add-on.
await disconnect_client(self.hass, self.config_entry)
@ -868,13 +878,13 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
async def async_step_start_failed(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Add-on start failed."""
return await self.async_revert_addon_config(reason="addon_start_failed")
async def async_step_finish_addon_setup(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Prepare info needed to complete the config entry update.
Get add-on discovery info and server version info.
@ -918,7 +928,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
self.hass.config_entries.async_schedule_reload(self.config_entry.entry_id)
return self.async_create_entry(title=TITLE, data={})
async def async_revert_addon_config(self, reason: str) -> FlowResult:
async def async_revert_addon_config(self, reason: str) -> ConfigFlowResult:
"""Abort the options flow.
If the add-on options have been changed, revert those and restart add-on.
@ -944,11 +954,11 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow):
return await self.async_step_configure_addon(addon_config_input)
class CannotConnect(exceptions.HomeAssistantError):
class CannotConnect(HomeAssistantError):
"""Indicate connection error."""
class InvalidInput(exceptions.HomeAssistantError):
class InvalidInput(HomeAssistantError):
"""Error to indicate input data is invalid."""
def __init__(self, error: str) -> None:

View File

@ -242,6 +242,9 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = {
}
ConfigFlowResult = FlowResult
class ConfigEntry:
"""Hold a configuration entry."""
@ -903,7 +906,7 @@ class ConfigEntry:
@callback
def async_get_active_flows(
self, hass: HomeAssistant, sources: set[str]
) -> Generator[FlowResult, None, None]:
) -> Generator[ConfigFlowResult, None, None]:
"""Get any active flows of certain sources for this entry."""
return (
flow
@ -970,9 +973,11 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
"""Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult
def __init__(
self,
hass: HomeAssistant,
@ -1010,7 +1015,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Start a configuration flow."""
if not context or "source" not in context:
raise KeyError("Context not set or doesn't have a source set")
@ -1024,7 +1029,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
and await _support_single_config_entry_only(self.hass, handler)
and self.config_entries.async_entries(handler, include_ignore=False)
):
return FlowResult(
return ConfigFlowResult(
type=data_entry_flow.FlowResultType.ABORT,
flow_id=flow_id,
handler=handler,
@ -1065,7 +1070,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
handler: str,
context: dict,
data: Any,
) -> tuple[data_entry_flow.FlowHandler, FlowResult]:
) -> tuple[ConfigFlow, ConfigFlowResult]:
"""Run the init in a task to allow it to be canceled at shutdown."""
flow = await self.async_create_flow(handler, context=context, data=data)
if not flow:
@ -1093,8 +1098,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
self._discovery_debouncer.async_shutdown()
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> ConfigFlowResult:
"""Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow)
@ -1128,7 +1133,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
and flow.context["source"] != SOURCE_IGNORE
and self.config_entries.async_entries(flow.handler, include_ignore=False)
):
return FlowResult(
return ConfigFlowResult(
type=data_entry_flow.FlowResultType.ABORT,
flow_id=flow.flow_id,
handler=flow.handler,
@ -1213,7 +1218,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
return flow
async def async_post_init(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> None:
"""After a flow is initialised trigger new flow notifications."""
source = flow.context["source"]
@ -1852,7 +1857,13 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured")
class ConfigFlow(data_entry_flow.FlowHandler):
class ConfigEntryBaseFlow(data_entry_flow.BaseFlowHandler[ConfigFlowResult]):
"""Base class for config and option flows."""
_flow_result = ConfigFlowResult
class ConfigFlow(ConfigEntryBaseFlow):
"""Base class for config flows with some helpers."""
def __init_subclass__(cls, *, domain: str | None = None, **kwargs: Any) -> None:
@ -2008,7 +2019,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
self,
include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None,
) -> list[data_entry_flow.FlowResult]:
) -> list[ConfigFlowResult]:
"""Return other in progress flows for current domain."""
return [
flw
@ -2020,22 +2031,18 @@ class ConfigFlow(data_entry_flow.FlowHandler):
if flw["flow_id"] != self.flow_id
]
async def async_step_ignore(
self, user_input: dict[str, Any]
) -> data_entry_flow.FlowResult:
async def async_step_ignore(self, user_input: dict[str, Any]) -> ConfigFlowResult:
"""Ignore this config flow."""
await self.async_set_unique_id(user_input["unique_id"], raise_on_progress=False)
return self.async_create_entry(title=user_input["title"], data={})
async def async_step_unignore(
self, user_input: dict[str, Any]
) -> data_entry_flow.FlowResult:
async def async_step_unignore(self, user_input: dict[str, Any]) -> ConfigFlowResult:
"""Rediscover a config entry by it's unique_id."""
return self.async_abort(reason="not_implemented")
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initiated by the user."""
return self.async_abort(reason="not_implemented")
@ -2068,14 +2075,14 @@ class ConfigFlow(data_entry_flow.FlowHandler):
async def _async_step_discovery_without_unique_id(
self,
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by discovery."""
await self._async_handle_discovery_without_unique_id()
return await self.async_step_user()
async def async_step_discovery(
self, discovery_info: DiscoveryInfoType
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by discovery."""
return await self._async_step_discovery_without_unique_id()
@ -2085,7 +2092,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
*,
reason: str,
description_placeholders: Mapping[str, str] | None = None,
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Abort the config flow."""
# Remove reauth notification if no reauth flows are in progress
if self.source == SOURCE_REAUTH and not any(
@ -2104,55 +2111,53 @@ class ConfigFlow(data_entry_flow.FlowHandler):
async def async_step_bluetooth(
self, discovery_info: BluetoothServiceInfoBleak
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by Bluetooth discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by DHCP discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_hassio(
self, discovery_info: HassioServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by HASS IO discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_integration_discovery(
self, discovery_info: DiscoveryInfoType
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by integration specific discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_homekit(
self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by Homekit discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_mqtt(
self, discovery_info: MqttServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by MQTT discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_ssdp(
self, discovery_info: SsdpServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by SSDP discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_usb(
self, discovery_info: UsbServiceInfo
) -> data_entry_flow.FlowResult:
async def async_step_usb(self, discovery_info: UsbServiceInfo) -> ConfigFlowResult:
"""Handle a flow initialized by USB discovery."""
return await self._async_step_discovery_without_unique_id()
async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Handle a flow initialized by Zeroconf discovery."""
return await self._async_step_discovery_without_unique_id()
@ -2165,7 +2170,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
description: str | None = None,
description_placeholders: Mapping[str, str] | None = None,
options: Mapping[str, Any] | None = None,
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Finish config flow and create a config entry."""
result = super().async_create_entry(
title=title,
@ -2175,6 +2180,8 @@ class ConfigFlow(data_entry_flow.FlowHandler):
)
result["options"] = options or {}
result["minor_version"] = self.MINOR_VERSION
result["version"] = self.VERSION
return result
@ -2188,7 +2195,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
data: Mapping[str, Any] | UndefinedType = UNDEFINED,
options: Mapping[str, Any] | UndefinedType = UNDEFINED,
reason: str = "reauth_successful",
) -> data_entry_flow.FlowResult:
) -> ConfigFlowResult:
"""Update config entry, reload config entry and finish config flow."""
result = self.hass.config_entries.async_update_entry(
entry=entry,
@ -2202,9 +2209,11 @@ class ConfigFlow(data_entry_flow.FlowHandler):
return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.FlowManager):
class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult
def _async_get_config_entry(self, config_entry_id: str) -> ConfigEntry:
"""Return config entry or raise if not found."""
entry = self.hass.config_entries.async_get_entry(config_entry_id)
@ -2229,8 +2238,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
return handler.async_get_options_flow(entry)
async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry.
Flow.handler and entry_id is the same thing to map flow with entry.
@ -2249,7 +2258,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
result["result"] = True
return result
async def _async_setup_preview(self, flow: data_entry_flow.FlowHandler) -> None:
async def _async_setup_preview(self, flow: data_entry_flow.BaseFlowHandler) -> None:
"""Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler)
await _load_integration(self.hass, entry.domain, {})
@ -2258,7 +2267,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager):
await flow.async_setup_preview(self.hass)
class OptionsFlow(data_entry_flow.FlowHandler):
class OptionsFlow(ConfigEntryBaseFlow):
"""Base class for config options flows."""
handler: str

View File

@ -11,7 +11,7 @@ from enum import StrEnum
from functools import partial
import logging
from types import MappingProxyType
from typing import Any, Required, TypedDict
from typing import Any, Generic, Required, TypedDict, TypeVar
import voluptuous as vol
@ -75,6 +75,7 @@ FLOW_NOT_COMPLETE_STEPS = {
FlowResultType.MENU,
}
STEP_ID_OPTIONAL_STEPS = {
FlowResultType.EXTERNAL_STEP,
FlowResultType.FORM,
@ -83,6 +84,9 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult")
@dataclass(slots=True)
class BaseServiceInfo:
"""Base class for discovery ServiceInfo."""
@ -163,26 +167,6 @@ class FlowResult(TypedDict, total=False):
version: int
@callback
def _async_flow_handler_to_flow_result(
flows: Iterable[FlowHandler], include_uninitialized: bool
) -> list[FlowResult]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
for flow in flows:
if not include_uninitialized and flow.cur_step is None:
continue
result = FlowResult(
flow_id=flow.flow_id,
handler=flow.handler,
context=flow.context,
)
if flow.cur_step:
result["step_id"] = flow.cur_step["step_id"]
results.append(result)
return results
def _map_error_to_schema_errors(
schema_errors: dict[str, Any],
error: vol.Invalid,
@ -206,9 +190,11 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC):
class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
"""Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT]
def __init__(
self,
hass: HomeAssistant,
@ -216,9 +202,9 @@ class FlowManager(abc.ABC):
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[str] = set()
self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[FlowHandler]] = {}
self._init_data_process_index: dict[type, set[FlowHandler]] = {}
self._progress: dict[str, BaseFlowHandler] = {}
self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {}
self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {}
@abc.abstractmethod
async def async_create_flow(
@ -227,7 +213,7 @@ class FlowManager(abc.ABC):
*,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> FlowHandler:
) -> BaseFlowHandler[_FlowResultT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@ -235,11 +221,13 @@ class FlowManager(abc.ABC):
@abc.abstractmethod
async def async_finish_flow(
self, flow: FlowHandler, result: FlowResult
) -> FlowResult:
self, flow: BaseFlowHandler, result: _FlowResultT
) -> _FlowResultT:
"""Finish a data entry flow."""
async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None:
async def async_post_init(
self, flow: BaseFlowHandler, result: _FlowResultT
) -> None:
"""Entry has finished executing its first step asynchronously."""
@callback
@ -262,16 +250,16 @@ class FlowManager(abc.ABC):
return False
@callback
def async_get(self, flow_id: str) -> FlowResult:
def async_get(self, flow_id: str) -> _FlowResultT:
"""Return a flow in progress as a partial FlowResult."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
return _async_flow_handler_to_flow_result([flow], False)[0]
return self._async_flow_handler_to_flow_result([flow], False)[0]
@callback
def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]:
def async_progress(self, include_uninitialized: bool = False) -> list[_FlowResultT]:
"""Return the flows in progress as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
return self._async_flow_handler_to_flow_result(
self._progress.values(), include_uninitialized
)
@ -281,13 +269,13 @@ class FlowManager(abc.ABC):
handler: str,
include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None,
) -> list[FlowResult]:
) -> list[_FlowResultT]:
"""Return the flows in progress by handler as a partial FlowResult.
If match_context is specified, only return flows with a context that
is a superset of match_context.
"""
return _async_flow_handler_to_flow_result(
return self._async_flow_handler_to_flow_result(
self._async_progress_by_handler(handler, match_context),
include_uninitialized,
)
@ -298,9 +286,9 @@ class FlowManager(abc.ABC):
init_data_type: type,
matcher: Callable[[Any], bool],
include_uninitialized: bool = False,
) -> list[FlowResult]:
) -> list[_FlowResultT]:
"""Return flows in progress init matching by data type as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
return self._async_flow_handler_to_flow_result(
(
progress
for progress in self._init_data_process_index.get(init_data_type, set())
@ -312,7 +300,7 @@ class FlowManager(abc.ABC):
@callback
def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None
) -> list[FlowHandler]:
) -> list[BaseFlowHandler[_FlowResultT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@ -329,7 +317,7 @@ class FlowManager(abc.ABC):
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
) -> FlowResult:
) -> _FlowResultT:
"""Start a data entry flow."""
if context is None:
context = {}
@ -352,9 +340,9 @@ class FlowManager(abc.ABC):
async def async_configure(
self, flow_id: str, user_input: dict | None = None
) -> FlowResult:
) -> _FlowResultT:
"""Continue a data entry flow."""
result: FlowResult | None = None
result: _FlowResultT | None = None
while not result or result["type"] == FlowResultType.SHOW_PROGRESS_DONE:
result = await self._async_configure(flow_id, user_input)
flow = self._progress.get(flow_id)
@ -364,7 +352,7 @@ class FlowManager(abc.ABC):
async def _async_configure(
self, flow_id: str, user_input: dict | None = None
) -> FlowResult:
) -> _FlowResultT:
"""Continue a data entry flow."""
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
@ -458,7 +446,7 @@ class FlowManager(abc.ABC):
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -467,7 +455,9 @@ class FlowManager(abc.ABC):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback
def _async_remove_flow_from_index(self, flow: FlowHandler) -> None:
def _async_remove_flow_from_index(
self, flow: BaseFlowHandler[_FlowResultT]
) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -492,17 +482,24 @@ class FlowManager(abc.ABC):
_LOGGER.exception("Error removing %s flow: %s", flow.handler, err)
async def _async_handle_step(
self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None
) -> FlowResult:
self,
flow: BaseFlowHandler[_FlowResultT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
"""Handle a step of a flow."""
self._raise_if_step_does_not_exist(flow, step_id)
method = f"async_step_{step_id}"
try:
result: FlowResult = await getattr(flow, method)(user_input)
result: _FlowResultT = await getattr(flow, method)(user_input)
except AbortFlow as err:
result = _create_abort_data(
flow.flow_id, flow.handler, err.reason, err.description_placeholders
result = self._flow_result(
type=FlowResultType.ABORT,
flow_id=flow.flow_id,
handler=flow.handler,
reason=err.reason,
description_placeholders=err.description_placeholders,
)
# Setup the flow handler's preview if needed
@ -521,7 +518,8 @@ class FlowManager(abc.ABC):
if (
result["type"] == FlowResultType.SHOW_PROGRESS
and (progress_task := result.pop("progress_task", None))
# Mypy does not agree with using pop on _FlowResultT
and (progress_task := result.pop("progress_task", None)) # type: ignore[arg-type]
and progress_task != flow.async_get_progress_task()
):
# The flow's progress task was changed, register a callback on it
@ -532,8 +530,9 @@ class FlowManager(abc.ABC):
def schedule_configure(_: asyncio.Task) -> None:
self.hass.async_create_task(call_configure())
progress_task.add_done_callback(schedule_configure)
flow.async_set_progress_task(progress_task)
# The mypy ignores are a consequence of mypy not accepting the pop above
progress_task.add_done_callback(schedule_configure) # type: ignore[attr-defined]
flow.async_set_progress_task(progress_task) # type: ignore[arg-type]
elif result["type"] != FlowResultType.SHOW_PROGRESS:
flow.async_cancel_progress_task()
@ -560,7 +559,9 @@ class FlowManager(abc.ABC):
return result
def _raise_if_step_does_not_exist(self, flow: FlowHandler, step_id: str) -> None:
def _raise_if_step_does_not_exist(
self, flow: BaseFlowHandler, step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@ -570,18 +571,45 @@ class FlowManager(abc.ABC):
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
)
async def _async_setup_preview(self, flow: FlowHandler) -> None:
async def _async_setup_preview(self, flow: BaseFlowHandler) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
self._preview.add(flow.handler)
await flow.async_setup_preview(self.hass)
@callback
def _async_flow_handler_to_flow_result(
self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
for flow in flows:
if not include_uninitialized and flow.cur_step is None:
continue
result = self._flow_result(
flow_id=flow.flow_id,
handler=flow.handler,
context=flow.context,
)
if flow.cur_step:
result["step_id"] = flow.cur_step["step_id"]
results.append(result)
return results
class FlowHandler:
class FlowManager(BaseFlowManager[FlowResult]):
"""Manage all the flows that are in progress."""
_flow_result = FlowResult
class BaseFlowHandler(Generic[_FlowResultT]):
"""Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT]
# Set by flow manager
cur_step: FlowResult | None = None
cur_step: _FlowResultT | None = None
# While not purely typed, it makes typehinting more useful for us
# and removes the need for constant None checks or asserts.
@ -657,12 +685,12 @@ class FlowHandler:
description_placeholders: Mapping[str, str | None] | None = None,
last_step: bool | None = None,
preview: str | None = None,
) -> FlowResult:
) -> _FlowResultT:
"""Return the definition of a form to gather user input.
The step_id parameter is deprecated and will be removed in a future release.
"""
flow_result = FlowResult(
flow_result = self._flow_result(
type=FlowResultType.FORM,
flow_id=self.flow_id,
handler=self.handler,
@ -684,11 +712,9 @@ class FlowHandler:
data: Mapping[str, Any],
description: str | None = None,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
) -> _FlowResultT:
"""Finish flow."""
flow_result = FlowResult(
version=self.VERSION,
minor_version=self.MINOR_VERSION,
flow_result = self._flow_result(
type=FlowResultType.CREATE_ENTRY,
flow_id=self.flow_id,
handler=self.handler,
@ -707,10 +733,14 @@ class FlowHandler:
*,
reason: str,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
) -> _FlowResultT:
"""Abort the flow."""
return _create_abort_data(
self.flow_id, self.handler, reason, description_placeholders
return self._flow_result(
type=FlowResultType.ABORT,
flow_id=self.flow_id,
handler=self.handler,
reason=reason,
description_placeholders=description_placeholders,
)
@callback
@ -720,12 +750,12 @@ class FlowHandler:
step_id: str | None = None,
url: str,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
) -> _FlowResultT:
"""Return the definition of an external step for the user to take.
The step_id parameter is deprecated and will be removed in a future release.
"""
flow_result = FlowResult(
flow_result = self._flow_result(
type=FlowResultType.EXTERNAL_STEP,
flow_id=self.flow_id,
handler=self.handler,
@ -737,9 +767,9 @@ class FlowHandler:
return flow_result
@callback
def async_external_step_done(self, *, next_step_id: str) -> FlowResult:
def async_external_step_done(self, *, next_step_id: str) -> _FlowResultT:
"""Return the definition of an external step for the user to take."""
return FlowResult(
return self._flow_result(
type=FlowResultType.EXTERNAL_STEP_DONE,
flow_id=self.flow_id,
handler=self.handler,
@ -754,7 +784,7 @@ class FlowHandler:
progress_action: str,
description_placeholders: Mapping[str, str] | None = None,
progress_task: asyncio.Task[Any] | None = None,
) -> FlowResult:
) -> _FlowResultT:
"""Show a progress message to the user, without user input allowed.
The step_id parameter is deprecated and will be removed in a future release.
@ -777,7 +807,7 @@ class FlowHandler:
if progress_task is None:
self.deprecated_show_progress = True
flow_result = FlowResult(
flow_result = self._flow_result(
type=FlowResultType.SHOW_PROGRESS,
flow_id=self.flow_id,
handler=self.handler,
@ -790,9 +820,9 @@ class FlowHandler:
return flow_result
@callback
def async_show_progress_done(self, *, next_step_id: str) -> FlowResult:
def async_show_progress_done(self, *, next_step_id: str) -> _FlowResultT:
"""Mark the progress done."""
return FlowResult(
return self._flow_result(
type=FlowResultType.SHOW_PROGRESS_DONE,
flow_id=self.flow_id,
handler=self.handler,
@ -806,13 +836,13 @@ class FlowHandler:
step_id: str | None = None,
menu_options: list[str] | dict[str, str],
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
) -> _FlowResultT:
"""Show a navigation menu to the user.
Options dict maps step_id => i18n label
The step_id parameter is deprecated and will be removed in a future release.
"""
flow_result = FlowResult(
flow_result = self._flow_result(
type=FlowResultType.MENU,
flow_id=self.flow_id,
handler=self.handler,
@ -853,21 +883,10 @@ class FlowHandler:
self.__progress_task = progress_task
@callback
def _create_abort_data(
flow_id: str,
handler: str,
reason: str,
description_placeholders: Mapping[str, str] | None = None,
) -> FlowResult:
"""Return the definition of an external step for the user to take."""
return FlowResult(
type=FlowResultType.ABORT,
flow_id=flow_id,
handler=handler,
reason=reason,
description_placeholders=description_placeholders,
)
class FlowHandler(BaseFlowHandler[FlowResult]):
"""Handle a data entry flow."""
_flow_result = FlowResult
# These can be removed if no deprecated constant are in this module anymore

View File

@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from homeassistant import config_entries
from homeassistant.components import onboarding
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from .typing import DiscoveryInfoType
@ -46,7 +45,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by the user."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -57,7 +56,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Confirm setup."""
if user_input is None and onboarding.async_is_onboarded(self.hass):
self._set_confirm_only()
@ -87,7 +86,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_discovery(
self, discovery_info: DiscoveryInfoType
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -98,7 +97,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_bluetooth(
self, discovery_info: BluetoothServiceInfoBleak
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by bluetooth discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -107,7 +106,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm()
async def async_step_dhcp(self, discovery_info: DhcpServiceInfo) -> FlowResult:
async def async_step_dhcp(
self, discovery_info: DhcpServiceInfo
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by dhcp discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -118,7 +119,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_homekit(
self, discovery_info: ZeroconfServiceInfo
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by Homekit discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -127,7 +128,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm()
async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult:
async def async_step_mqtt(
self, discovery_info: MqttServiceInfo
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by mqtt discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -138,7 +141,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
async def async_step_zeroconf(
self, discovery_info: ZeroconfServiceInfo
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by Zeroconf discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -147,7 +150,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm()
async def async_step_ssdp(self, discovery_info: SsdpServiceInfo) -> FlowResult:
async def async_step_ssdp(
self, discovery_info: SsdpServiceInfo
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by Ssdp discovery."""
if self._async_in_progress() or self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -156,7 +161,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]):
return await self.async_step_confirm()
async def async_step_import(self, _: dict[str, Any] | None) -> FlowResult:
async def async_step_import(
self, _: dict[str, Any] | None
) -> config_entries.ConfigFlowResult:
"""Handle a flow initialized by import."""
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
@ -205,7 +212,7 @@ class WebhookFlowHandler(config_entries.ConfigFlow):
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a user initiated set up flow to create a webhook."""
if not self._allow_multiple and self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")

View File

@ -25,7 +25,6 @@ from yarl import URL
from homeassistant import config_entries
from homeassistant.components import http
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.loader import async_get_application_credentials
from .aiohttp_client import async_get_clientsession
@ -253,7 +252,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_pick_implementation(
self, user_input: dict | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow start."""
implementations = await async_get_implementations(self.hass, self.DOMAIN)
@ -286,7 +285,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_auth(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Create an entry for auth."""
# Flow has been triggered by external data
if user_input is not None:
@ -314,7 +313,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_creation(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Create config entry from external data."""
_LOGGER.debug("Creating config entry from external data")
@ -353,14 +352,18 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
{"auth_implementation": self.flow_impl.domain, "token": token}
)
async def async_step_authorize_rejected(self, data: None = None) -> FlowResult:
async def async_step_authorize_rejected(
self, data: None = None
) -> config_entries.ConfigFlowResult:
"""Step to handle flow rejection."""
return self.async_abort(
reason="user_rejected_authorize",
description_placeholders={"error": self.external_data["error"]},
)
async def async_oauth_create_entry(self, data: dict) -> FlowResult:
async def async_oauth_create_entry(
self, data: dict
) -> config_entries.ConfigFlowResult:
"""Create an entry for the flow.
Ok to override if you want to fetch extra info or even add another step.
@ -369,7 +372,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> config_entries.ConfigFlowResult:
"""Handle a flow start."""
return await self.async_step_pick_implementation(user_input)

View File

@ -18,7 +18,7 @@ from . import config_validation as cv
class _BaseFlowManagerView(HomeAssistantView):
"""Foundation for flow manager views."""
def __init__(self, flow_mgr: data_entry_flow.FlowManager) -> None:
def __init__(self, flow_mgr: data_entry_flow.BaseFlowManager) -> None:
"""Initialize the flow manager index view."""
self._flow_mgr = flow_mgr

View File

@ -4,9 +4,9 @@ from __future__ import annotations
from collections.abc import Coroutine
from typing import Any, NamedTuple
from homeassistant.config_entries import ConfigFlowResult
from homeassistant.const import EVENT_HOMEASSISTANT_STARTED
from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.loader import bind_hass
from homeassistant.util.async_ import gather_with_limited_concurrency
@ -40,7 +40,7 @@ def async_create_flow(
@callback
def _async_init_flow(
hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any
) -> Coroutine[None, None, FlowResult] | None:
) -> Coroutine[None, None, ConfigFlowResult] | None:
"""Create a discovery flow."""
# Avoid spawning flows that have the same initial discovery data
# as ones in progress as it may cause additional device probing

View File

@ -10,9 +10,15 @@ from typing import Any, cast
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
OptionsFlowWithConfigEntry,
)
from homeassistant.core import HomeAssistant, callback, split_entity_id
from homeassistant.data_entry_flow import FlowResult, UnknownHandler
from homeassistant.data_entry_flow import UnknownHandler
from . import entity_registry as er, selector
from .typing import UNDEFINED, UndefinedType
@ -126,7 +132,7 @@ class SchemaCommonFlowHandler:
async def async_step(
self, step_id: str, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle a step."""
if isinstance(self._flow[step_id], SchemaFlowFormStep):
return await self._async_form_step(step_id, user_input)
@ -141,7 +147,7 @@ class SchemaCommonFlowHandler:
async def _async_form_step(
self, step_id: str, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle a form step."""
form_step: SchemaFlowFormStep = cast(SchemaFlowFormStep, self._flow[step_id])
@ -204,7 +210,7 @@ class SchemaCommonFlowHandler:
async def _show_next_step_or_create_entry(
self, form_step: SchemaFlowFormStep
) -> FlowResult:
) -> ConfigFlowResult:
next_step_id_or_end_flow: str | None
if callable(form_step.next_step):
@ -222,7 +228,7 @@ class SchemaCommonFlowHandler:
next_step_id: str,
error: SchemaFlowError | None = None,
user_input: dict[str, Any] | None = None,
) -> FlowResult:
) -> ConfigFlowResult:
"""Show form for next step."""
if isinstance(self._flow[next_step_id], SchemaFlowMenuStep):
menu_step = cast(SchemaFlowMenuStep, self._flow[next_step_id])
@ -271,7 +277,7 @@ class SchemaCommonFlowHandler:
async def _async_menu_step(
self, step_id: str, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle a menu step."""
menu_step: SchemaFlowMenuStep = cast(SchemaFlowMenuStep, self._flow[step_id])
return self._handler.async_show_menu(
@ -280,7 +286,7 @@ class SchemaCommonFlowHandler:
)
class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
class SchemaConfigFlowHandler(ConfigFlow, ABC):
"""Handle a schema based config flow."""
config_flow: Mapping[str, SchemaFlowStep]
@ -294,8 +300,8 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
@callback
def _async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> config_entries.OptionsFlow:
config_entry: ConfigEntry,
) -> OptionsFlow:
"""Get the options flow for this handler."""
if cls.options_flow is None:
raise UnknownHandler
@ -324,9 +330,7 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
@classmethod
@callback
def async_supports_options_flow(
cls, config_entry: config_entries.ConfigEntry
) -> bool:
def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool:
"""Return options flow support for this handler."""
return cls.options_flow is not None
@ -335,13 +339,13 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
step_id: str,
) -> Callable[
[SchemaConfigFlowHandler, dict[str, Any] | None],
Coroutine[Any, Any, FlowResult],
Coroutine[Any, Any, ConfigFlowResult],
]:
"""Generate a step handler."""
async def _async_step(
self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle a config flow step."""
# pylint: disable-next=protected-access
result = await self._common_handler.async_step(step_id, user_input)
@ -382,7 +386,7 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
self,
data: Mapping[str, Any],
**kwargs: Any,
) -> FlowResult:
) -> ConfigFlowResult:
"""Finish config flow and create a config entry."""
self.async_config_flow_finished(data)
return super().async_create_entry(
@ -390,12 +394,12 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC):
)
class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
class SchemaOptionsFlowHandler(OptionsFlowWithConfigEntry):
"""Handle a schema based options flow."""
def __init__(
self,
config_entry: config_entries.ConfigEntry,
config_entry: ConfigEntry,
options_flow: Mapping[str, SchemaFlowStep],
async_options_flow_finished: Callable[[HomeAssistant, Mapping[str, Any]], None]
| None = None,
@ -430,13 +434,13 @@ class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
step_id: str,
) -> Callable[
[SchemaConfigFlowHandler, dict[str, Any] | None],
Coroutine[Any, Any, FlowResult],
Coroutine[Any, Any, ConfigFlowResult],
]:
"""Generate a step handler."""
async def _async_step(
self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle an options flow step."""
# pylint: disable-next=protected-access
result = await self._common_handler.async_step(step_id, user_input)
@ -449,7 +453,7 @@ class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
self,
data: Mapping[str, Any],
**kwargs: Any,
) -> FlowResult:
) -> ConfigFlowResult:
"""Finish config flow and create a config entry."""
if self._async_options_flow_finished:
self._async_options_flow_finished(self.hass, data)

View File

@ -55,11 +55,12 @@ class TypeHintMatch:
)
@dataclass
@dataclass(kw_only=True)
class ClassTypeHintMatch:
"""Class for pattern matching."""
base_class: str
exclude_base_classes: set[str] | None = None
matches: list[TypeHintMatch]
@ -481,6 +482,7 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
"config_flow": [
ClassTypeHintMatch(
base_class="FlowHandler",
exclude_base_classes={"ConfigEntryBaseFlow"},
matches=[
TypeHintMatch(
function_name="async_step_*",
@ -492,6 +494,11 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
ClassTypeHintMatch(
base_class="ConfigFlow",
matches=[
TypeHintMatch(
function_name="async_step123_*",
arg_types={},
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_get_options_flow",
arg_types={
@ -504,56 +511,66 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
arg_types={
1: "DhcpServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_hassio",
arg_types={
1: "HassioServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_homekit",
arg_types={
1: "ZeroconfServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_mqtt",
arg_types={
1: "MqttServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_reauth",
arg_types={
1: "Mapping[str, Any]",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_ssdp",
arg_types={
1: "SsdpServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_usb",
arg_types={
1: "UsbServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
TypeHintMatch(
function_name="async_step_zeroconf",
arg_types={
1: "ZeroconfServiceInfo",
},
return_type="FlowResult",
return_type=["ConfigFlowResult", "FlowResult"],
),
],
),
ClassTypeHintMatch(
base_class="OptionsFlow",
matches=[
TypeHintMatch(
function_name="async_step_*",
arg_types={},
return_type=["ConfigFlowResult", "FlowResult"],
),
],
),
@ -3126,11 +3143,19 @@ class HassTypeHintChecker(BaseChecker):
ancestor: nodes.ClassDef
checked_class_methods: set[str] = set()
ancestors = list(node.ancestors()) # cache result for inside loop
for class_matches in self._class_matchers:
for class_matcher in self._class_matchers:
skip_matcher = False
if exclude_base_classes := class_matcher.exclude_base_classes:
for ancestor in ancestors:
if ancestor.name in exclude_base_classes:
skip_matcher = True
break
if skip_matcher:
continue
for ancestor in ancestors:
if ancestor.name == class_matches.base_class:
if ancestor.name == class_matcher.base_class:
self._visit_class_functions(
node, class_matches.matches, checked_class_methods
node, class_matcher.matches, checked_class_methods
)
def _visit_class_functions(

View File

@ -6,10 +6,9 @@ from typing import Any
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError
from .const import DOMAIN
@ -68,14 +67,14 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
return {"title": "Name of the device"}
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
class ConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for NEW_NAME."""
VERSION = 1
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
) -> ConfigFlowResult:
"""Handle the initial step."""
errors: dict[str, str] = {}
if user_input is not None:

View File

@ -147,13 +147,11 @@ async def test_legacy_subscription_repair_flow(
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": DOMAIN,
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(

View File

@ -941,10 +941,8 @@ async def test_two_step_options_flow(hass: HomeAssistant, client) -> None:
"handler": "test1",
"type": "create_entry",
"title": "Enable disable",
"version": 1,
"description": None,
"description_placeholders": None,
"minor_version": 1,
}

View File

@ -94,13 +94,11 @@ async def test_supervisor_issue_repair_flow(
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": "hassio",
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -190,13 +188,11 @@ async def test_supervisor_issue_repair_flow_with_multiple_suggestions(
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": "hassio",
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -305,13 +301,11 @@ async def test_supervisor_issue_repair_flow_with_multiple_suggestions_and_confir
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": "hassio",
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -386,13 +380,11 @@ async def test_supervisor_issue_repair_flow_skip_confirmation(
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": "hassio",
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -486,13 +478,11 @@ async def test_mount_failed_repair_flow(
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": "hassio",
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")
@ -598,13 +588,11 @@ async def test_supervisor_issue_docker_config_repair_flow(
flow_id = data["flow_id"]
assert data == {
"version": 1,
"type": "create_entry",
"flow_id": flow_id,
"handler": "hassio",
"description": None,
"description_placeholders": None,
"minor_version": 1,
}
assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234")

View File

@ -244,9 +244,7 @@ async def test_issues_created(
"description_placeholders": None,
"flow_id": flow_id,
"handler": DOMAIN,
"minor_version": 1,
"type": "create_entry",
"version": 1,
}
await ws_client.send_json({"id": 4, "type": "repairs/list_issues"})

View File

@ -338,9 +338,7 @@ async def test_fix_issue(
"description_placeholders": None,
"flow_id": flow_id,
"handler": domain,
"minor_version": 1,
"type": "create_entry",
"version": 1,
}
await ws_client.send_json({"id": 4, "type": "repairs/list_issues"})

View File

@ -63,7 +63,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
"""Test existing flows prevent an identical ones from being after startup."""
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
with patch(
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
"homeassistant.data_entry_flow.BaseFlowManager.async_has_matching_flow",
return_value=True,
):
discovery_flow.async_create_flow(

View File

@ -45,7 +45,7 @@ def manager_fixture():
handlers = Registry()
entries = []
class FlowManager(data_entry_flow.FlowManager):
class FlowManager(data_entry_flow.BaseFlowManager):
"""Test flow manager."""
async def async_create_flow(self, handler_key, *, context, data):
@ -105,7 +105,7 @@ async def test_name(hass: HomeAssistant, entity_registry: er.EntityRegistry) ->
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_config_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker
) -> None:
"""Test handling of advanced options in config flow."""
manager.hass = hass
@ -200,7 +200,7 @@ async def test_config_flow_advanced_option(
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_options_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker
) -> None:
"""Test handling of advanced options in options flow."""
manager.hass = hass
@ -475,7 +475,7 @@ async def test_next_step_function(hass: HomeAssistant) -> None:
async def test_suggested_values(
hass: HomeAssistant, manager: data_entry_flow.FlowManager
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
) -> None:
"""Test suggested_values handling in SchemaFlowFormStep."""
manager.hass = hass
@ -667,7 +667,7 @@ async def test_options_flow_state(hass: HomeAssistant) -> None:
async def test_options_flow_omit_optional_keys(
hass: HomeAssistant, manager: data_entry_flow.FlowManager
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
) -> None:
"""Test handling of advanced options in options flow."""
manager.hass = hass

View File

@ -346,7 +346,7 @@ def test_invalid_config_flow_step(
pylint.testutils.MessageTest(
msg_id="hass-return-type",
node=func_node,
args=("FlowResult", "async_step_zeroconf"),
args=(["ConfigFlowResult", "FlowResult"], "async_step_zeroconf"),
line=11,
col_offset=4,
end_line=11,
@ -374,7 +374,7 @@ def test_valid_config_flow_step(
async def async_step_zeroconf(
self,
device_config: ZeroconfServiceInfo
) -> FlowResult:
) -> ConfigFlowResult:
pass
""",
"homeassistant.components.pylint_test.config_flow",

View File

@ -24,9 +24,11 @@ def manager():
handlers = Registry()
entries = []
class FlowManager(data_entry_flow.FlowManager):
class FlowManager(data_entry_flow.BaseFlowManager):
"""Test flow manager."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow."""
handler = handlers.get(handler_key)
@ -79,7 +81,7 @@ async def test_configure_reuses_handler_instance(manager) -> None:
assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None:
async def test_configure_two_steps(manager: data_entry_flow.BaseFlowManager) -> None:
"""Test that we reuse instances."""
@manager.mock_reg_handler("test")
@ -211,7 +213,6 @@ async def test_create_saves_data(manager) -> None:
assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0]
assert entry["version"] == 5
assert entry["handler"] == "test"
assert entry["title"] == "Test Title"
assert entry["data"] == "Test Data"
@ -237,7 +238,6 @@ async def test_discovery_init_flow(manager) -> None:
assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0]
assert entry["version"] == 5
assert entry["handler"] == "test"
assert entry["title"] == "hello"
assert entry["data"] == data
@ -258,7 +258,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None:
step_id="init", data_schema=vol.Schema({"count": int})
)
class FlowManager(data_entry_flow.FlowManager):
class FlowManager(data_entry_flow.BaseFlowManager):
async def async_create_flow(self, handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
@ -775,7 +775,7 @@ async def test_async_get_unknown_flow(manager) -> None:
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.FlowManager
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
) -> None:
"""Test we can check for matching flows."""
manager.hass = hass
@ -951,7 +951,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None:
async def test_find_flows_by_init_data_type(
manager: data_entry_flow.FlowManager,
manager: data_entry_flow.BaseFlowManager,
) -> None:
"""Test we can find flows by init data type."""