1
mirror of https://github.com/home-assistant/core synced 2024-08-02 23:40:32 +02:00

Add and restore context in recorder (#15859)

This commit is contained in:
Paulus Schoutsen 2018-08-10 18:09:01 +02:00 committed by GitHub
parent da916d7b27
commit 9512bb9587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 73 additions and 14 deletions

View File

@ -114,6 +114,27 @@ def _drop_index(engine, table_name, index_name):
"critical operation.", index_name, table_name)
def _add_columns(engine, table_name, columns_def):
"""Add columns to a table."""
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
columns_def = ['ADD COLUMN {}'.format(col_def) for col_def in columns_def]
try:
engine.execute(text("ALTER TABLE {table} {columns_def}".format(
table=table_name,
columns_def=', '.join(columns_def))))
return
except SQLAlchemyError:
pass
for column_def in columns_def:
engine.execute(text("ALTER TABLE {table} {column_def}".format(
table=table_name,
column_def=column_def)))
def _apply_update(engine, new_version, old_version):
"""Perform operations to bring schema up to date."""
if new_version == 1:
@ -146,6 +167,19 @@ def _apply_update(engine, new_version, old_version):
elif new_version == 5:
# Create supporting index for States.event_id foreign key
_create_index(engine, "states", "ix_states_event_id")
elif new_version == 6:
_add_columns(engine, "events", [
'context_id CHARACTER(36)',
'context_user_id CHARACTER(36)',
])
_create_index(engine, "events", "ix_events_context_id")
_create_index(engine, "events", "ix_events_context_user_id")
_add_columns(engine, "states", [
'context_id CHARACTER(36)',
'context_user_id CHARACTER(36)',
])
_create_index(engine, "states", "ix_states_context_id")
_create_index(engine, "states", "ix_states_context_user_id")
else:
raise ValueError("No schema migration defined for version {}"
.format(new_version))

View File

@ -9,14 +9,15 @@ from sqlalchemy import (
from sqlalchemy.ext.declarative import declarative_base
import homeassistant.util.dt as dt_util
from homeassistant.core import Event, EventOrigin, State, split_entity_id
from homeassistant.core import (
Context, Event, EventOrigin, State, split_entity_id)
from homeassistant.remote import JSONEncoder
# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()
SCHEMA_VERSION = 5
SCHEMA_VERSION = 6
_LOGGER = logging.getLogger(__name__)
@ -31,6 +32,8 @@ class Events(Base): # type: ignore
origin = Column(String(32))
time_fired = Column(DateTime(timezone=True), index=True)
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)
@staticmethod
def from_event(event):
@ -38,16 +41,23 @@ class Events(Base): # type: ignore
return Events(event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired)
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id)
def to_native(self):
"""Convert to a natve HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id
)
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
_process_timestamp(self.time_fired)
_process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
@ -69,6 +79,8 @@ class States(Base): # type: ignore
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow,
index=True)
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)
__table_args__ = (
# Used for fetching the state of entities at a specific time
@ -82,7 +94,11 @@ class States(Base): # type: ignore
entity_id = event.data['entity_id']
state = event.data.get('new_state')
dbstate = States(entity_id=entity_id)
dbstate = States(
entity_id=entity_id,
context_id=event.context.id,
context_user_id=event.context.user_id,
)
# State got deleted
if state is None:
@ -103,12 +119,17 @@ class States(Base): # type: ignore
def to_native(self):
"""Convert to an HA state object."""
context = Context(
id=self.context_id,
user_id=self.context_user_id
)
try:
return State(
self.entity_id, self.state,
json.loads(self.attributes),
_process_timestamp(self.last_changed),
_process_timestamp(self.last_updated)
_process_timestamp(self.last_updated),
context=context,
)
except ValueError:
# When json.loads fails

View File

@ -423,7 +423,8 @@ class Event:
self.event_type == other.event_type and
self.data == other.data and
self.origin == other.origin and
self.time_fired == other.time_fired)
self.time_fired == other.time_fired and
self.context == other.context)
class EventBus:
@ -695,7 +696,8 @@ class State:
return (self.__class__ == other.__class__ and # type: ignore
self.entity_id == other.entity_id and
self.state == other.state and
self.attributes == other.attributes)
self.attributes == other.attributes and
self.context == other.context)
def __repr__(self) -> str:
"""Return the representation of the states."""

View File

@ -266,7 +266,7 @@ def mock_state_change_event(hass, new_state, old_state=None):
if old_state:
event_data['old_state'] = old_state
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
@asyncio.coroutine

View File

@ -60,7 +60,7 @@ class TestStates(unittest.TestCase):
'entity_id': 'sensor.temperature',
'old_state': None,
'new_state': state,
})
}, context=state.context)
assert state == States.from_event(event).to_native()
def test_from_event_to_delete_state(self):

View File

@ -83,9 +83,10 @@ class TestComponentHistory(unittest.TestCase):
self.wait_recording_done()
# Get states returns everything before POINT
self.assertEqual(states,
sorted(history.get_states(self.hass, future),
key=lambda state: state.entity_id))
for state1, state2 in zip(
states, sorted(history.get_states(self.hass, future),
key=lambda state: state.entity_id)):
assert state1 == state2
# Test get_state here because we have a DB setup
self.assertEqual(

View File

@ -246,8 +246,9 @@ class TestEvent(unittest.TestCase):
"""Test events."""
now = dt_util.utcnow()
data = {'some': 'attr'}
context = ha.Context()
event1, event2 = [
ha.Event('some_type', data, time_fired=now)
ha.Event('some_type', data, time_fired=now, context=context)
for _ in range(2)
]