Add data entry flow helper (#13935)

* Extract data entry flows HTTP views into helper

* Remove use of domain

* Lint

* Fix tests

* Update doc
This commit is contained in:
Paulus Schoutsen 2018-04-17 05:44:32 -04:00 committed by Pascal Vizeli
parent 6e9669c18d
commit 534aa0e4b5
5 changed files with 132 additions and 84 deletions

View File

@ -1,11 +1,10 @@
"""Http views to control the config manager."""
import asyncio
import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.helpers.data_entry_flow import (
FlowManagerIndexView, FlowManagerResourceView)
REQUIREMENTS = ['voluptuous-serialize==1']
@ -16,8 +15,10 @@ def async_setup(hass):
"""Enable the Home Assistant views."""
hass.http.register_view(ConfigManagerEntryIndexView)
hass.http.register_view(ConfigManagerEntryResourceView)
hass.http.register_view(ConfigManagerFlowIndexView)
hass.http.register_view(ConfigManagerFlowResourceView)
hass.http.register_view(
ConfigManagerFlowIndexView(hass.config_entries.flow))
hass.http.register_view(
ConfigManagerFlowResourceView(hass.config_entries.flow))
hass.http.register_view(ConfigManagerAvailableFlowView)
return True
@ -78,7 +79,7 @@ class ConfigManagerEntryResourceView(HomeAssistantView):
return self.json(result)
class ConfigManagerFlowIndexView(HomeAssistantView):
class ConfigManagerFlowIndexView(FlowManagerIndexView):
"""View to create config flows."""
url = '/api/config/config_entries/flow'
@ -97,78 +98,13 @@ class ConfigManagerFlowIndexView(HomeAssistantView):
flw for flw in hass.config_entries.flow.async_progress()
if flw['source'] != data_entry_flow.SOURCE_USER])
@RequestDataValidator(vol.Schema({
vol.Required('domain'): str,
}))
@asyncio.coroutine
def post(self, request, data):
"""Handle a POST request."""
hass = request.app['hass']
try:
result = yield from hass.config_entries.flow.async_init(
data['domain'])
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except data_entry_flow.UnknownStep:
return self.json_message('Handler does not support init', 400)
result = _prepare_json(result)
return self.json(result)
class ConfigManagerFlowResourceView(HomeAssistantView):
class ConfigManagerFlowResourceView(FlowManagerResourceView):
"""View to interact with the flow manager."""
url = '/api/config/config_entries/flow/{flow_id}'
name = 'api:config:config_entries:flow:resource'
@asyncio.coroutine
def get(self, request, flow_id):
"""Get the current state of a data_entry_flow."""
hass = request.app['hass']
try:
result = yield from hass.config_entries.flow.async_configure(
flow_id)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
result = _prepare_json(result)
return self.json(result)
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
@asyncio.coroutine
def post(self, request, flow_id, data):
"""Handle a POST request."""
hass = request.app['hass']
try:
result = yield from hass.config_entries.flow.async_configure(
flow_id, data)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
except vol.Invalid:
return self.json_message('User input malformed', 400)
result = _prepare_json(result)
return self.json(result)
@asyncio.coroutine
def delete(self, request, flow_id):
"""Cancel a flow in progress."""
hass = request.app['hass']
try:
hass.config_entries.flow.async_abort(flow_id)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
return self.json_message('Flow aborted')
class ConfigManagerAvailableFlowView(HomeAssistantView):
"""View to query available flows."""

View File

