1
mirror of https://github.com/home-assistant/core synced 2024-08-02 23:40:32 +02:00

Restore_state helper to restore entity states from the DB on startup (#4614)

* Restore states

* feedback

* Remove component move into recorder

* space

* helper

* Address my own comments

* Improve test coverage

* Add test for light restore state
This commit is contained in:
Johann Kellerman 2017-02-21 09:40:27 +02:00 committed by Paulus Schoutsen
parent 2b9fb73032
commit fdc373f27e
18 changed files with 425 additions and 184 deletions

View File

@ -15,7 +15,6 @@ import voluptuous as vol
from homeassistant.const import (
HTTP_BAD_REQUEST, CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE)
import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util
from homeassistant.components import recorder, script
from homeassistant.components.frontend import register_built_in_panel
@ -28,34 +27,22 @@ DOMAIN = 'history'
DEPENDENCIES = ['recorder', 'http']
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
CONF_EXCLUDE: vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
}),
CONF_INCLUDE: vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
})
}),
DOMAIN: recorder.FILTER_SCHEMA,
}, extra=vol.ALLOW_EXTRA)
SIGNIFICANT_DOMAINS = ('thermostat', 'climate')
IGNORE_DOMAINS = ('zone', 'scene',)
def last_5_states(entity_id):
"""Return the last 5 states for entity_id."""
entity_id = entity_id.lower()
states = recorder.get_model('States')
return recorder.execute(
recorder.query('States').filter(
(states.entity_id == entity_id) &
(states.last_changed == states.last_updated)
).order_by(states.state_id.desc()).limit(5))
def last_recorder_run():
"""Retireve the last closed recorder run from the DB."""
rec_runs = recorder.get_model('RecorderRuns')
with recorder.session_scope() as session:
res = recorder.query(rec_runs).order_by(rec_runs.end.desc()).first()
if res is None:
return None
session.expunge(res)
return res
def get_significant_states(start_time, end_time=None, entity_id=None,
@ -91,7 +78,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None,
def state_changes_during_period(start_time, end_time=None, entity_id=None):
"""Return states changes during UTC period start_time - end_time."""
states = recorder.get_model('States')
query = recorder.query('States').filter(
query = recorder.query(states).filter(
(states.last_changed == states.last_updated) &
(states.last_changed > start_time))
@ -132,7 +119,7 @@ def get_states(utc_point_in_time, entity_ids=None, run=None, filters=None):
most_recent_state_ids = most_recent_state_ids.group_by(
states.entity_id).subquery()
query = recorder.query('States').join(most_recent_state_ids, and_(
query = recorder.query(states).join(most_recent_state_ids, and_(
states.state_id == most_recent_state_ids.c.max_state_id))
for state in recorder.execute(query):
@ -185,27 +172,13 @@ def setup(hass, config):
filters.included_entities = include[CONF_ENTITIES]
filters.included_domains = include[CONF_DOMAINS]
hass.http.register_view(Last5StatesView)
recorder.get_instance()
hass.http.register_view(HistoryPeriodView(filters))
register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box')
return True
class Last5StatesView(HomeAssistantView):
"""Handle last 5 state view requests."""
url = '/api/history/entity/{entity_id}/recent_states'
name = 'api:history:entity-recent-states'
@asyncio.coroutine
def get(self, request, entity_id):
"""Retrieve last 5 states of entity."""
result = yield from request.app['hass'].loop.run_in_executor(
None, last_5_states, entity_id)
return self.json(result)
class HistoryPeriodView(HomeAssistantView):
"""Handle history period requests."""

View File

@ -15,6 +15,7 @@ from homeassistant.const import (
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.restore_state import async_get_last_state
DOMAIN = 'input_boolean'
@ -139,6 +140,14 @@ class InputBoolean(ToggleEntity):
"""Return true if entity is on."""
return self._state
@asyncio.coroutine
def async_added_to_hass(self):
"""Called when entity about to be added to hass."""
state = yield from async_get_last_state(self.hass, self.entity_id)
if not state:
return
self._state = state.state == 'on'
@asyncio.coroutine
def async_turn_on(self, **kwargs):
"""Turn the entity on."""

View File

@ -22,6 +22,7 @@ from homeassistant.helpers.entity import ToggleEntity
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.restore_state import async_restore_state
import homeassistant.util.color as color_util
from homeassistant.util.async import run_callback_threadsafe
@ -126,6 +127,14 @@ PROFILE_SCHEMA = vol.Schema(
_LOGGER = logging.getLogger(__name__)
def extract_info(state):
"""Extract light parameters from a state object."""
params = {key: state.attributes[key] for key in PROP_TO_ATTR
if key in state.attributes}
params['is_on'] = state.state == STATE_ON
return params
def is_on(hass, entity_id=None):
"""Return if the lights are on based on the statemachine."""
entity_id = entity_id or ENTITY_ID_ALL_LIGHTS
@ -369,3 +378,9 @@ class Light(ToggleEntity):
def supported_features(self):
"""Flag supported features."""
return 0
@asyncio.coroutine
def async_added_to_hass(self):
"""Component added, restore_state using platforms."""
if hasattr(self, 'async_restore_state'):
yield from async_restore_state(self, extract_info)

View File

@ -4,6 +4,7 @@ Demo light platform that implements lights.
For more details about this platform, please refer to the documentation
https://home-assistant.io/components/demo/
"""
import asyncio
import random
from homeassistant.components.light import (
@ -149,3 +150,26 @@ class DemoLight(Light):
# As we have disabled polling, we need to inform
# Home Assistant about updates in our state ourselves.
self.schedule_update_ha_state()
@asyncio.coroutine
def async_restore_state(self, is_on, **kwargs):
"""Restore the demo state."""
self._state = is_on
if 'brightness' in kwargs:
self._brightness = kwargs['brightness']
if 'color_temp' in kwargs:
self._ct = kwargs['color_temp']
if 'rgb_color' in kwargs:
self._rgb = kwargs['rgb_color']
if 'xy_color' in kwargs:
self._xy_color = kwargs['xy_color']
if 'white_value' in kwargs:
self._white = kwargs['white_value']
if 'effect' in kwargs:
self._effect = kwargs['effect']

View File

@ -22,6 +22,7 @@ from homeassistant.const import (
ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS,
CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL)
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType, QueryType
@ -42,36 +43,35 @@ CONNECT_RETRY_WAIT = 10
QUERY_RETRY_WAIT = 0.1
ERROR_QUERY = "Error during query: %s"
FILTER_SCHEMA = vol.Schema({
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
}),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
})
})
CONFIG_SCHEMA = vol.Schema({
DOMAIN: vol.Schema({
DOMAIN: FILTER_SCHEMA.extend({
vol.Optional(CONF_PURGE_DAYS):
vol.All(vol.Coerce(int), vol.Range(min=1)),
vol.Optional(CONF_DB_URL): cv.string,
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
}),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema({
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]):
vol.All(cv.ensure_list, [cv.string])
})
})
}, extra=vol.ALLOW_EXTRA)
_INSTANCE = None # type: Any
_LOGGER = logging.getLogger(__name__)
# These classes will be populated during setup()
# scoped_session, in the same thread session_scope() stays the same
_SESSION = None
@contextmanager
def session_scope():
"""Provide a transactional scope around a series of operations."""
session = _SESSION()
session = _INSTANCE.get_session()
try:
yield session
session.commit()
@ -83,15 +83,28 @@ def session_scope():
session.close()
def get_instance() -> None:
"""Throw error if recorder not initialized."""
if _INSTANCE is None:
raise RuntimeError("Recorder not initialized.")
ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
_wait(_INSTANCE.db_ready, "Database not ready")
return _INSTANCE
# pylint: disable=invalid-sequence-index
def execute(qry: QueryType) -> List[Any]:
"""Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections.
"""
_verify_instance()
import sqlalchemy.exc
get_instance()
from sqlalchemy.exc import SQLAlchemyError
with session_scope() as session:
for _ in range(0, RETRIES):
try:
@ -99,7 +112,7 @@ def execute(qry: QueryType) -> List[Any]:
row for row in
(row.to_native() for row in qry)
if row is not None]
except sqlalchemy.exc.SQLAlchemyError as err:
except SQLAlchemyError as err:
_LOGGER.error(ERROR_QUERY, err)
session.rollback()
time.sleep(QUERY_RETRY_WAIT)
@ -111,13 +124,13 @@ def run_information(point_in_time: Optional[datetime]=None):
There is also the run that covers point_in_time.
"""
_verify_instance()
ins = get_instance()
recorder_runs = get_model('RecorderRuns')
if point_in_time is None or point_in_time > _INSTANCE.recording_start:
if point_in_time is None or point_in_time > ins.recording_start:
return recorder_runs(
end=None,
start=_INSTANCE.recording_start,
start=ins.recording_start,
closed_incorrect=False)
with session_scope() as session:
@ -148,17 +161,19 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool:
exclude = config.get(DOMAIN, {}).get(CONF_EXCLUDE, {})
_INSTANCE = Recorder(hass, purge_days=purge_days, uri=db_url,
include=include, exclude=exclude)
_INSTANCE.start()
return True
def query(model_name: Union[str, Any], *args) -> QueryType:
def query(model_name: Union[str, Any], session=None, *args) -> QueryType:
"""Helper to return a query handle."""
_verify_instance()
if session is None:
session = get_instance().get_session()
if isinstance(model_name, str):
return _SESSION().query(get_model(model_name), *args)
return _SESSION().query(model_name, *args)
return session.query(get_model(model_name), *args)
return session.query(model_name, *args)
def get_model(model_name: str) -> Any:
@ -185,6 +200,7 @@ class Recorder(threading.Thread):
self.recording_start = dt_util.utcnow()
self.db_url = uri
self.db_ready = threading.Event()
self.start_recording = threading.Event()
self.engine = None # type: Any
self._run = None # type: Any
@ -195,23 +211,26 @@ class Recorder(threading.Thread):
def start_recording(event):
"""Start recording."""
self.start()
self.start_recording.set()
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_recording)
hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown)
hass.bus.listen(MATCH_ALL, self.event_listener)
self.get_session = None
def run(self):
"""Start processing events to save."""
from homeassistant.components.recorder.models import Events, States
import sqlalchemy.exc
from sqlalchemy.exc import SQLAlchemyError
while True:
try:
self._setup_connection()
self._setup_run()
self.db_ready.set()
break
except sqlalchemy.exc.SQLAlchemyError as err:
except SQLAlchemyError as err:
_LOGGER.error("Error during connection setup: %s (retrying "
"in %s seconds)", err, CONNECT_RETRY_WAIT)
time.sleep(CONNECT_RETRY_WAIT)
@ -220,6 +239,8 @@ class Recorder(threading.Thread):
async_track_time_interval(
self.hass, self._purge_old_data, timedelta(days=2))
_wait(self.start_recording, "Waiting to start recording")
while True:
event = self.queue.get()
@ -275,10 +296,9 @@ class Recorder(threading.Thread):
def shutdown(self, event):
"""Tell the recorder to shut down."""
global _INSTANCE # pylint: disable=global-statement
_INSTANCE = None
self.queue.put(None)
self.join()
_INSTANCE = None
def block_till_done(self):
"""Block till all events processed."""
@ -286,15 +306,10 @@ class Recorder(threading.Thread):
def block_till_db_ready(self):
"""Block until the database session is ready."""
self.db_ready.wait(10)
while not self.db_ready.is_set():
_LOGGER.warning('Database not ready, waiting another 10 seconds.')
self.db_ready.wait(10)
_wait(self.db_ready, "Database not ready")
def _setup_connection(self):
"""Ensure database is ready to fly."""
global _SESSION # pylint: disable=invalid-name,global-statement
import homeassistant.components.recorder.models as models
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session
@ -312,9 +327,8 @@ class Recorder(threading.Thread):
models.Base.metadata.create_all(self.engine)
session_factory = sessionmaker(bind=self.engine)
_SESSION = scoped_session(session_factory)
self.get_session = scoped_session(session_factory)
self._migrate_schema()
self.db_ready.set()
def _migrate_schema(self):
"""Check if the schema needs to be upgraded."""
@ -396,16 +410,16 @@ class Recorder(threading.Thread):
def _close_connection(self):
"""Close the connection."""
global _SESSION # pylint: disable=invalid-name,global-statement
self.engine.dispose()
self.engine = None
_SESSION = None
self.get_session = None
def _setup_run(self):
"""Log the start of the current run."""
recorder_runs = get_model('RecorderRuns')
with session_scope() as session:
for run in query('RecorderRuns').filter_by(end=None):
for run in query(
recorder_runs, session=session).filter_by(end=None):
run.closed_incorrect = True
run.end = self.recording_start
_LOGGER.warning("Ended unfinished session (id=%s from %s)",
@ -482,13 +496,13 @@ class Recorder(threading.Thread):
return False
def _verify_instance() -> None:
"""Throw error if recorder not initialized."""
if _INSTANCE is None:
raise RuntimeError("Recorder not initialized.")
ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident")
if ident is not None and ident == threading.get_ident():
raise RuntimeError('Cannot be called from within the event loop')
_INSTANCE.block_till_db_ready()
def _wait(event, message):
"""Event wait helper."""
for retry in (10, 20, 30):
event.wait(10)
if event.is_set():
return
msg = message + " ({} seconds)".format(retry)
_LOGGER.warning(msg)
if not event.is_set():
raise HomeAssistantError(msg)

View File

@ -199,7 +199,7 @@ class HistoryStatsSensor(Entity):
if self._start is not None:
try:
start_rendered = self._start.render()
except TemplateError as ex:
except (TemplateError, TypeError) as ex:
HistoryStatsHelper.handle_template_exception(ex, 'start')
return
start = dt_util.parse_datetime(start_rendered)
@ -216,7 +216,7 @@ class HistoryStatsSensor(Entity):
if self._end is not None:
try:
end_rendered = self._end.render()
except TemplateError as ex:
except (TemplateError, TypeError) as ex:
HistoryStatsHelper.handle_template_exception(ex, 'end')
return
end = dt_util.parse_datetime(end_rendered)

View File

@ -288,7 +288,7 @@ class Entity(object):
self.hass.add_job(self.async_update_ha_state(force_refresh))
def remove(self) -> None:
"""Remove entitiy from HASS."""
"""Remove entity from HASS."""
run_coroutine_threadsafe(
self.async_remove(), self.hass.loop
).result()

View File

@ -202,6 +202,10 @@ class EntityComponent(object):
'Invalid entity id: {}'.format(entity.entity_id))
self.entities[entity.entity_id] = entity
if hasattr(entity, 'async_added_to_hass'):
yield from entity.async_added_to_hass()
yield from entity.async_update_ha_state()
return True

View File

@ -0,0 +1,82 @@
"""Support for restoring entity states on startup."""
import asyncio
import logging
from datetime import timedelta
from homeassistant.core import HomeAssistant, CoreState, callback
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.components.history import get_states, last_recorder_run
from homeassistant.components.recorder import DOMAIN as _RECORDER
import homeassistant.util.dt as dt_util
_LOGGER = logging.getLogger(__name__)
DATA_RESTORE_CACHE = 'restore_state_cache'
_LOCK = 'restore_lock'
def _load_restore_cache(hass: HomeAssistant):
"""Load the restore cache to be used by other components."""
@callback
def remove_cache(event):
"""Remove the states cache."""
hass.data.pop(DATA_RESTORE_CACHE, None)
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache)
last_run = last_recorder_run()
if last_run is None or last_run.end is None:
_LOGGER.debug('Not creating cache - no suitable last run found: %s',
last_run)
hass.data[DATA_RESTORE_CACHE] = {}
return
last_end_time = last_run.end - timedelta(seconds=1)
# Unfortunately the recorder_run model do not return offset-aware time
last_end_time = last_end_time.replace(tzinfo=dt_util.UTC)
_LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time)
states = get_states(last_end_time, run=last_run)
# Cache the states
hass.data[DATA_RESTORE_CACHE] = {
state.entity_id: state for state in states}
_LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE]))
@asyncio.coroutine
def async_get_last_state(hass, entity_id: str):
"""Helper to restore state."""
if (_RECORDER not in hass.config.components or
hass.state != CoreState.starting):
return None
if DATA_RESTORE_CACHE in hass.data:
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
if _LOCK not in hass.data:
hass.data[_LOCK] = asyncio.Lock(loop=hass.loop)
with (yield from hass.data[_LOCK]):
if DATA_RESTORE_CACHE not in hass.data:
yield from hass.loop.run_in_executor(
None, _load_restore_cache, hass)
return hass.data[DATA_RESTORE_CACHE].get(entity_id)
@asyncio.coroutine
def async_restore_state(entity, extract_info):
"""Helper to call entity.async_restore_state with cached info."""
if entity.hass.state != CoreState.starting:
_LOGGER.debug("Not restoring state: State is not starting: %s",
entity.hass.state)
return
state = yield from async_get_last_state(entity.hass, entity.entity_id)
if not state:
return
yield from entity.async_restore_state(**extract_info(state))

