1
mirror of https://github.com/home-assistant/core synced 2024-07-18 12:02:20 +02:00

Use StaticPool for recorder and NullPool for all other threads with sqlite3 (#49693)

This commit is contained in:
J. Nick Koston 2021-04-26 19:20:31 -10:00 committed by GitHub
parent d9714e6b79
commit b27e9e376d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 86 additions and 6 deletions

View File

@ -43,6 +43,7 @@ import homeassistant.util.dt as dt_util
from . import migration, purge
from .const import CONF_DB_INTEGRITY_CHECK, DATA_INSTANCE, DOMAIN, SQLITE_URL_PREFIX
from .models import Base, Events, RecorderRuns, States
from .pool import RecorderPool
from .util import (
dburl_to_path,
end_incomplete_runs,
@ -783,6 +784,8 @@ class Recorder(threading.Thread):
kwargs["connect_args"] = {"check_same_thread": False}
kwargs["poolclass"] = StaticPool
kwargs["pool_reset_on_return"] = None
elif self.db_url.startswith(SQLITE_URL_PREFIX):
kwargs["poolclass"] = RecorderPool
else:
kwargs["echo"] = False

View File

@ -0,0 +1,34 @@
"""A pool for sqlite connections."""
import threading
from sqlalchemy.pool import NullPool, StaticPool
class RecorderPool(StaticPool, NullPool):
"""A hybird of NullPool and StaticPool.
When called from the creating thread acts like StaticPool
When called from any other thread, acts like NullPool
"""
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
"""Create the pool."""
self._tid = threading.current_thread().ident
StaticPool.__init__(self, *args, **kw)
def _do_return_conn(self, conn):
if threading.current_thread().ident == self._tid:
return super()._do_return_conn(conn)
conn.close()
def dispose(self):
"""Dispose of the connection."""
if threading.current_thread().ident == self._tid:
return super().dispose()
def _do_get(self):
if threading.current_thread().ident == self._tid:
return super()._do_get()
return super( # pylint: disable=bad-super-call
NullPool, self
)._create_connection()

View File

@ -1,9 +1,10 @@
"""The tests for the Recorder component."""
# pylint: disable=protected-access
from datetime import datetime, timedelta
import sqlite3
from unittest.mock import patch
from sqlalchemy.exc import OperationalError, SQLAlchemyError
from sqlalchemy.exc import DatabaseError, OperationalError, SQLAlchemyError
from homeassistant.components import recorder
from homeassistant.components.recorder import (
@ -885,6 +886,9 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
hass.states.async_set("test.lost", "on", {})
sqlite3_exception = DatabaseError("statement", {}, [])
sqlite3_exception.__cause__ = sqlite3.DatabaseError()
with patch.object(
hass.data[DATA_INSTANCE].event_session,
"close",
@ -894,11 +898,16 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
await hass.async_add_executor_job(corrupt_db_file, test_db_file)
await async_wait_recording_done_without_instance(hass)
# This state will not be recorded because
# the database corruption will be discovered
# and we will have to rollback to recover
hass.states.async_set("test.one", "off", {})
await async_wait_recording_done_without_instance(hass)
with patch.object(
hass.data[DATA_INSTANCE].event_session,
"commit",
side_effect=[sqlite3_exception, None],
):
# This state will not be recorded because
# the database corruption will be discovered
# and we will have to rollback to recover
hass.states.async_set("test.one", "off", {})
await async_wait_recording_done_without_instance(hass)
assert "Unrecoverable sqlite3 database corruption detected" in caplog.text
assert "The system will rename the corrupt database file" in caplog.text

View File

@ -0,0 +1,34 @@
"""Test pool."""
import threading
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from homeassistant.components.recorder.pool import RecorderPool
def test_recorder_pool():
"""Test RecorderPool gives the same connection in the creating thread."""
engine = create_engine("sqlite://", poolclass=RecorderPool)
get_session = sessionmaker(bind=engine)
connections = []
def _get_connection_twice():
session = get_session()
connections.append(session.connection().connection.connection)
session.close()
session = get_session()
connections.append(session.connection().connection.connection)
session.close()
_get_connection_twice()
assert connections[0] == connections[1]
new_thread = threading.Thread(target=_get_connection_twice)
new_thread.start()
new_thread.join()
assert connections[2] != connections[3]