diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index 35cf695f1e3a..ad231a2a348d 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -7,7 +7,7 @@ import logging import voluptuous as vol from homeassistant.setup import async_prepare_setup_platform -from homeassistant.core import CoreState +from homeassistant.core import CoreState, Context from homeassistant.loader import bind_hass from homeassistant.const import ( ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF, @@ -280,15 +280,21 @@ class AutomationEntity(ToggleEntity, RestoreEntity): This method is a coroutine. """ - if skip_condition or self._cond_func(variables): - self.async_set_context(context) - self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, { - ATTR_NAME: self._name, - ATTR_ENTITY_ID: self.entity_id, - }, context=context) - await self._async_action(self.entity_id, variables, context) - self._last_triggered = utcnow() - await self.async_update_ha_state() + if not skip_condition and not self._cond_func(variables): + return + + # Create a new context referring to the old context. + parent_id = None if context is None else context.id + trigger_context = Context(parent_id=parent_id) + + self.async_set_context(trigger_context) + self.hass.bus.async_fire(EVENT_AUTOMATION_TRIGGERED, { + ATTR_NAME: self._name, + ATTR_ENTITY_ID: self.entity_id, + }, context=trigger_context) + await self._async_action(self.entity_id, variables, trigger_context) + self._last_triggered = utcnow() + await self.async_update_ha_state() async def async_will_remove_from_hass(self): """Remove listeners when removing automation from HASS.""" diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 825f402aef21..972862e7a9c0 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -220,6 +220,15 @@ def _apply_update(engine, new_version, old_version): _create_index(engine, "states", "ix_states_context_user_id") elif new_version == 7: _create_index(engine, "states", "ix_states_entity_id") + elif new_version == 8: + # Pending migration, want to group a few. + pass + # _add_columns(engine, "events", [ + # 'context_parent_id CHARACTER(36)', + # ]) + # _add_columns(engine, "states", [ + # 'context_parent_id CHARACTER(36)', + # ]) else: raise ValueError("No schema migration defined for version {}" .format(new_version)) diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index d1be17b83d50..bea2b12b3702 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -34,16 +34,20 @@ class Events(Base): # type: ignore created = Column(DateTime(timezone=True), default=datetime.utcnow) context_id = Column(String(36), index=True) context_user_id = Column(String(36), index=True) + # context_parent_id = Column(String(36), index=True) @staticmethod def from_event(event): """Create an event database object from a native event.""" - return Events(event_type=event.event_type, - event_data=json.dumps(event.data, cls=JSONEncoder), - origin=str(event.origin), - time_fired=event.time_fired, - context_id=event.context.id, - context_user_id=event.context.user_id) + return Events( + event_type=event.event_type, + event_data=json.dumps(event.data, cls=JSONEncoder), + origin=str(event.origin), + time_fired=event.time_fired, + context_id=event.context.id, + context_user_id=event.context.user_id, + # context_parent_id=event.context.parent_id, + ) def to_native(self): """Convert to a natve HA Event.""" @@ -81,6 +85,7 @@ class States(Base): # type: ignore created = Column(DateTime(timezone=True), default=datetime.utcnow) context_id = Column(String(36), index=True) context_user_id = Column(String(36), index=True) + # context_parent_id = Column(String(36), index=True) __table_args__ = ( # Used for fetching the state of entities at a specific time @@ -99,6 +104,7 @@ class States(Base): # type: ignore entity_id=entity_id, context_id=event.context.id, context_user_id=event.context.user_id, + # context_parent_id=event.context.parent_id, ) # State got deleted diff --git a/homeassistant/core.py b/homeassistant/core.py index 48ef4f462729..253900a39ef8 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -409,6 +409,10 @@ class Context: type=str, default=None, ) + parent_id = attr.ib( + type=Optional[str], + default=None + ) id = attr.ib( type=str, default=attr.Factory(lambda: uuid.uuid4().hex), @@ -418,6 +422,7 @@ class Context: """Return a dictionary representation of the context.""" return { 'id': self.id, + 'parent_id': self.parent_id, 'user_id': self.user_id, } diff --git a/tests/components/automation/test_event.py b/tests/components/automation/test_event.py index 4b669fc13562..8ca7f6b13f5e 100644 --- a/tests/components/automation/test_event.py +++ b/tests/components/automation/test_event.py @@ -41,7 +41,7 @@ async def test_if_fires_on_event(hass, calls): hass.bus.async_fire('test_event', context=context) await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id await common.async_turn_off(hass) await hass.async_block_till_done() diff --git a/tests/components/automation/test_geo_location.py b/tests/components/automation/test_geo_location.py index 928296c8d276..92ded1a07db2 100644 --- a/tests/components/automation/test_geo_location.py +++ b/tests/components/automation/test_geo_location.py @@ -68,7 +68,7 @@ async def test_if_fires_on_zone_enter(hass, calls): await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id assert 'geo_location - geo_location.entity - hello - hello - test' == \ calls[0].data['some'] @@ -221,7 +221,7 @@ async def test_if_fires_on_zone_appear(hass, calls): await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id assert 'geo_location - geo_location.entity - - hello - test' == \ calls[0].data['some'] diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index 12c97507a137..a019f65afcf5 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -369,38 +369,47 @@ async def test_shared_context(hass, calls): }) context = Context() - automation_mock = Mock() + first_automation_listener = Mock() event_mock = Mock() - hass.bus.async_listen('test_event2', automation_mock) + hass.bus.async_listen('test_event2', first_automation_listener) hass.bus.async_listen(EVENT_AUTOMATION_TRIGGERED, event_mock) hass.bus.async_fire('test_event', context=context) await hass.async_block_till_done() # Ensure events was fired - assert automation_mock.call_count == 1 + assert first_automation_listener.call_count == 1 assert event_mock.call_count == 2 - # Ensure context carries through the event - args, kwargs = automation_mock.call_args - assert args[0].context == context + # Verify automation triggered evenet for 'hello' automation + args, kwargs = event_mock.call_args_list[0] + first_trigger_context = args[0].context + assert first_trigger_context.parent_id == context.id + # Ensure event data has all attributes set + assert args[0].data.get(ATTR_NAME) is not None + assert args[0].data.get(ATTR_ENTITY_ID) is not None - for call in event_mock.call_args_list: - args, kwargs = call - assert args[0].context == context - # Ensure event data has all attributes set - assert args[0].data.get(ATTR_NAME) is not None - assert args[0].data.get(ATTR_ENTITY_ID) is not None + # Ensure context set correctly for event fired by 'hello' automation + args, kwargs = first_automation_listener.call_args + assert args[0].context is first_trigger_context - # Ensure the automation state shares the same context + # Ensure the 'hello' automation state has the right context state = hass.states.get('automation.hello') assert state is not None - assert state.context == context + assert state.context is first_trigger_context + + # Verify automation triggered evenet for 'bye' automation + args, kwargs = event_mock.call_args_list[1] + second_trigger_context = args[0].context + assert second_trigger_context.parent_id == first_trigger_context.id + # Ensure event data has all attributes set + assert args[0].data.get(ATTR_NAME) is not None + assert args[0].data.get(ATTR_ENTITY_ID) is not None # Ensure the service call from the second automation # shares the same context assert len(calls) == 1 - assert calls[0].context == context + assert calls[0].context is second_trigger_context async def test_services(hass, calls): diff --git a/tests/components/automation/test_numeric_state.py b/tests/components/automation/test_numeric_state.py index 92a5f3b8b921..803a15e9634f 100644 --- a/tests/components/automation/test_numeric_state.py +++ b/tests/components/automation/test_numeric_state.py @@ -45,7 +45,7 @@ async def test_if_fires_on_entity_change_below(hass, calls): hass.states.async_set('test.entity', 9, context=context) await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id # Set above 12 so the automation will fire again hass.states.async_set('test.entity', 12) @@ -134,7 +134,7 @@ async def test_if_not_fires_on_entity_change_below_to_below(hass, calls): hass.states.async_set('test.entity', 9, context=context) await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id # already below so should not fire again hass.states.async_set('test.entity', 5) diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py index abe02638f265..53c1eaab3d99 100644 --- a/tests/components/automation/test_state.py +++ b/tests/components/automation/test_state.py @@ -55,7 +55,7 @@ async def test_if_fires_on_entity_change(hass, calls): hass.states.async_set('test.entity', 'world', context=context) await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id assert 'state - test.entity - hello - world - None' == \ calls[0].data['some'] diff --git a/tests/components/automation/test_template.py b/tests/components/automation/test_template.py index c326c7f03f48..f803f97f4abd 100644 --- a/tests/components/automation/test_template.py +++ b/tests/components/automation/test_template.py @@ -257,7 +257,7 @@ async def test_if_fires_on_change_with_template_advanced(hass, calls): hass.states.async_set('test.entity', 'world', context=context) await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id assert 'template - test.entity - hello - world' == \ calls[0].data['some'] diff --git a/tests/components/automation/test_zone.py b/tests/components/automation/test_zone.py index 04ffeaf13aad..d5bfd9fdf885 100644 --- a/tests/components/automation/test_zone.py +++ b/tests/components/automation/test_zone.py @@ -66,7 +66,7 @@ async def test_if_fires_on_zone_enter(hass, calls): await hass.async_block_till_done() assert 1 == len(calls) - assert calls[0].context is context + assert calls[0].context.parent_id == context.id assert 'zone - test.entity - hello - hello - test' == \ calls[0].data['some'] diff --git a/tests/test_core.py b/tests/test_core.py index e2ed249f441d..5e23fab36e77 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -310,6 +310,7 @@ class TestEvent(unittest.TestCase): 'time_fired': now, 'context': { 'id': event.context.id, + 'parent_id': None, 'user_id': event.context.user_id, }, } @@ -1076,3 +1077,16 @@ async def test_service_call_event_contains_original_data(hass): assert len(calls) == 1 assert calls[0].data['number'] == 23 assert calls[0].context is context + + +def test_context(): + """Test context init.""" + c = ha.Context() + assert c.user_id is None + assert c.parent_id is None + assert c.id is not None + + c = ha.Context(23, 100) + assert c.user_id == 23 + assert c.parent_id == 100 + assert c.id is not None