From b27e9e376dffd4d26fc86914e83d09e59e46c3a2 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 26 Apr 2021 19:20:31 -1000 Subject: [PATCH] Use StaticPool for recorder and NullPool for all other threads with sqlite3 (#49693) --- homeassistant/components/recorder/__init__.py | 3 ++ homeassistant/components/recorder/pool.py | 34 +++++++++++++++++++ tests/components/recorder/test_init.py | 21 ++++++++---- tests/components/recorder/test_pool.py | 34 +++++++++++++++++++ 4 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 homeassistant/components/recorder/pool.py create mode 100644 tests/components/recorder/test_pool.py diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 9e9592f86870..b4be8852f551 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util from . import migration, purge from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX from .models import Base, Events, RecorderRuns, States +from .pool import RecorderPool from .util import ( dburl_to_path, end_incomplete_runs, @@ -783,6 +784,8 @@ class Recorder(threading.Thread): kwargs["connect_args"] = {"check_same_thread": False} kwargs["poolclass"] = StaticPool kwargs["pool_reset_on_return"] = None + elif self.db_url.startswith(SQLITE_URL_PREFIX): + kwargs["poolclass"] = RecorderPool else: kwargs["echo"] = False diff --git a/homeassistant/components/recorder/pool.py b/homeassistant/components/recorder/pool.py new file mode 100644 index 000000000000..9ee89d248cce --- /dev/null +++ b/homeassistant/components/recorder/pool.py @@ -0,0 +1,34 @@ +"""A pool for sqlite connections.""" +import threading + +from sqlalchemy.pool import NullPool, StaticPool + + +class RecorderPool(StaticPool, NullPool): + """A hybird of NullPool and StaticPool. + + When called from the creating thread acts like StaticPool + When called from any other thread, acts like NullPool + """ + + def __init__(self, *args, **kw): # pylint: disable=super-init-not-called + """Create the pool.""" + self._tid = threading.current_thread().ident + StaticPool.__init__(self, *args, **kw) + + def _do_return_conn(self, conn): + if threading.current_thread().ident == self._tid: + return super()._do_return_conn(conn) + conn.close() + + def dispose(self): + """Dispose of the connection.""" + if threading.current_thread().ident == self._tid: + return super().dispose() + + def _do_get(self): + if threading.current_thread().ident == self._tid: + return super()._do_get() + return super( # pylint: disable=bad-super-call + NullPool, self + )._create_connection() diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index dddba971aadc..70271634ff52 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -1,9 +1,10 @@ """The tests for the Recorder component.""" # pylint: disable=protected-access from datetime import datetime, timedelta +import sqlite3 from unittest.mock import patch -from sqlalchemy.exc import OperationalError, SQLAlchemyError +from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError from homeassistant.components import recorder from homeassistant.components.recorder import ( @@ -885,6 +886,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog): hass.states.async_set("test.lost", "on", {}) + sqlite3_exception = DatabaseError("statement", {}, []) + sqlite3_exception.__cause__ = sqlite3.DatabaseError() + with patch.object( hass.data[DATA_INSTANCE].event_session, "close", @@ -894,11 +898,16 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog): await hass.async_add_executor_job(corrupt_db_file, test_db_file) await async_wait_recording_done_without_instance(hass) - # This state will not be recorded because - # the database corruption will be discovered - # and we will have to rollback to recover - hass.states.async_set("test.one", "off", {}) - await async_wait_recording_done_without_instance(hass) + with patch.object( + hass.data[DATA_INSTANCE].event_session, + "commit", + side_effect=[sqlite3_exception, None], + ): + # This state will not be recorded because + # the database corruption will be discovered + # and we will have to rollback to recover + hass.states.async_set("test.one", "off", {}) + await async_wait_recording_done_without_instance(hass) assert "Unrecoverable sqlite3 database corruption detected" in caplog.text assert "The system will rename the corrupt database file" in caplog.text diff --git a/tests/components/recorder/test_pool.py b/tests/components/recorder/test_pool.py new file mode 100644 index 000000000000..e59dc18fc8bc --- /dev/null +++ b/tests/components/recorder/test_pool.py @@ -0,0 +1,34 @@ +"""Test pool.""" +import threading + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from homeassistant.components.recorder.pool import RecorderPool + + +def test_recorder_pool(): + """Test RecorderPool gives the same connection in the creating thread.""" + + engine = create_engine("sqlite://", poolclass=RecorderPool) + get_session = sessionmaker(bind=engine) + + connections = [] + + def _get_connection_twice(): + session = get_session() + connections.append(session.connection().connection.connection) + session.close() + + session = get_session() + connections.append(session.connection().connection.connection) + session.close() + + _get_connection_twice() + assert connections[0] == connections[1] + + new_thread = threading.Thread(target=_get_connection_twice) + new_thread.start() + new_thread.join() + + assert connections[2] != connections[3]