Make recorder execute avoid native conversion by default (#36938)

This commit is contained in:
J. Nick Koston 2020-06-21 23:58:57 -05:00 committed by GitHub
parent f4b8a95205
commit edad387b12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 42 additions and 35 deletions

View File

@ -109,7 +109,7 @@ def _get_significant_states(
query = query.order_by(States.entity_id, States.last_updated) query = query.order_by(States.entity_id, States.last_updated)
states = execute(query, to_native=False) states = execute(query)
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
elapsed = time.perf_counter() - timer_start elapsed = time.perf_counter() - timer_start
@ -144,9 +144,7 @@ def state_changes_during_period(hass, start_time, end_time=None, entity_id=None)
entity_ids = [entity_id] if entity_id is not None else None entity_ids = [entity_id] if entity_id is not None else None
states = execute( states = execute(query.order_by(States.entity_id, States.last_updated))
query.order_by(States.entity_id, States.last_updated), to_native=False
)
return _sorted_states_to_json(hass, session, states, start_time, entity_ids) return _sorted_states_to_json(hass, session, states, start_time, entity_ids)
@ -169,8 +167,7 @@ def get_last_state_changes(hass, number_of_states, entity_id):
states = execute( states = execute(
query.order_by(States.entity_id, States.last_updated.desc()).limit( query.order_by(States.entity_id, States.last_updated.desc()).limit(
number_of_states number_of_states
), )
to_native=False,
) )
return _sorted_states_to_json( return _sorted_states_to_json(
@ -271,7 +268,9 @@ def _get_states_with_session(
return [ return [
state state
for state in (States.to_native(row) for row in execute(query, to_native=False)) for state in (
States.to_native(row, validate_entity_id=False) for row in execute(query)
)
if not state.attributes.get(ATTR_HIDDEN, False) if not state.attributes.get(ATTR_HIDDEN, False)
] ]
@ -331,7 +330,8 @@ def _sorted_states_to_json(
[ [
native_state native_state
for native_state in ( for native_state in (
States.to_native(db_state) for db_state in group States.to_native(db_state, validate_entity_id=False)
for db_state in group
) )
if ( if (
domain != SCRIPT_DOMAIN domain != SCRIPT_DOMAIN
@ -347,7 +347,7 @@ def _sorted_states_to_json(
# in-between only provide the "state" and the # in-between only provide the "state" and the
# "last_changed". # "last_changed".
if not ent_results: if not ent_results:
ent_results.append(States.to_native(next(group))) ent_results.append(States.to_native(next(group), validate_entity_id=False))
initial_state = ent_results[-1] initial_state = ent_results[-1]
prev_state = ent_results[-1] prev_state = ent_results[-1]
@ -355,7 +355,7 @@ def _sorted_states_to_json(
for db_state in group: for db_state in group:
if ATTR_HIDDEN in db_state.attributes and States.to_native( if ATTR_HIDDEN in db_state.attributes and States.to_native(
db_state db_state, validate_entity_id=False
).attributes.get(ATTR_HIDDEN, False): ).attributes.get(ATTR_HIDDEN, False):
continue continue
@ -382,7 +382,7 @@ def _sorted_states_to_json(
# There was at least one state change # There was at least one state change
# replace the last minimal state with # replace the last minimal state with
# a full state # a full state
ent_results[-1] = States.to_native(prev_state) ent_results[-1] = States.to_native(prev_state, validate_entity_id=False)
# Filter out the empty lists if some states had 0 results. # Filter out the empty lists if some states had 0 results.
return {key: val for key, val in result.items() if val} return {key: val for key, val in result.items() if val}

View File

@ -311,7 +311,7 @@ class Plant(Entity):
) )
.order_by(States.last_updated.asc()) .order_by(States.last_updated.asc())
) )
states = execute(query) states = execute(query, to_native=True, validate_entity_ids=False)
for state in states: for state in states:
# filter out all None, NaN and "unknown" states # filter out all None, NaN and "unknown" states

View File

@ -128,7 +128,7 @@ class States(Base): # type: ignore
return dbstate return dbstate
def to_native(self): def to_native(self, validate_entity_id=True):
"""Convert to an HA state object.""" """Convert to an HA state object."""
context = Context(id=self.context_id, user_id=self.context_user_id) context = Context(id=self.context_id, user_id=self.context_user_id)
try: try:
@ -139,9 +139,7 @@ class States(Base): # type: ignore
process_timestamp(self.last_changed), process_timestamp(self.last_changed),
process_timestamp(self.last_updated), process_timestamp(self.last_updated),
context=context, context=context,
# Temp, because database can still store invalid entity IDs validate_entity_id=validate_entity_id,
# Remove with 1.0 or in 2020.
temp_invalid_id_bypass=True,
) )
except ValueError: except ValueError:
# When json.loads fails # When json.loads fails

View File

@ -54,7 +54,7 @@ def commit(session, work):
return False return False
def execute(qry, to_native=True): def execute(qry, to_native=False, validate_entity_ids=True):
"""Query the database and convert the objects to HA native form. """Query the database and convert the objects to HA native form.
This method also retries a few times in the case of stale connections. This method also retries a few times in the case of stale connections.
@ -64,7 +64,12 @@ def execute(qry, to_native=True):
timer_start = time.perf_counter() timer_start = time.perf_counter()
if to_native: if to_native:
result = [ result = [
row for row in (row.to_native() for row in qry) if row is not None row
for row in (
row.to_native(validate_entity_id=validate_entity_ids)
for row in qry
)
if row is not None
] ]
else: else:
result = list(qry) result = list(qry)

View File

@ -332,7 +332,7 @@ class StatisticsSensor(Entity):
query = query.order_by(States.last_updated.desc()).limit( query = query.order_by(States.last_updated.desc()).limit(
self._sampling_size self._sampling_size
) )
states = execute(query) states = execute(query, to_native=True, validate_entity_ids=False)
for state in reversed(states): for state in reversed(states):
self._add_state_to_queue(state) self._add_state_to_queue(state)

View File

@ -739,14 +739,12 @@ class State:
last_changed: Optional[datetime.datetime] = None, last_changed: Optional[datetime.datetime] = None,
last_updated: Optional[datetime.datetime] = None, last_updated: Optional[datetime.datetime] = None,
context: Optional[Context] = None, context: Optional[Context] = None,
# Temp, because database can still store invalid entity IDs validate_entity_id: Optional[bool] = True,
# Remove with 1.0 or in 2020.
temp_invalid_id_bypass: Optional[bool] = False,
) -> None: ) -> None:
"""Initialize a new state.""" """Initialize a new state."""
state = str(state) state = str(state)
if not valid_entity_id(entity_id) and not temp_invalid_id_bypass: if validate_entity_id and not valid_entity_id(entity_id):
raise InvalidEntityFormatError( raise InvalidEntityFormatError(
f"Invalid entity id encountered: {entity_id}. " f"Invalid entity id encountered: {entity_id}. "
"Format should be <domain>.<object_id>" "Format should be <domain>.<object_id>"

View File

@ -2,12 +2,14 @@
from datetime import datetime from datetime import datetime
import unittest import unittest
import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
from homeassistant.components.recorder.models import Base, Events, RecorderRuns, States from homeassistant.components.recorder.models import Base, Events, RecorderRuns, States
from homeassistant.const import EVENT_STATE_CHANGED from homeassistant.const import EVENT_STATE_CHANGED
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.exceptions import InvalidEntityFormatError
from homeassistant.util import dt from homeassistant.util import dt
ENGINE = None ENGINE = None
@ -155,8 +157,11 @@ class TestRecorderRuns(unittest.TestCase):
def test_states_from_native_invalid_entity_id(): def test_states_from_native_invalid_entity_id():
"""Test loading a state from an invalid entity ID.""" """Test loading a state from an invalid entity ID."""
event = States() state = States()
event.entity_id = "test.invalid__id" state.entity_id = "test.invalid__id"
event.attributes = "{}" state.attributes = "{}"
state = event.to_native() with pytest.raises(InvalidEntityFormatError):
state = state.to_native()
state = state.to_native(validate_entity_id=False)
assert state.entity_id == "test.invalid__id" assert state.entity_id == "test.invalid__id"

View File

@ -47,7 +47,7 @@ def test_recorder_bad_execute(hass_recorder):
hass_recorder() hass_recorder()
def to_native(): def to_native(validate_entity_id=True):
"""Rasie exception.""" """Rasie exception."""
raise SQLAlchemyError() raise SQLAlchemyError()
@ -57,6 +57,6 @@ def test_recorder_bad_execute(hass_recorder):
with pytest.raises(SQLAlchemyError), patch( with pytest.raises(SQLAlchemyError), patch(
"homeassistant.components.recorder.time.sleep" "homeassistant.components.recorder.time.sleep"
) as e_mock: ) as e_mock:
util.execute((mck1,)) util.execute((mck1,), to_native=True)
assert e_mock.call_count == 2 assert e_mock.call_count == 2

View File

@ -3,8 +3,6 @@ from datetime import datetime, timedelta
import statistics import statistics
import unittest import unittest
import pytest
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.statistics.sensor import StatisticsSensor from homeassistant.components.statistics.sensor import StatisticsSensor
from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT, STATE_UNKNOWN, TEMP_CELSIUS from homeassistant.const import ATTR_UNIT_OF_MEASUREMENT, STATE_UNKNOWN, TEMP_CELSIUS
@ -17,6 +15,7 @@ from tests.common import (
get_test_home_assistant, get_test_home_assistant,
init_recorder_component, init_recorder_component,
) )
from tests.components.recorder.common import wait_recording_done
class TestStatisticsSensor(unittest.TestCase): class TestStatisticsSensor(unittest.TestCase):
@ -321,11 +320,12 @@ class TestStatisticsSensor(unittest.TestCase):
) == state.attributes.get("max_age") ) == state.attributes.get("max_age")
assert self.change_rate == state.attributes.get("change_rate") assert self.change_rate == state.attributes.get("change_rate")
@pytest.mark.skip("Flaky in CI")
def test_initialize_from_database(self): def test_initialize_from_database(self):
"""Test initializing the statistics from the database.""" """Test initializing the statistics from the database."""
# enable the recorder # enable the recorder
init_recorder_component(self.hass) init_recorder_component(self.hass)
self.hass.block_till_done()
self.hass.data[recorder.DATA_INSTANCE].block_till_done()
# store some values # store some values
for value in self.values: for value in self.values:
self.hass.states.set( self.hass.states.set(
@ -333,7 +333,7 @@ class TestStatisticsSensor(unittest.TestCase):
) )
self.hass.block_till_done() self.hass.block_till_done()
# wait for the recorder to really store the data # wait for the recorder to really store the data
self.hass.data[recorder.DATA_INSTANCE].block_till_done() wait_recording_done(self.hass)
# only now create the statistics component, so that it must read the # only now create the statistics component, so that it must read the
# data from the database # data from the database
assert setup_component( assert setup_component(
@ -357,7 +357,6 @@ class TestStatisticsSensor(unittest.TestCase):
state = self.hass.states.get("sensor.test") state = self.hass.states.get("sensor.test")
assert str(self.mean) == state.state assert str(self.mean) == state.state
@pytest.mark.skip("Flaky in CI")
def test_initialize_from_database_with_maxage(self): def test_initialize_from_database_with_maxage(self):
"""Test initializing the statistics from the database.""" """Test initializing the statistics from the database."""
mock_data = { mock_data = {
@ -381,6 +380,8 @@ class TestStatisticsSensor(unittest.TestCase):
# enable the recorder # enable the recorder
init_recorder_component(self.hass) init_recorder_component(self.hass)
self.hass.block_till_done()
self.hass.data[recorder.DATA_INSTANCE].block_till_done()
with patch( with patch(
"homeassistant.components.statistics.sensor.dt_util.utcnow", new=mock_now "homeassistant.components.statistics.sensor.dt_util.utcnow", new=mock_now
@ -397,7 +398,7 @@ class TestStatisticsSensor(unittest.TestCase):
mock_data["return_time"] += timedelta(hours=1) mock_data["return_time"] += timedelta(hours=1)
# wait for the recorder to really store the data # wait for the recorder to really store the data
self.hass.data[recorder.DATA_INSTANCE].block_till_done() wait_recording_done(self.hass)
# only now create the statistics component, so that it must read # only now create the statistics component, so that it must read
# the data from the database # the data from the database
assert setup_component( assert setup_component(