1
mirror of https://github.com/home-assistant/core synced 2024-08-31 05:57:13 +02:00

Use StaticPool for recorder and NullPool for all other threads with sqlite3 (#49693)

This commit is contained in:
J. Nick Koston 2021-04-26 19:20:31 -10:00 committed by GitHub
parent d9714e6b79
commit b27e9e376d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 6 deletions

View File

@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util
from . import migration, purge from . import migration, purge
from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX
from .models import Base, Events, RecorderRuns, States from .models import Base, Events, RecorderRuns, States
from .pool import RecorderPool
from .util import ( from .util import (
dburl_to_path, dburl_to_path,
end_incomplete_runs, end_incomplete_runs,
@ -783,6 +784,8 @@ class Recorder(threading.Thread):
kwargs["connect_args"] = {"check_same_thread": False} kwargs["connect_args"] = {"check_same_thread": False}
kwargs["poolclass"] = StaticPool kwargs["poolclass"] = StaticPool
kwargs["pool_reset_on_return"] = None kwargs["pool_reset_on_return"] = None
elif self.db_url.startswith(SQLITE_URL_PREFIX):
kwargs["poolclass"] = RecorderPool
else: else:
kwargs["echo"] = False kwargs["echo"] = False

View File

@ -0,0 +1,34 @@
"""A pool for sqlite connections."""
import threading
from sqlalchemy.pool import NullPool, StaticPool
class RecorderPool(StaticPool, NullPool):
"""A hybird of NullPool and StaticPool.
When called from the creating thread acts like StaticPool
When called from any other thread, acts like NullPool
"""
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
"""Create the pool."""
self._tid = threading.current_thread().ident
StaticPool.__init__(self, *args, **kw)
def _do_return_conn(self, conn):
if threading.current_thread().ident == self._tid:
return super()._do_return_conn(conn)
conn.close()
def dispose(self):
"""Dispose of the connection."""
if threading.current_thread().ident == self._tid:
return super().dispose()
def _do_get(self):
if threading.current_thread().ident == self._tid:
return super()._do_get()
return super( # pylint: disable=bad-super-call
NullPool, self
)._create_connection()

View File

@ -1,9 +1,10 @@
"""The tests for the Recorder component.""" """The tests for the Recorder component."""
# pylint: disable=protected-access # pylint: disable=protected-access
from datetime import datetime, timedelta from datetime import datetime, timedelta
import sqlite3
from unittest.mock import patch from unittest.mock import patch
from sqlalchemy.exc import OperationalError, SQLAlchemyError from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
from homeassistant.components import recorder from homeassistant.components import recorder
from homeassistant.components.recorder import ( from homeassistant.components.recorder import (
@ -885,6 +886,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
hass.states.async_set("test.lost", "on", {}) hass.states.async_set("test.lost", "on", {})
sqlite3_exception = DatabaseError("statement", {}, [])
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
with patch.object( with patch.object(
hass.data[DATA_INSTANCE].event_session, hass.data[DATA_INSTANCE].event_session,
"close", "close",
@ -894,11 +898,16 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
await hass.async_add_executor_job(corrupt_db_file, test_db_file) await hass.async_add_executor_job(corrupt_db_file, test_db_file)
await async_wait_recording_done_without_instance(hass) await async_wait_recording_done_without_instance(hass)
# This state will not be recorded because with patch.object(
# the database corruption will be discovered hass.data[DATA_INSTANCE].event_session,
# and we will have to rollback to recover "commit",
hass.states.async_set("test.one", "off", {}) side_effect=[sqlite3_exception, None],
await async_wait_recording_done_without_instance(hass) ):
# This state will not be recorded because
# the database corruption will be discovered
# and we will have to rollback to recover
hass.states.async_set("test.one", "off", {})
await async_wait_recording_done_without_instance(hass)
assert "Unrecoverable sqlite3 database corruption detected" in caplog.text assert "Unrecoverable sqlite3 database corruption detected" in caplog.text
assert "The system will rename the corrupt database file" in caplog.text assert "The system will rename the corrupt database file" in caplog.text

View File

@ -0,0 +1,34 @@
"""Test pool."""
import threading
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from homeassistant.components.recorder.pool import RecorderPool
def test_recorder_pool():
"""Test RecorderPool gives the same connection in the creating thread."""
engine = create_engine("sqlite://", poolclass=RecorderPool)
get_session = sessionmaker(bind=engine)
connections = []
def _get_connection_twice():
session = get_session()
connections.append(session.connection().connection.connection)
session.close()
session = get_session()
connections.append(session.connection().connection.connection)
session.close()
_get_connection_twice()
assert connections[0] == connections[1]
new_thread = threading.Thread(target=_get_connection_twice)
new_thread.start()
new_thread.join()
assert connections[2] != connections[3]