@ -338,7 +338,7 @@ class ConfigEntries:
if component not in self.hass.config.components:
return True
await entry.async_unload(
return await entry.async_unload(
self.hass, component=getattr(self.hass.components, component))
async def _async_save_entry(self, result):
@ -362,6 +362,8 @@ class ConfigEntries:
await async_setup_component(
self.hass, entry.domain, self._hass_config)
return entry
async def _async_create_flow(self, handler):
"""Create a flow for specified handler.

View File

@ -34,12 +34,12 @@ class UnknownStep(FlowError):
class FlowManager:
"""Manage all the flows that are in progress."""
def __init__(self, hass, async_create_flow, async_save_entry):
def __init__(self, hass, async_create_flow, async_finish_flow):
"""Initialize the flow manager."""
self.hass = hass
self._progress = {}
self._async_create_flow = async_create_flow
self._async_save_entry = async_save_entry
self._async_finish_flow = async_finish_flow
@callback
def async_progress(self):
@ -113,10 +113,8 @@ class FlowManager:
if result['type'] == RESULT_TYPE_ABORT:
return result
# We pass a copy of the result because we're going to mutate our
# version afterwards and don't want to cause unexpected bugs.
await self._async_save_entry(dict(result))
result.pop('data')
# We pass a copy of the result because we're mutating our version
result['result'] = await self._async_finish_flow(dict(result))
return result

View File

@ -0,0 +1,106 @@
"""Helpers for the data entry flow."""
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.data_validator import RequestDataValidator
def _prepare_json(result):
"""Convert result for JSON."""
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy()
data.pop('result')
data.pop('data')
return data
elif result['type'] != data_entry_flow.RESULT_TYPE_FORM:
return result
import voluptuous_serialize
data = result.copy()
schema = data['data_schema']
if schema is None:
data['data_schema'] = []
else:
data['data_schema'] = voluptuous_serialize.convert(schema)
return data
class FlowManagerIndexView(HomeAssistantView):
"""View to create config flows."""
def __init__(self, flow_mgr):
"""Initialize the flow manager index view."""
self._flow_mgr = flow_mgr
async def get(self, request):
"""List flows that are in progress."""
return self.json(self._flow_mgr.async_progress())
@RequestDataValidator(vol.Schema({
vol.Required('handler'): vol.Any(str, list),
}))
async def post(self, request, data):
"""Handle a POST request."""
if isinstance(data['handler'], list):
handler = tuple(data['handler'])
else:
handler = data['handler']
try:
result = await self._flow_mgr.async_init(handler)
except data_entry_flow.UnknownHandler:
return self.json_message('Invalid handler specified', 404)
except data_entry_flow.UnknownStep:
return self.json_message('Handler does not support init', 400)
result = _prepare_json(result)
return self.json(result)
class FlowManagerResourceView(HomeAssistantView):
"""View to interact with the flow manager."""
def __init__(self, flow_mgr):
"""Initialize the flow manager resource view."""
self._flow_mgr = flow_mgr
async def get(self, request, flow_id):
"""Get the current state of a data_entry_flow."""
try:
result = await self._flow_mgr.async_configure(flow_id)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
result = _prepare_json(result)
return self.json(result)
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(self, request, flow_id, data):
"""Handle a POST request."""
try:
result = await self._flow_mgr.async_configure(flow_id, data)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
except vol.Invalid:
return self.json_message('User input malformed', 400)
result = _prepare_json(result)
return self.json(result)
async def delete(self, request, flow_id):
"""Cancel a flow in progress."""
try:
self._flow_mgr.async_abort(flow_id)
except data_entry_flow.UnknownFlow:
return self.json_message('Invalid flow specified', 404)
return self.json_message('Flow aborted')

View File

@ -17,6 +17,12 @@ from homeassistant.loader import set_component
from tests.common import MockConfigEntry, MockModule, mock_coro_func
@pytest.fixture(scope='session', autouse=True)
def mock_test_component():
"""Ensure a component called 'test' exists."""
set_component('test', MockModule('test'))
@pytest.fixture
def client(hass, aiohttp_client):
"""Fixture that can interact with the config manager API."""
@ -111,7 +117,7 @@ def test_initialize_flow(hass, client):
with patch.dict(HANDLERS, {'test': TestFlow}):
resp = yield from client.post('/api/config/config_entries/flow',
json={'domain': 'test'})
json={'handler': 'test'})
assert resp.status == 200
data = yield from resp.json()
@ -150,7 +156,7 @@ def test_abort(hass, client):
with patch.dict(HANDLERS, {'test': TestFlow}):
resp = yield from client.post('/api/config/config_entries/flow',
json={'domain': 'test'})
json={'handler': 'test'})
assert resp.status == 200
data = yield from resp.json()
@ -180,7 +186,7 @@ def test_create_account(hass, client):
with patch.dict(HANDLERS, {'test': TestFlow}):
resp = yield from client.post('/api/config/config_entries/flow',
json={'domain': 'test'})
json={'handler': 'test'})
assert resp.status == 200
data = yield from resp.json()
@ -220,7 +226,7 @@ def test_two_step_flow(hass, client):
with patch.dict(HANDLERS, {'test': TestFlow}):
resp = yield from client.post('/api/config/config_entries/flow',
json={'domain': 'test'})
json={'handler': 'test'})
assert resp.status == 200
data = yield from resp.json()
flow_id = data.pop('flow_id')
@ -305,7 +311,7 @@ def test_get_progress_flow(hass, client):
with patch.dict(HANDLERS, {'test': TestFlow}):
resp = yield from client.post('/api/config/config_entries/flow',
json={'domain': 'test'})
json={'handler': 'test'})
assert resp.status == 200
data = yield from resp.json()