Improve calls to async_show_progress in improv_ble (#107790)

This commit is contained in:
Erik Montnemery 2024-01-14 09:37:54 +01:00 committed by GitHub
parent 93d363ea57
commit d4cb055d75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 35 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
import logging
from typing import Any, TypeVar
@ -325,14 +325,15 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return
if not self._provision_task:
self._provision_task = self.hass.async_create_task(
self._resume_flow_when_done(_do_provision())
)
self._provision_task = self.hass.async_create_task(_do_provision())
if not self._provision_task.done():
return self.async_show_progress(
step_id="do_provision", progress_action="provisioning"
step_id="do_provision",
progress_action="provisioning",
progress_task=self._provision_task,
)
await self._provision_task
self._provision_task = None
return self.async_show_progress_done(next_step_id="provision_done")
@ -347,14 +348,6 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self._provision_result = None
return result
async def _resume_flow_when_done(self, awaitable: Awaitable) -> None:
try:
await awaitable
finally:
self.hass.async_create_task(
self.hass.config_entries.flow.async_configure(flow_id=self.flow_id)
)
async def async_step_authorize(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
@ -378,14 +371,15 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
except AbortFlow as err:
return self.async_abort(reason=err.reason)
self._authorize_task = self.hass.async_create_task(
self._resume_flow_when_done(authorized_event.wait())
)
self._authorize_task = self.hass.async_create_task(authorized_event.wait())
if not self._authorize_task.done():
return self.async_show_progress(
step_id="authorize", progress_action="authorize"
step_id="authorize",
progress_action="authorize",
progress_task=self._authorize_task,
)
await self._authorize_task
self._authorize_task = None
if self._unsub:
self._unsub()

View File

@ -265,10 +265,7 @@ async def _test_common_success(
assert result["type"] == FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "provisioning"
assert result["step_id"] == "do_provision"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision_done"
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result.get("description_placeholders") == placeholders
@ -321,10 +318,7 @@ async def _test_common_success_w_authorize(
assert result["progress_action"] == "authorize"
assert result["step_id"] == "authorize"
mock_subscribe_state_updates.assert_awaited_once()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision"
await hass.async_block_till_done()
with patch(
f"{IMPROV_BLE}.config_flow.ImprovBLEClient.need_authorization",
@ -337,10 +331,7 @@ async def _test_common_success_w_authorize(
assert result["type"] == FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "provisioning"
assert result["step_id"] == "do_provision"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision_done"
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["description_placeholders"] == {"url": "http://blabla.local"}
@ -578,10 +569,7 @@ async def _test_provision_error(hass: HomeAssistant, exc) -> None:
assert result["type"] == FlowResultType.SHOW_PROGRESS
assert result["progress_action"] == "provisioning"
assert result["step_id"] == "do_provision"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.SHOW_PROGRESS_DONE
assert result["step_id"] == "provision_done"
await hass.async_block_till_done()
return result["flow_id"]