View File

@ -197,8 +197,8 @@ def load_order_components(components: Sequence[str]) -> OrderedSet:
load_order.update(comp_load_order)
# Push some to first place in load order
for comp in ('mqtt_eventstream', 'mqtt', 'logger',
'recorder', 'introduction'):
for comp in ('mqtt_eventstream', 'mqtt', 'recorder',
'introduction', 'logger'):
if comp in load_order:
load_order.promote(comp)

View File

@ -22,7 +22,7 @@ from homeassistant.const import (
STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED,
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
ATTR_DISCOVERED, SERVER_PORT)
from homeassistant.components import sun, mqtt
from homeassistant.components import sun, mqtt, recorder
from homeassistant.components.http.auth import auth_middleware
from homeassistant.components.http.const import (
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
@ -452,3 +452,31 @@ def assert_setup_component(count, domain=None):
res_len = 0 if res is None else len(res)
assert res_len == count, 'setup_component failed, expected {} got {}: {}' \
.format(count, res_len, res)
def init_recorder_component(hass, add_config=None, db_ready_callback=None):
"""Initialize the recorder."""
config = dict(add_config) if add_config else {}
config[recorder.CONF_DB_URL] = 'sqlite://' # In memory DB
saved_recorder = recorder.Recorder
class Recorder2(saved_recorder):
"""Recorder with a callback after db_ready."""
def _setup_connection(self):
"""Setup the connection and run the callback."""
super(Recorder2, self)._setup_connection()
if db_ready_callback:
_LOGGER.debug('db_ready_callback start (db_ready not set,'
'never use get_instance in the callback)')
db_ready_callback()
_LOGGER.debug('db_ready_callback completed')
with patch('homeassistant.components.recorder.Recorder',
side_effect=Recorder2):
assert setup_component(hass, recorder.DOMAIN,
{recorder.DOMAIN: config})
assert recorder.DOMAIN in hass.config.components
recorder.get_instance().block_till_db_ready()
_LOGGER.info("In-memory recorder successfully started")

View File

@ -1,17 +1,20 @@
"""The tests for the demo light component."""
# pylint: disable=protected-access
import asyncio
import unittest
from homeassistant.bootstrap import setup_component
from homeassistant.core import State, CoreState
from homeassistant.bootstrap import setup_component, async_setup_component
import homeassistant.components.light as light
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
from tests.common import get_test_home_assistant
ENTITY_LIGHT = 'light.bed_light'
class TestDemoClimate(unittest.TestCase):
"""Test the demo climate hvac."""
class TestDemoLight(unittest.TestCase):
"""Test the demo light."""
# pylint: disable=invalid-name
def setUp(self):
@ -60,3 +63,36 @@ class TestDemoClimate(unittest.TestCase):
light.turn_off(self.hass, ENTITY_LIGHT)
self.hass.block_till_done()
self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT))
@asyncio.coroutine
def test_restore_state(hass):
"""Test state gets restored."""
hass.config.components.add('recorder')
hass.state = CoreState.starting
hass.data[DATA_RESTORE_CACHE] = {
'light.bed_light': State('light.bed_light', 'on', {
'brightness': 'value-brightness',
'color_temp': 'value-color_temp',
'rgb_color': 'value-rgb_color',
'xy_color': 'value-xy_color',
'white_value': 'value-white_value',
'effect': 'value-effect',
}),
}
yield from async_setup_component(hass, 'light', {
'light': {
'platform': 'demo',
}})
state = hass.states.get('light.bed_light')
assert state is not None
assert state.entity_id == 'light.bed_light'
assert state.state == 'on'
assert state.attributes.get('brightness') == 'value-brightness'
assert state.attributes.get('color_temp') == 'value-color_temp'
assert state.attributes.get('rgb_color') == 'value-rgb_color'
assert state.attributes.get('xy_color') == 'value-xy_color'
assert state.attributes.get('white_value') == 'value-white_value'
assert state.attributes.get('effect') == 'value-effect'

