1
mirror of https://github.com/home-assistant/core synced 2024-09-06 10:29:55 +02:00
ha-core/tests/test_data_entry_flow.py
Jason Hu 00c6f56cc8
Allow finish_flow callback to change data entry result type (#16100)
* Allow finish_flow callback to change data entry result type

* Add unit test
2018-08-21 10:48:24 -07:00

248 lines
8.2 KiB
Python

"""Test the flow classes."""
import pytest
import voluptuous as vol
from homeassistant import data_entry_flow
from homeassistant.util.decorator import Registry
@pytest.fixture
def manager():
"""Return a flow manager."""
handlers = Registry()
entries = []
async def async_create_flow(handler_name, *, context, data):
handler = handlers.get(handler_name)
if handler is None:
raise data_entry_flow.UnknownHandler
flow = handler()
flow.init_step = context.get('init_step', 'init') \
if context is not None else 'init'
flow.source = context.get('source') \
if context is not None else 'user_input'
return flow
async def async_add_entry(flow, result):
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
result['source'] = flow.context.get('source') \
if flow.context is not None else 'user'
entries.append(result)
return result
manager = data_entry_flow.FlowManager(
None, async_create_flow, async_add_entry)
manager.mock_created_entries = entries
manager.mock_reg_handler = handlers.register
return manager
async def test_configure_reuses_handler_instance(manager):
"""Test that we reuse instances."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
handle_count = 0
async def async_step_init(self, user_input=None):
self.handle_count += 1
return self.async_show_form(
errors={'base': str(self.handle_count)},
step_id='init')
form = await manager.async_init('test')
assert form['errors']['base'] == '1'
form = await manager.async_configure(form['flow_id'])
assert form['errors']['base'] == '2'
assert len(manager.async_progress()) == 1
assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager):
"""Test that we reuse instances."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_first(self, user_input=None):
if user_input is not None:
self.init_data = user_input
return await self.async_step_second()
return self.async_show_form(
step_id='first',
data_schema=vol.Schema([str])
)
async def async_step_second(self, user_input=None):
if user_input is not None:
return self.async_create_entry(
title='Test Entry',
data=self.init_data + user_input
)
return self.async_show_form(
step_id='second',
data_schema=vol.Schema([str])
)
form = await manager.async_init('test', context={'init_step': 'first'})
with pytest.raises(vol.Invalid):
form = await manager.async_configure(
form['flow_id'], 'INCORRECT-DATA')
form = await manager.async_configure(
form['flow_id'], ['INIT-DATA'])
form = await manager.async_configure(
form['flow_id'], ['SECOND-DATA'])
assert form['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
result = manager.mock_created_entries[0]
assert result['handler'] == 'test'
assert result['data'] == ['INIT-DATA', 'SECOND-DATA']
async def test_show_form(manager):
"""Test that abort removes the flow from progress."""
schema = vol.Schema({
vol.Required('username'): str,
vol.Required('password'): str
})
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_show_form(
step_id='init',
data_schema=schema,
errors={
'username': 'Should be unique.'
}
)
form = await manager.async_init('test')
assert form['type'] == 'form'
assert form['data_schema'] is schema
assert form['errors'] == {
'username': 'Should be unique.'
}
async def test_abort_removes_instance(manager):
"""Test that abort removes the flow from progress."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
is_new = True
async def async_step_init(self, user_input=None):
old = self.is_new
self.is_new = False
return self.async_abort(reason=str(old))
form = await manager.async_init('test')
assert form['reason'] == 'True'
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
form = await manager.async_init('test')
assert form['reason'] == 'True'
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0
async def test_create_saves_data(manager):
"""Test creating a config entry."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_init(self, user_input=None):
return self.async_create_entry(
title='Test Title',
data='Test Data'
)
await manager.async_init('test')
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0]
assert entry['version'] == 5
assert entry['handler'] == 'test'
assert entry['title'] == 'Test Title'
assert entry['data'] == 'Test Data'
assert entry['source'] == 'user'
async def test_discovery_init_flow(manager):
"""Test a flow initialized by discovery."""
@manager.mock_reg_handler('test')
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5
async def async_step_init(self, info):
return self.async_create_entry(title=info['id'], data=info)
data = {
'id': 'hello',
'token': 'secret'
}
await manager.async_init(
'test', context={'source': 'discovery'}, data=data)
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1
entry = manager.mock_created_entries[0]
assert entry['version'] == 5
assert entry['handler'] == 'test'
assert entry['title'] == 'hello'
assert entry['data'] == data
assert entry['source'] == 'discovery'
async def test_finish_callback_change_result_type(hass):
"""Test finish callback can change result type."""
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_init(self, input):
"""Return init form with one input field 'count'."""
if input is not None:
return self.async_create_entry(title='init', data=input)
return self.async_show_form(
step_id='init',
data_schema=vol.Schema({'count': int}))
async def async_create_flow(handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
async def async_finish_flow(flow, result):
"""Redirect to init form if count <= 1."""
if result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
if (result['data'] is None or
result['data'].get('count', 0) <= 1):
return flow.async_show_form(
step_id='init',
data_schema=vol.Schema({'count': int}))
else:
result['result'] = result['data']['count']
return result
manager = data_entry_flow.FlowManager(
hass, async_create_flow, async_finish_flow)
result = await manager.async_init('test')
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'init'
result = await manager.async_configure(result['flow_id'], {'count': 0})
assert result['type'] == data_entry_flow.RESULT_TYPE_FORM
assert result['step_id'] == 'init'
assert 'result' not in result
result = await manager.async_configure(result['flow_id'], {'count': 2})
assert result['type'] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result['result'] == 2