Complete strict typing for recorder (#71274)

* Complete strict typing for recorder

* update tests

* Update tests/components/recorder/test_migrate.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Update tests/components/recorder/test_migrate.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Remove the asserts

* remove ignore comments

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
J. Nick Koston 2022-05-04 12:22:50 -05:00 committed by GitHub
parent 13ce0a7d6a
commit eb77f8db85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 309 deletions

View File

@ -177,23 +177,7 @@ homeassistant.components.pure_energie.*
homeassistant.components.rainmachine.*
homeassistant.components.rdw.*
homeassistant.components.recollect_waste.*
homeassistant.components.recorder
homeassistant.components.recorder.const
homeassistant.components.recorder.core
homeassistant.components.recorder.backup
homeassistant.components.recorder.executor
homeassistant.components.recorder.history
homeassistant.components.recorder.models
homeassistant.components.recorder.pool
homeassistant.components.recorder.purge
homeassistant.components.recorder.repack
homeassistant.components.recorder.run_history
homeassistant.components.recorder.services
homeassistant.components.recorder.statistics
homeassistant.components.recorder.system_health
homeassistant.components.recorder.tasks
homeassistant.components.recorder.util
homeassistant.components.recorder.websocket_api
homeassistant.components.recorder.*
homeassistant.components.remote.*
homeassistant.components.renault.*
homeassistant.components.ridwell.*

View File

@ -171,7 +171,7 @@ class Recorder(threading.Thread):
self._pending_event_data: dict[str, EventData] = {}
self._pending_expunge: list[States] = []
self.event_session: Session | None = None
self.get_session: Callable[[], Session] | None = None
self._get_session: Callable[[], Session] | None = None
self._completed_first_database_setup: bool | None = None
self.async_migration_event = asyncio.Event()
self.migration_in_progress = False
@ -205,6 +205,12 @@ class Recorder(threading.Thread):
"""Return if the recorder is recording."""
return self._event_listener is not None
def get_session(self) -> Session:
"""Get a new sqlalchemy session."""
if self._get_session is None:
raise RuntimeError("The database connection has not been established")
return self._get_session()
def queue_task(self, task: RecorderTask) -> None:
"""Add a task to the recorder queue."""
self._queue.put(task)
@ -459,7 +465,7 @@ class Recorder(threading.Thread):
@callback
def _async_setup_periodic_tasks(self) -> None:
"""Prepare periodic tasks."""
if self.hass.is_stopping or not self.get_session:
if self.hass.is_stopping or not self._get_session:
# Home Assistant is shutting down
return
@ -591,7 +597,7 @@ class Recorder(threading.Thread):
while tries <= self.db_max_retries:
try:
self._setup_connection()
return migration.get_schema_version(self)
return migration.get_schema_version(self.get_session)
except Exception as err: # pylint: disable=broad-except
_LOGGER.exception(
"Error during connection setup: %s (retrying in %s seconds)",
@ -619,7 +625,9 @@ class Recorder(threading.Thread):
self.hass.add_job(self._async_migration_started)
try:
migration.migrate_schema(self, current_version)
migration.migrate_schema(
self.hass, self.engine, self.get_session, current_version
)
except exc.DatabaseError as err:
if self._handle_database_error(err):
return True
@ -896,7 +904,6 @@ class Recorder(threading.Thread):
def _open_event_session(self) -> None:
"""Open the event session."""
assert self.get_session is not None
self.event_session = self.get_session()
self.event_session.expire_on_commit = False
@ -1011,7 +1018,7 @@ class Recorder(threading.Thread):
sqlalchemy_event.listen(self.engine, "connect", setup_recorder_connection)
Base.metadata.create_all(self.engine)
self.get_session = scoped_session(sessionmaker(bind=self.engine, future=True))
self._get_session = scoped_session(sessionmaker(bind=self.engine, future=True))
_LOGGER.debug("Connected to recorder database")
def _close_connection(self) -> None:
@ -1019,11 +1026,10 @@ class Recorder(threading.Thread):
assert self.engine is not None
self.engine.dispose()
self.engine = None
self.get_session = None
self._get_session = None
def _setup_run(self) -> None:
"""Log the start of the current run and schedule any needed jobs."""
assert self.get_session is not None
with session_scope(session=self.get_session()) as session:
end_incomplete_runs(session, self.run_history.recording_start)
self.run_history.start(session)

View File

@ -1,11 +1,13 @@
"""Schema migration helpers."""
from collections.abc import Callable, Iterable
import contextlib
from datetime import timedelta
import logging
from typing import Any
from typing import cast
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import (
DatabaseError,
InternalError,
@ -13,9 +15,12 @@ from sqlalchemy.exc import (
ProgrammingError,
SQLAlchemyError,
)
from sqlalchemy.orm.session import Session
from sqlalchemy.schema import AddConstraint, DropConstraint
from sqlalchemy.sql.expression import true
from homeassistant.core import HomeAssistant
from .models import (
SCHEMA_VERSION,
TABLE_STATES,
@ -33,7 +38,7 @@ from .util import session_scope
_LOGGER = logging.getLogger(__name__)
def raise_if_exception_missing_str(ex, match_substrs):
def raise_if_exception_missing_str(ex: Exception, match_substrs: Iterable[str]) -> None:
"""Raise an exception if the exception and cause do not contain the match substrs."""
lower_ex_strs = [str(ex).lower(), str(ex.__cause__).lower()]
for str_sub in match_substrs:
@ -44,10 +49,9 @@ def raise_if_exception_missing_str(ex, match_substrs):
raise ex
def get_schema_version(instance: Any) -> int:
def get_schema_version(session_maker: Callable[[], Session]) -> int:
"""Get the schema version."""
assert instance.get_session is not None
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
res = (
session.query(SchemaChanges)
.order_by(SchemaChanges.change_id.desc())
@ -61,7 +65,7 @@ def get_schema_version(instance: Any) -> int:
"No schema version found. Inspected version: %s", current_version
)
return current_version
return cast(int, current_version)
def schema_is_current(current_version: int) -> bool:
@ -69,21 +73,27 @@ def schema_is_current(current_version: int) -> bool:
return current_version == SCHEMA_VERSION
def migrate_schema(instance: Any, current_version: int) -> None:
def migrate_schema(
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
current_version: int,
) -> None:
"""Check if the schema needs to be upgraded."""
assert instance.get_session is not None
_LOGGER.warning("Database is about to upgrade. Schema version: %s", current_version)
for version in range(current_version, SCHEMA_VERSION):
new_version = version + 1
_LOGGER.info("Upgrading recorder db schema to version %s", new_version)
_apply_update(instance, new_version, current_version)
with session_scope(session=instance.get_session()) as session:
_apply_update(hass, engine, session_maker, new_version, current_version)
with session_scope(session=session_maker()) as session:
session.add(SchemaChanges(schema_version=new_version))
_LOGGER.info("Upgrade to version %s done", new_version)
def _create_index(instance, table_name, index_name):
def _create_index(
session_maker: Callable[[], Session], table_name: str, index_name: str
) -> None:
"""Create an index for the specified table.
The index name should match the name given for the index
@ -104,7 +114,7 @@ def _create_index(instance, table_name, index_name):
"be patient!",
index_name,
)
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
index.create(connection)
@ -117,7 +127,9 @@ def _create_index(instance, table_name, index_name):
_LOGGER.debug("Finished creating %s", index_name)
def _drop_index(instance, table_name, index_name):
def _drop_index(
session_maker: Callable[[], Session], table_name: str, index_name: str
) -> None:
"""Drop an index from a specified table.
There is no universal way to do something like `DROP INDEX IF EXISTS`
@ -132,7 +144,7 @@ def _drop_index(instance, table_name, index_name):
success = False
# Engines like DB2/Oracle
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(text(f"DROP INDEX {index_name}"))
@ -143,7 +155,7 @@ def _drop_index(instance, table_name, index_name):
# Engines like SQLite, SQL Server
if not success:
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(
@ -160,7 +172,7 @@ def _drop_index(instance, table_name, index_name):
if not success:
# Engines like MySQL, MS Access
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(
@ -194,7 +206,9 @@ def _drop_index(instance, table_name, index_name):
)
def _add_columns(instance, table_name, columns_def):
def _add_columns(
session_maker: Callable[[], Session], table_name: str, columns_def: list[str]
) -> None:
"""Add columns to a table."""
_LOGGER.warning(
"Adding columns %s to table %s. Note: this can take several "
@ -206,7 +220,7 @@ def _add_columns(instance, table_name, columns_def):
columns_def = [f"ADD {col_def}" for col_def in columns_def]
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(
@ -223,7 +237,7 @@ def _add_columns(instance, table_name, columns_def):
_LOGGER.info("Unable to use quick column add. Adding 1 by 1")
for column_def in columns_def:
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(
@ -242,7 +256,12 @@ def _add_columns(instance, table_name, columns_def):
)
def _modify_columns(instance, engine, table_name, columns_def):
def _modify_columns(
session_maker: Callable[[], Session],
engine: Engine,
table_name: str,
columns_def: list[str],
) -> None:
"""Modify columns in a table."""
if engine.dialect.name == "sqlite":
_LOGGER.debug(
@ -274,7 +293,7 @@ def _modify_columns(instance, engine, table_name, columns_def):
else:
columns_def = [f"MODIFY {col_def}" for col_def in columns_def]
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(
@ -289,7 +308,7 @@ def _modify_columns(instance, engine, table_name, columns_def):
_LOGGER.info("Unable to use quick column modify. Modifying 1 by 1")
for column_def in columns_def:
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(
@ -305,7 +324,9 @@ def _modify_columns(instance, engine, table_name, columns_def):
)
def _update_states_table_with_foreign_key_options(instance, engine):
def _update_states_table_with_foreign_key_options(
session_maker: Callable[[], Session], engine: Engine
) -> None:
"""Add the options to foreign key constraints."""
inspector = sqlalchemy.inspect(engine)
alters = []
@ -333,7 +354,7 @@ def _update_states_table_with_foreign_key_options(instance, engine):
)
for alter in alters:
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(DropConstraint(alter["old_fk"]))
@ -346,7 +367,9 @@ def _update_states_table_with_foreign_key_options(instance, engine):
)
def _drop_foreign_key_constraints(instance, engine, table, columns):
def _drop_foreign_key_constraints(
session_maker: Callable[[], Session], engine: Engine, table: str, columns: list[str]
) -> None:
"""Drop foreign key constraints for a table on specific columns."""
inspector = sqlalchemy.inspect(engine)
drops = []
@ -364,7 +387,7 @@ def _drop_foreign_key_constraints(instance, engine, table, columns):
)
for drop in drops:
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
try:
connection = session.connection()
connection.execute(DropConstraint(drop))
@ -376,19 +399,24 @@ def _drop_foreign_key_constraints(instance, engine, table, columns):
)
def _apply_update(instance, new_version, old_version): # noqa: C901
def _apply_update( # noqa: C901
hass: HomeAssistant,
engine: Engine,
session_maker: Callable[[], Session],
new_version: int,
old_version: int,
) -> None:
"""Perform operations to bring schema up to date."""
engine = instance.engine
dialect = engine.dialect.name
big_int = "INTEGER(20)" if dialect == "mysql" else "INTEGER"
if new_version == 1:
_create_index(instance, "events", "ix_events_time_fired")
_create_index(session_maker, "events", "ix_events_time_fired")
elif new_version == 2:
# Create compound start/end index for recorder_runs
_create_index(instance, "recorder_runs", "ix_recorder_runs_start_end")
_create_index(session_maker, "recorder_runs", "ix_recorder_runs_start_end")
# Create indexes for states
_create_index(instance, "states", "ix_states_last_updated")
_create_index(session_maker, "states", "ix_states_last_updated")
elif new_version == 3:
# There used to be a new index here, but it was removed in version 4.
pass
@ -398,41 +426,41 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
if old_version == 3:
# Remove index that was added in version 3
_drop_index(instance, "states", "ix_states_created_domain")
_drop_index(session_maker, "states", "ix_states_created_domain")
if old_version == 2:
# Remove index that was added in version 2
_drop_index(instance, "states", "ix_states_entity_id_created")
_drop_index(session_maker, "states", "ix_states_entity_id_created")
# Remove indexes that were added in version 0
_drop_index(instance, "states", "states__state_changes")
_drop_index(instance, "states", "states__significant_changes")
_drop_index(instance, "states", "ix_states_entity_id_created")
_drop_index(session_maker, "states", "states__state_changes")
_drop_index(session_maker, "states", "states__significant_changes")
_drop_index(session_maker, "states", "ix_states_entity_id_created")
_create_index(instance, "states", "ix_states_entity_id_last_updated")
_create_index(session_maker, "states", "ix_states_entity_id_last_updated")
elif new_version == 5:
# Create supporting index for States.event_id foreign key
_create_index(instance, "states", "ix_states_event_id")
_create_index(session_maker, "states", "ix_states_event_id")
elif new_version == 6:
_add_columns(
instance,
session_maker,
"events",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
)
_create_index(instance, "events", "ix_events_context_id")
_create_index(instance, "events", "ix_events_context_user_id")
_create_index(session_maker, "events", "ix_events_context_id")
_create_index(session_maker, "events", "ix_events_context_user_id")
_add_columns(
instance,
session_maker,
"states",
["context_id CHARACTER(36)", "context_user_id CHARACTER(36)"],
)
_create_index(instance, "states", "ix_states_context_id")
_create_index(instance, "states", "ix_states_context_user_id")
_create_index(session_maker, "states", "ix_states_context_id")
_create_index(session_maker, "states", "ix_states_context_user_id")
elif new_version == 7:
_create_index(instance, "states", "ix_states_entity_id")
_create_index(session_maker, "states", "ix_states_entity_id")
elif new_version == 8:
_add_columns(instance, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(instance, "states", ["old_state_id INTEGER"])
_create_index(instance, "events", "ix_events_context_parent_id")
_add_columns(session_maker, "events", ["context_parent_id CHARACTER(36)"])
_add_columns(session_maker, "states", ["old_state_id INTEGER"])
_create_index(session_maker, "events", "ix_events_context_parent_id")
elif new_version == 9:
# We now get the context from events with a join
# since its always there on state_changed events
@ -443,35 +471,35 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# sqlalchemy alembic to make that work
#
# no longer dropping ix_states_context_id since its recreated in 28
_drop_index(instance, "states", "ix_states_context_user_id")
_drop_index(session_maker, "states", "ix_states_context_user_id")
# This index won't be there if they were not running
# nightly but we don't treat that as a critical issue
_drop_index(instance, "states", "ix_states_context_parent_id")
_drop_index(session_maker, "states", "ix_states_context_parent_id")
# Redundant keys on composite index:
# We already have ix_states_entity_id_last_updated
_drop_index(instance, "states", "ix_states_entity_id")
_create_index(instance, "events", "ix_events_event_type_time_fired")
_drop_index(instance, "events", "ix_events_event_type")
_drop_index(session_maker, "states", "ix_states_entity_id")
_create_index(session_maker, "events", "ix_events_event_type_time_fired")
_drop_index(session_maker, "events", "ix_events_event_type")
elif new_version == 10:
# Now done in step 11
pass
elif new_version == 11:
_create_index(instance, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(instance, engine)
_create_index(session_maker, "states", "ix_states_old_state_id")
_update_states_table_with_foreign_key_options(session_maker, engine)
elif new_version == 12:
if engine.dialect.name == "mysql":
_modify_columns(instance, engine, "events", ["event_data LONGTEXT"])
_modify_columns(instance, engine, "states", ["attributes LONGTEXT"])
_modify_columns(session_maker, engine, "events", ["event_data LONGTEXT"])
_modify_columns(session_maker, engine, "states", ["attributes LONGTEXT"])
elif new_version == 13:
if engine.dialect.name == "mysql":
_modify_columns(
instance,
session_maker,
engine,
"events",
["time_fired DATETIME(6)", "created DATETIME(6)"],
)
_modify_columns(
instance,
session_maker,
engine,
"states",
[
@ -481,12 +509,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
],
)
elif new_version == 14:
_modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
_modify_columns(session_maker, engine, "events", ["event_type VARCHAR(64)"])
elif new_version == 15:
# This dropped the statistics table, done again in version 18.
pass
elif new_version == 16:
_drop_foreign_key_constraints(instance, engine, TABLE_STATES, ["old_state_id"])
_drop_foreign_key_constraints(
session_maker, engine, TABLE_STATES, ["old_state_id"]
)
elif new_version == 17:
# This dropped the statistics table, done again in version 18.
pass
@ -511,13 +541,13 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
elif new_version == 19:
# This adds the statistic runs table, insert a fake run to prevent duplicating
# statistics.
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
session.add(StatisticsRuns(start=get_start_time()))
elif new_version == 20:
# This changed the precision of statistics from float to double
if engine.dialect.name in ["mysql", "postgresql"]:
_modify_columns(
instance,
session_maker,
engine,
"statistics",
[
@ -539,7 +569,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
table,
)
with contextlib.suppress(SQLAlchemyError):
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
connection = session.connection()
connection.execute(
# Using LOCK=EXCLUSIVE to prevent the database from corrupting
@ -574,7 +604,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# Block 5-minute statistics for one hour from the last run, or it will overlap
# with existing hourly statistics. Don't block on a database with no existing
# statistics.
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
if session.query(Statistics.id).count() and (
last_run_string := session.query(
func.max(StatisticsRuns.start)
@ -590,7 +620,7 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
# When querying the database, be careful to only explicitly query for columns
# which were present in schema version 21. If querying the table, SQLAlchemy
# will refer to future columns.
with session_scope(session=instance.get_session()) as session:
with session_scope(session=session_maker()) as session:
for sum_statistic in session.query(StatisticsMeta.id).filter_by(
has_sum=true()
):
@ -617,48 +647,52 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
)
elif new_version == 23:
# Add name column to StatisticsMeta
_add_columns(instance, "statistics_meta", ["name VARCHAR(255)"])
_add_columns(session_maker, "statistics_meta", ["name VARCHAR(255)"])
elif new_version == 24:
# Recreate statistics indices to block duplicated statistics
_drop_index(instance, "statistics", "ix_statistics_statistic_id_start")
_drop_index(session_maker, "statistics", "ix_statistics_statistic_id_start")
_drop_index(
instance,
session_maker,
"statistics_short_term",
"ix_statistics_short_term_statistic_id_start",
)
try:
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
_create_index(
instance,
session_maker, "statistics", "ix_statistics_statistic_id_start"
)
_create_index(
session_maker,
"statistics_short_term",
"ix_statistics_short_term_statistic_id_start",
)
except DatabaseError:
# There may be duplicated statistics entries, delete duplicated statistics
# and try again
with session_scope(session=instance.get_session()) as session:
delete_duplicates(instance, session)
_create_index(instance, "statistics", "ix_statistics_statistic_id_start")
with session_scope(session=session_maker()) as session:
delete_duplicates(hass, session)
_create_index(
instance,
session_maker, "statistics", "ix_statistics_statistic_id_start"
)
_create_index(
session_maker,
"statistics_short_term",
"ix_statistics_short_term_statistic_id_start",
)
elif new_version == 25:
_add_columns(instance, "states", [f"attributes_id {big_int}"])
_create_index(instance, "states", "ix_states_attributes_id")
_add_columns(session_maker, "states", [f"attributes_id {big_int}"])
_create_index(session_maker, "states", "ix_states_attributes_id")
elif new_version == 26:
_create_index(instance, "statistics_runs", "ix_statistics_runs_start")
_create_index(session_maker, "statistics_runs", "ix_statistics_runs_start")
elif new_version == 27:
_add_columns(instance, "events", [f"data_id {big_int}"])
_create_index(instance, "events", "ix_events_data_id")
_add_columns(session_maker, "events", [f"data_id {big_int}"])
_create_index(session_maker, "events", "ix_events_data_id")
elif new_version == 28:
_add_columns(instance, "events", ["origin_idx INTEGER"])
_add_columns(session_maker, "events", ["origin_idx INTEGER"])
# We never use the user_id or parent_id index
_drop_index(instance, "events", "ix_events_context_user_id")
_drop_index(instance, "events", "ix_events_context_parent_id")
_drop_index(session_maker, "events", "ix_events_context_user_id")
_drop_index(session_maker, "events", "ix_events_context_parent_id")
_add_columns(
instance,
session_maker,
"states",
[
"origin_idx INTEGER",
@ -667,14 +701,14 @@ def _apply_update(instance, new_version, old_version): # noqa: C901
"context_parent_id VARCHAR(36)",
],
)
_create_index(instance, "states", "ix_states_context_id")
_create_index(session_maker, "states", "ix_states_context_id")
# Once there are no longer any state_changed events
# in the events table we can drop the index on states.event_id
else:
raise ValueError(f"No schema migration defined for version {new_version}")
def _inspect_schema_version(session):
def _inspect_schema_version(session: Session) -> int:
"""Determine the schema version by inspecting the db structure.
When the schema version is not present in the db, either db was just
@ -696,4 +730,4 @@ def _inspect_schema_version(session):
# Version 1 schema changes not found, this db needs to be migrated.
current_version = SchemaChanges(schema_version=0)
session.add(current_version)
return current_version.schema_version
return cast(int, current_version.schema_version)

View File

@ -47,7 +47,7 @@ def purge_old_data(
)
using_sqlite = instance.using_sqlite()
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
with session_scope(session=instance.get_session()) as session:
# Purge a max of MAX_ROWS_TO_PURGE, based on the oldest states or events record
(
event_ids,
@ -515,7 +515,7 @@ def _purge_filtered_events(
def purge_entity_data(instance: Recorder, entity_filter: Callable[[str], bool]) -> bool:
"""Purge states and events of specified entities."""
using_sqlite = instance.using_sqlite()
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
with session_scope(session=instance.get_session()) as session:
selected_entity_ids: list[str] = [
entity_id
for (entity_id,) in session.query(distinct(States.entity_id)).all()

View File

@ -377,7 +377,7 @@ def _delete_duplicates_from_table(
return (total_deleted_rows, all_non_identical_duplicates)
def delete_duplicates(instance: Recorder, session: Session) -> None:
def delete_duplicates(hass: HomeAssistant, session: Session) -> None:
"""Identify and delete duplicated statistics.
A backup will be made of duplicated statistics before it is deleted.
@ -391,7 +391,7 @@ def delete_duplicates(instance: Recorder, session: Session) -> None:
if non_identical_duplicates:
isotime = dt_util.utcnow().isoformat()
backup_file_name = f"deleted_statistics.{isotime}.json"
backup_path = instance.hass.config.path(STORAGE_DIR, backup_file_name)
backup_path = hass.config.path(STORAGE_DIR, backup_file_name)
os.makedirs(os.path.dirname(backup_path), exist_ok=True)
with open(backup_path, "w", encoding="utf8") as backup_file:
@ -551,7 +551,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
end = start + timedelta(minutes=5)
# Return if we already have 5-minute statistics for the requested period
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
with session_scope(session=instance.get_session()) as session:
if session.query(StatisticsRuns).filter_by(start=start).first():
_LOGGER.debug("Statistics already compiled for %s-%s", start, end)
return True
@ -578,7 +578,7 @@ def compile_statistics(instance: Recorder, start: datetime) -> bool:
# Insert collected statistics in the database
with session_scope(
session=instance.get_session(), # type: ignore[misc]
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
for stats in platform_stats:
@ -768,7 +768,7 @@ def _configured_unit(unit: str | None, units: UnitSystem) -> str | None:
def clear_statistics(instance: Recorder, statistic_ids: list[str]) -> None:
"""Clear statistics for a list of statistic_ids."""
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id.in_(statistic_ids)
).delete(synchronize_session=False)
@ -778,7 +778,7 @@ def update_statistics_metadata(
instance: Recorder, statistic_id: str, unit_of_measurement: str | None
) -> None:
"""Update statistics metadata for a statistic_id."""
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
with session_scope(session=instance.get_session()) as session:
session.query(StatisticsMeta).filter(
StatisticsMeta.statistic_id == statistic_id
).update({StatisticsMeta.unit_of_measurement: unit_of_measurement})
@ -1376,7 +1376,7 @@ def add_external_statistics(
"""Process an add_external_statistics job."""
with session_scope(
session=instance.get_session(), # type: ignore[misc]
session=instance.get_session(),
exception_filter=_filter_unique_constraint_integrity_error(instance),
) as session:
old_metadata_dict = get_metadata_with_session(
@ -1403,7 +1403,7 @@ def adjust_statistics(
) -> bool:
"""Process an add_statistics job."""
with session_scope(session=instance.get_session()) as session: # type: ignore[misc]
with session_scope(session=instance.get_session()) as session:
metadata = get_metadata_with_session(
instance.hass, session, statistic_ids=(statistic_id,)
)

View File

@ -65,8 +65,6 @@ class PurgeTask(RecorderTask):
def run(self, instance: Recorder) -> None:
"""Purge the database."""
assert instance.get_session is not None
if purge.purge_old_data(
instance, self.purge_before, self.repack, self.apply_filter
):

178
mypy.ini
View File

@ -1710,183 +1710,7 @@ no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.const]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.core]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.backup]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.executor]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.history]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.models]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.pool]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.purge]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.repack]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.run_history]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.services]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.statistics]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.system_health]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.tasks]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.util]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.recorder.websocket_api]
[mypy-homeassistant.components.recorder.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true

View File

@ -138,6 +138,8 @@ async def test_shutdown_closes_connections(hass, recorder_mock):
await hass.async_block_till_done()
assert len(pool.shutdown.mock_calls) == 1
with pytest.raises(RuntimeError):
assert instance.get_session()
async def test_state_gets_saved_when_set_before_start_event(

View File

@ -60,9 +60,12 @@ async def test_schema_update_calls(hass):
await async_wait_recording_done(hass)
assert recorder.util.async_migration_in_progress(hass) is False
instance = recorder.get_instance(hass)
engine = instance.engine
session_maker = instance.get_session
update.assert_has_calls(
[
call(hass.data[DATA_INSTANCE], version + 1, 0)
call(hass, engine, session_maker, version + 1, 0)
for version in range(0, models.SCHEMA_VERSION)
]
)
@ -327,10 +330,10 @@ async def test_schema_migrate(hass, start_version):
assert recorder.util.async_migration_in_progress(hass) is not True
def test_invalid_update():
def test_invalid_update(hass):
"""Test that an invalid new version raises an exception."""
with pytest.raises(ValueError):
migration._apply_update(Mock(), -1, 0)
migration._apply_update(hass, Mock(), Mock(), -1, 0)
@pytest.mark.parametrize(
@ -351,7 +354,9 @@ def test_modify_column(engine_type, substr):
instance.get_session = Mock(return_value=session)
engine = Mock()
engine.dialect.name = engine_type
migration._modify_columns(instance, engine, "events", ["event_type VARCHAR(64)"])
migration._modify_columns(
instance.get_session, engine, "events", ["event_type VARCHAR(64)"]
)
if substr:
assert substr in connection.execute.call_args[0][0].text
else:
@ -365,8 +370,12 @@ def test_forgiving_add_column():
session.execute(text("CREATE TABLE hello (id int)"))
instance = Mock()
instance.get_session = Mock(return_value=session)
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
migration._add_columns(instance, "hello", ["context_id CHARACTER(36)"])
migration._add_columns(
instance.get_session, "hello", ["context_id CHARACTER(36)"]
)
migration._add_columns(
instance.get_session, "hello", ["context_id CHARACTER(36)"]
)
def test_forgiving_add_index():
@ -376,7 +385,7 @@ def test_forgiving_add_index():
with Session(engine) as session:
instance = Mock()
instance.get_session = Mock(return_value=session)
migration._create_index(instance, "states", "ix_states_context_id")
migration._create_index(instance.get_session, "states", "ix_states_context_id")
@pytest.mark.parametrize(

View File

@ -740,7 +740,7 @@ def test_delete_duplicates_no_duplicates(hass_recorder, caplog):
hass = hass_recorder()
wait_recording_done(hass)
with session_scope(hass=hass) as session:
delete_duplicates(hass.data[DATA_INSTANCE], session)
delete_duplicates(hass, session)
assert "duplicated statistics rows" not in caplog.text
assert "Found non identical" not in caplog.text
assert "Found duplicated" not in caplog.text