View File

@ -11,8 +11,7 @@ from sqlalchemy import create_engine
from homeassistant.core import callback
from homeassistant.const import MATCH_ALL
from homeassistant.components import recorder
from homeassistant.bootstrap import setup_component
from tests.common import get_test_home_assistant
from tests.common import get_test_home_assistant, init_recorder_component
from tests.components.recorder import models_original
@ -22,18 +21,15 @@ class BaseTestRecorder(unittest.TestCase):
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
db_uri = 'sqlite://' # In memory DB
setup_component(self.hass, recorder.DOMAIN, {
recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}})
init_recorder_component(self.hass)
self.hass.start()
recorder._verify_instance()
recorder._INSTANCE.block_till_done()
recorder.get_instance().block_till_done()
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
recorder._INSTANCE.shutdown(None)
self.hass.stop()
assert recorder._INSTANCE is None
with self.assertRaises(RuntimeError):
recorder.get_instance()
def _add_test_states(self):
"""Add multiple states to the db for testing."""
@ -228,7 +224,7 @@ class TestMigrateRecorder(BaseTestRecorder):
@patch('sqlalchemy.create_engine', new=create_engine_test)
@patch('homeassistant.components.recorder.Recorder._migrate_schema')
def setUp(self, migrate): # pylint: disable=invalid-name
def setUp(self, migrate): # pylint: disable=invalid-name,arguments-differ
"""Setup things to be run when tests are started.
create_engine is patched to create a db that starts with the old
@ -261,16 +257,12 @@ def hass_recorder():
"""HASS fixture with in-memory recorder."""
hass = get_test_home_assistant()
def setup_recorder(config={}):
def setup_recorder(config=None):
"""Setup with params."""
db_uri = 'sqlite://' # In memory DB
conf = {recorder.CONF_DB_URL: db_uri}
conf.update(config)
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: conf})
init_recorder_component(hass, config)
hass.start()
hass.block_till_done()
recorder._verify_instance()
recorder._INSTANCE.block_till_done()
recorder.get_instance().block_till_done()
return hass
yield setup_recorder
@ -352,12 +344,12 @@ def test_recorder_errors_exceptions(hass_recorder): \
# Verify the instance fails before setup
with pytest.raises(RuntimeError):
recorder._verify_instance()
recorder.get_instance()
# Setup the recorder
hass_recorder()
recorder._verify_instance()
recorder.get_instance()
# Verify session scope raises (and prints) an exception
with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \

View File

@ -1,16 +1,17 @@
"""The test for the History Statistics sensor platform."""
# pylint: disable=protected-access
import unittest
from datetime import timedelta
import unittest
from unittest.mock import patch
import homeassistant.components.recorder as recorder
import homeassistant.core as ha
import homeassistant.util.dt as dt_util
from homeassistant.bootstrap import setup_component
import homeassistant.components.recorder as recorder
from homeassistant.components.sensor.history_stats import HistoryStatsSensor
import homeassistant.core as ha
from homeassistant.helpers.template import Template
from tests.common import get_test_home_assistant
import homeassistant.util.dt as dt_util
from tests.common import init_recorder_component, get_test_home_assistant
class TestHistoryStatsSensor(unittest.TestCase):
@ -204,12 +205,8 @@ class TestHistoryStatsSensor(unittest.TestCase):
def init_recorder(self):
"""Initialize the recorder."""
db_uri = 'sqlite://'
with patch('homeassistant.core.Config.path', return_value=db_uri):
setup_component(self.hass, recorder.DOMAIN, {
"recorder": {
"db_url": db_uri}})
init_recorder_component(self.hass)
self.hass.start()
recorder._INSTANCE.block_till_db_ready()
recorder.get_instance().block_till_db_ready()
self.hass.block_till_done()
recorder._INSTANCE.block_till_done()
recorder.get_instance().block_till_done()

View File

@ -1,5 +1,5 @@
"""The tests the History component."""
# pylint: disable=protected-access
# pylint: disable=protected-access,invalid-name
from datetime import timedelta
import unittest
from unittest.mock import patch, sentinel
@ -10,68 +10,47 @@ import homeassistant.util.dt as dt_util
from homeassistant.components import history, recorder
from tests.common import (
mock_http_component, mock_state_change_event, get_test_home_assistant)
init_recorder_component, mock_http_component, mock_state_change_event,
get_test_home_assistant)
class TestComponentHistory(unittest.TestCase):
"""Test History component."""
# pylint: disable=invalid-name
def setUp(self):
def setUp(self): # pylint: disable=invalid-name
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
# pylint: disable=invalid-name
def tearDown(self):
def tearDown(self): # pylint: disable=invalid-name
"""Stop everything that was started."""
self.hass.stop()
def init_recorder(self):
"""Initialize the recorder."""
db_uri = 'sqlite://'
with patch('homeassistant.core.Config.path', return_value=db_uri):
setup_component(self.hass, recorder.DOMAIN, {
"recorder": {
"db_url": db_uri}})
init_recorder_component(self.hass)
self.hass.start()
recorder._INSTANCE.block_till_db_ready()
recorder.get_instance().block_till_db_ready()
self.wait_recording_done()
def wait_recording_done(self):
"""Block till recording is done."""
self.hass.block_till_done()
recorder._INSTANCE.block_till_done()
recorder.get_instance().block_till_done()
def test_setup(self):
"""Test setup method of history."""
mock_http_component(self.hass)
config = history.CONFIG_SCHEMA({
ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: {
# ha.DOMAIN: {},
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ['media_player'],
history.CONF_ENTITIES: ['thermostat.test']},
history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ['thermostat'],
history.CONF_ENTITIES: ['media_player.test']}}})
self.assertTrue(setup_component(self.hass, history.DOMAIN, config))
def test_last_5_states(self):
"""Test retrieving the last 5 states."""
self.init_recorder()
states = []
entity_id = 'test.last_5_states'
for i in range(7):
self.hass.states.set(entity_id, "State {}".format(i))
self.wait_recording_done()
if i > 1:
states.append(self.hass.states.get(entity_id))
self.assertEqual(
list(reversed(states)), history.last_5_states(entity_id))
self.assertTrue(setup_component(self.hass, history.DOMAIN, config))
def test_get_states(self):
"""Test getting states at a specific point in time."""
@ -121,6 +100,7 @@ class TestComponentHistory(unittest.TestCase):
entity_id = 'media_player.test'
def set_state(state):
"""Set the state."""
self.hass.states.set(entity_id, state)
self.wait_recording_done()
return self.hass.states.get(entity_id)
@ -311,7 +291,8 @@ class TestComponentHistory(unittest.TestCase):
config = history.CONFIG_SCHEMA({
ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: {
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ['media_player']},
history.CONF_EXCLUDE: {
history.CONF_DOMAINS: ['media_player']}}})
@ -332,7 +313,8 @@ class TestComponentHistory(unittest.TestCase):
config = history.CONFIG_SCHEMA({
ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: {
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_ENTITIES: ['media_player.test']},
history.CONF_EXCLUDE: {
history.CONF_ENTITIES: ['media_player.test']}}})
@ -351,7 +333,8 @@ class TestComponentHistory(unittest.TestCase):
config = history.CONFIG_SCHEMA({
ha.DOMAIN: {},
history.DOMAIN: {history.CONF_INCLUDE: {
history.DOMAIN: {
history.CONF_INCLUDE: {
history.CONF_DOMAINS: ['media_player'],
history.CONF_ENTITIES: ['thermostat.test']},
history.CONF_EXCLUDE: {
@ -359,7 +342,8 @@ class TestComponentHistory(unittest.TestCase):
history.CONF_ENTITIES: ['media_player.test']}}})
self.check_significant_states(zero, four, states, config)
def check_significant_states(self, zero, four, states, config):
def check_significant_states(self, zero, four, states, config): \
# pylint: disable=no-self-use
"""Check if significant states are retrieved."""
filters = history.Filters()
exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE)
@ -390,6 +374,7 @@ class TestComponentHistory(unittest.TestCase):
script_c = 'script.can_cancel_this_one'
def set_state(entity_id, state, **kwargs):
"""Set the state."""
self.hass.states.set(entity_id, state, **kwargs)
self.wait_recording_done()
return self.hass.states.get(entity_id)

View File

@ -1,15 +1,18 @@
"""The tests for the input_boolean component."""
# pylint: disable=protected-access
import asyncio
import unittest
import logging
from tests.common import get_test_home_assistant
from homeassistant.bootstrap import setup_component
from homeassistant.core import CoreState, State
from homeassistant.bootstrap import setup_component, async_setup_component
from homeassistant.components.input_boolean import (
DOMAIN, is_on, toggle, turn_off, turn_on)
from homeassistant.const import (
STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME)
from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE
_LOGGER = logging.getLogger(__name__)
@ -103,3 +106,30 @@ class TestInputBoolean(unittest.TestCase):
self.assertEqual('Hello World',
state_2.attributes.get(ATTR_FRIENDLY_NAME))
self.assertEqual('mdi:work', state_2.attributes.get(ATTR_ICON))
@asyncio.coroutine
def test_restore_state(hass):
"""Ensure states are restored on startup."""
hass.data[DATA_RESTORE_CACHE] = {
'input_boolean.b1': State('input_boolean.b1', 'on'),
'input_boolean.b2': State('input_boolean.b2', 'off'),
'input_boolean.b3': State('input_boolean.b3', 'on'),
}
hass.state = CoreState.starting
hass.config.components.add('recorder')
yield from async_setup_component(hass, DOMAIN, {
DOMAIN: {
'b1': None,
'b2': None,
}})
state = hass.states.get('input_boolean.b1')
assert state
assert state.state == 'on'
state = hass.states.get('input_boolean.b2')
assert state
assert state.state == 'off'

View File

@ -1,5 +1,6 @@
"""The tests for the logbook component."""
# pylint: disable=protected-access
# pylint: disable=protected-access,invalid-name
import logging
from datetime import timedelta
import unittest
from unittest.mock import patch
@ -13,7 +14,11 @@ import homeassistant.util.dt as dt_util
from homeassistant.components import logbook
from homeassistant.bootstrap import setup_component
from tests.common import mock_http_component, get_test_home_assistant
from tests.common import (
mock_http_component, init_recorder_component, get_test_home_assistant)
_LOGGER = logging.getLogger(__name__)
class TestComponentLogbook(unittest.TestCase):
@ -24,12 +29,14 @@ class TestComponentLogbook(unittest.TestCase):
def setUp(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
init_recorder_component(self.hass) # Force an in memory DB
mock_http_component(self.hass)
self.hass.config.components |= set(['frontend', 'recorder', 'api'])
with patch('homeassistant.components.logbook.'
'register_built_in_panel'):
assert setup_component(self.hass, logbook.DOMAIN,
self.EMPTY_CONFIG)
self.hass.start()
def tearDown(self):
"""Stop everything that was started."""
@ -41,6 +48,7 @@ class TestComponentLogbook(unittest.TestCase):
@ha.callback
def event_listener(event):
"""Append on event."""
calls.append(event)
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
@ -72,6 +80,7 @@ class TestComponentLogbook(unittest.TestCase):
@ha.callback
def event_listener(event):
"""Append on event."""
calls.append(event)
self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener)
@ -242,17 +251,17 @@ class TestComponentLogbook(unittest.TestCase):
entity_id2 = 'sensor.blu'
eventA = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, {
logbook.ATTR_NAME: name,
logbook.ATTR_MESSAGE: message,
logbook.ATTR_DOMAIN: domain,
logbook.ATTR_ENTITY_ID: entity_id,
})
logbook.ATTR_NAME: name,
logbook.ATTR_MESSAGE: message,
logbook.ATTR_DOMAIN: domain,
logbook.ATTR_ENTITY_ID: entity_id,
})
eventB = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, {
logbook.ATTR_NAME: name,
logbook.ATTR_MESSAGE: message,
logbook.ATTR_DOMAIN: domain,
logbook.ATTR_ENTITY_ID: entity_id2,
})
logbook.ATTR_NAME: name,
logbook.ATTR_MESSAGE: message,
logbook.ATTR_DOMAIN: domain,
logbook.ATTR_ENTITY_ID: entity_id2,
})
config = logbook.CONFIG_SCHEMA({
ha.DOMAIN: {},
@ -532,7 +541,8 @@ class TestComponentLogbook(unittest.TestCase):
def create_state_changed_event(self, event_time_fired, entity_id, state,
attributes=None, last_changed=None,
last_updated=None):
last_updated=None): \
# pylint: disable=no-self-use
"""Create state changed event."""
# Logbook only cares about state change events that
# contain an old state but will not actually act on it.

View File

@ -0,0 +1,42 @@
"""The tests for the Restore component."""
import asyncio
from unittest.mock import patch, MagicMock
from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import CoreState, State
import homeassistant.util.dt as dt_util
from homeassistant.helpers.restore_state import (
async_get_last_state, DATA_RESTORE_CACHE)
@asyncio.coroutine
def test_caching_data(hass):
"""Test that we cache data."""
hass.config.components.add('recorder')
hass.state = CoreState.starting
states = [
State('input_boolean.b0', 'on'),
State('input_boolean.b1', 'on'),
State('input_boolean.b2', 'on'),
]
with patch('homeassistant.helpers.restore_state.last_recorder_run',
return_value=MagicMock(end=dt_util.utcnow())), \
patch('homeassistant.helpers.restore_state.get_states',
return_value=states):
state = yield from async_get_last_state(hass, 'input_boolean.b1')
assert DATA_RESTORE_CACHE in hass.data
assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states}
assert state is not None
assert state.entity_id == 'input_boolean.b1'
assert state.state == 'on'
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
yield from hass.async_block_till_done()
assert DATA_RESTORE_CACHE not in hass.data