Allow chaining contexts (#21028)

* Allow chaining contexts

* Add stubbed out migration
This commit is contained in:
Paulus Schoutsen 2019-03-01 10:08:38 -08:00 committed by GitHub
parent b39846fb6b
commit 52f337ef00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 88 additions and 39 deletions

View File

@ -7,7 +7,7 @@ import logging
import voluptuous as vol
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.core import CoreState
from homeassistant.core import CoreState, Context
from homeassistant.loader import bind_hass
from homeassistant.const import (
ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF,
@ -280,15 +280,21 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
This method is a coroutine.
"""
if skip_condition or self._cond_func(variables):
self.async_set_context(context)
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
ATTR_NAME: self._name,
ATTR_ENTITY_ID: self.entity_id,
}, context=context)
await self._async_action(self.entity_id, variables, context)
self._last_triggered = utcnow()
await self.async_update_ha_state()
if not skip_condition and not self._cond_func(variables):
return
# Create a new context referring to the old context.
parent_id = None if context is None else context.id
trigger_context = Context(parent_id=parent_id)
self.async_set_context(trigger_context)
self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, {
ATTR_NAME: self._name,
ATTR_ENTITY_ID: self.entity_id,
}, context=trigger_context)
await self._async_action(self.entity_id, variables, trigger_context)
self._last_triggered = utcnow()
await self.async_update_ha_state()
async def async_will_remove_from_hass(self):
"""Remove listeners when removing automation from HASS."""

View File

@ -220,6 +220,15 @@ def _apply_update(engine, new_version, old_version):
_create_index(engine, "states", "ix_states_context_user_id")
elif new_version == 7:
_create_index(engine, "states", "ix_states_entity_id")
elif new_version == 8:
# Pending migration, want to group a few.
pass
# _add_columns(engine, "events", [
# 'context_parent_id CHARACTER(36)',
# ])
# _add_columns(engine, "states", [
# 'context_parent_id CHARACTER(36)',
# ])
else:
raise ValueError("No schema migration defined for version {}"
.format(new_version))

View File

@ -34,16 +34,20 @@ class Events(Base): # type: ignore
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)
# context_parent_id = Column(String(36), index=True)
@staticmethod
def from_event(event):
"""Create an event database object from a native event."""
return Events(event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id)
return Events(
event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id,
# context_parent_id=event.context.parent_id,
)
def to_native(self):
"""Convert to a natve HA Event."""
@ -81,6 +85,7 @@ class States(Base): # type: ignore
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)
# context_parent_id = Column(String(36), index=True)
__table_args__ = (
# Used for fetching the state of entities at a specific time
@ -99,6 +104,7 @@ class States(Base): # type: ignore
entity_id=entity_id,
context_id=event.context.id,
context_user_id=event.context.user_id,
# context_parent_id=event.context.parent_id,
)
# State got deleted

View File

@ -409,6 +409,10 @@ class Context:
type=str,
default=None,
)
parent_id = attr.ib(
type=Optional[str],
default=None
)
id = attr.ib(
type=str,
default=attr.Factory(lambda: uuid.uuid4().hex),
@ -418,6 +422,7 @@ class Context:
"""Return a dictionary representation of the context."""
return {
'id': self.id,
'parent_id': self.parent_id,
'user_id': self.user_id,
}

View File

@ -41,7 +41,7 @@ async def test_if_fires_on_event(hass, calls):
hass.bus.async_fire('test_event', context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
await common.async_turn_off(hass)
await hass.async_block_till_done()

View File

@ -68,7 +68,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'geo_location - geo_location.entity - hello - hello - test' == \
calls[0].data['some']
@ -221,7 +221,7 @@ async def test_if_fires_on_zone_appear(hass, calls):
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'geo_location - geo_location.entity - - hello - test' == \
calls[0].data['some']

View File

@ -369,38 +369,47 @@ async def test_shared_context(hass, calls):
})
context = Context()
automation_mock = Mock()
first_automation_listener = Mock()
event_mock = Mock()
hass.bus.async_listen('test_event2', automation_mock)
hass.bus.async_listen('test_event2', first_automation_listener)
hass.bus.async_listen(EVENT_AUTOMATION_TRIGGERED, event_mock)
hass.bus.async_fire('test_event', context=context)
await hass.async_block_till_done()
# Ensure events was fired
assert automation_mock.call_count == 1
assert first_automation_listener.call_count == 1
assert event_mock.call_count == 2
# Ensure context carries through the event
args, kwargs = automation_mock.call_args
assert args[0].context == context
# Verify automation triggered evenet for 'hello' automation
args, kwargs = event_mock.call_args_list[0]
first_trigger_context = args[0].context
assert first_trigger_context.parent_id == context.id
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None
for call in event_mock.call_args_list:
args, kwargs = call
assert args[0].context == context
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None
# Ensure context set correctly for event fired by 'hello' automation
args, kwargs = first_automation_listener.call_args
assert args[0].context is first_trigger_context
# Ensure the automation state shares the same context
# Ensure the 'hello' automation state has the right context
state = hass.states.get('automation.hello')
assert state is not None
assert state.context == context
assert state.context is first_trigger_context
# Verify automation triggered evenet for 'bye' automation
args, kwargs = event_mock.call_args_list[1]
second_trigger_context = args[0].context
assert second_trigger_context.parent_id == first_trigger_context.id
# Ensure event data has all attributes set
assert args[0].data.get(ATTR_NAME) is not None
assert args[0].data.get(ATTR_ENTITY_ID) is not None
# Ensure the service call from the second automation
# shares the same context
assert len(calls) == 1
assert calls[0].context == context
assert calls[0].context is second_trigger_context
async def test_services(hass, calls):

View File

@ -45,7 +45,7 @@ async def test_if_fires_on_entity_change_below(hass, calls):
hass.states.async_set('test.entity', 9, context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
# Set above 12 so the automation will fire again
hass.states.async_set('test.entity', 12)
@ -134,7 +134,7 @@ async def test_if_not_fires_on_entity_change_below_to_below(hass, calls):
hass.states.async_set('test.entity', 9, context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
# already below so should not fire again
hass.states.async_set('test.entity', 5)

View File

@ -55,7 +55,7 @@ async def test_if_fires_on_entity_change(hass, calls):
hass.states.async_set('test.entity', 'world', context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'state - test.entity - hello - world - None' == \
calls[0].data['some']

View File

@ -257,7 +257,7 @@ async def test_if_fires_on_change_with_template_advanced(hass, calls):
hass.states.async_set('test.entity', 'world', context=context)
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'template - test.entity - hello - world' == \
calls[0].data['some']

View File

@ -66,7 +66,7 @@ async def test_if_fires_on_zone_enter(hass, calls):
await hass.async_block_till_done()
assert 1 == len(calls)
assert calls[0].context is context
assert calls[0].context.parent_id == context.id
assert 'zone - test.entity - hello - hello - test' == \
calls[0].data['some']

View File

@ -310,6 +310,7 @@ class TestEvent(unittest.TestCase):
'time_fired': now,
'context': {
'id': event.context.id,
'parent_id': None,
'user_id': event.context.user_id,
},
}
@ -1076,3 +1077,16 @@ async def test_service_call_event_contains_original_data(hass):
assert len(calls) == 1
assert calls[0].data['number'] == 23
assert calls[0].context is context
def test_context():
"""Test context init."""
c = ha.Context()
assert c.user_id is None
assert c.parent_id is None
assert c.id is not None
c = ha.Context(23, 100)
assert c.user_id == 23
assert c.parent_id == 100
assert c.id is not None