1
mirror of https://github.com/home-assistant/core synced 2024-08-06 09:34:49 +02:00

Speed up subscribing to mqtt topics on connect (#73685)

* Speed up subscribing to mqtt topics

* update tests

* Remove extra function wrapper

* Recover debug logging for subscriptions

* Small changes and test

* Update homeassistant/components/mqtt/client.py

* Update client.py

Co-authored-by: jbouwh <jan@jbsoft.nl>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
J. Nick Koston 2022-06-22 09:03:18 -05:00 committed by GitHub
parent 54591b8ca1
commit 19b2b33037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 22 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Iterable
from functools import lru_cache, partial, wraps
import inspect
from itertools import groupby
@ -430,7 +430,7 @@ class MQTT:
# Only subscribe if currently connected.
if self.connected:
self._last_subscribe = time.time()
await self._async_perform_subscription(topic, qos)
await self._async_perform_subscriptions(((topic, qos),))
@callback
def async_remove() -> None:
@ -464,16 +464,37 @@ class MQTT:
_raise_on_error(result)
await self._wait_for_mid(mid)
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
"""Perform a paho-mqtt subscription."""
async def _async_perform_subscriptions(
self, subscriptions: Iterable[tuple[str, int]]
) -> None:
"""Perform MQTT client subscriptions."""
def _process_client_subscriptions() -> list[tuple[int, int]]:
"""Initiate all subscriptions on the MQTT client and return the results."""
subscribe_result_list = []
for topic, qos in subscriptions:
result, mid = self._mqttc.subscribe(topic, qos)
subscribe_result_list.append((result, mid))
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
return subscribe_result_list
async with self._paho_lock:
result: int | None = None
result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, topic, qos
results = await self.hass.async_add_executor_job(
_process_client_subscriptions
)
_LOGGER.debug("Subscribing to %s, mid: %s", topic, mid)
_raise_on_error(result)
await self._wait_for_mid(mid)
tasks = []
errors = []
for result, mid in results:
if result == 0:
tasks.append(self._wait_for_mid(mid))
else:
errors.append(result)
if tasks:
await asyncio.gather(*tasks)
if errors:
_raise_on_errors(errors)
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code: int) -> None:
"""On connect callback.
@ -502,10 +523,16 @@ class MQTT:
# Group subscriptions to only re-subscribe once for each topic.
keyfunc = attrgetter("topic")
for topic, subs in groupby(sorted(self.subscriptions, key=keyfunc), keyfunc):
# Re-subscribe with the highest requested qos
max_qos = max(subscription.qos for subscription in subs)
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
self.hass.add_job(
self._async_perform_subscriptions,
[
# Re-subscribe with the highest requested qos
(topic, max(subscription.qos for subscription in subs))
for topic, subs in groupby(
sorted(self.subscriptions, key=keyfunc), keyfunc
)
],
)
if (
CONF_BIRTH_MESSAGE in self.conf
@ -638,15 +665,22 @@ class MQTT:
)
def _raise_on_error(result_code: int | None) -> None:
def _raise_on_errors(result_codes: Iterable[int | None]) -> None:
"""Raise error if error result."""
# pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt
if result_code is not None and result_code != 0:
raise HomeAssistantError(
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
)
if messages := [
mqtt.error_string(result_code)
for result_code in result_codes
if result_code != 0
]:
raise HomeAssistantError(f"Error talking to MQTT: {', '.join(messages)}")
def _raise_on_error(result_code: int | None) -> None:
"""Raise error if error result."""
_raise_on_errors((result_code,))
def _matcher_for_topic(subscription: str) -> Any:

View File

@ -1312,6 +1312,20 @@ async def test_publish_error(hass, caplog):
assert "Failed to connect to MQTT server: Out of memory." in caplog.text
async def test_subscribe_error(
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
):
"""Test publish error."""
await mqtt_mock_entry_no_yaml_config()
mqtt_client_mock.on_connect(mqtt_client_mock, None, None, 0)
await hass.async_block_till_done()
with pytest.raises(HomeAssistantError):
# simulate client is not connected error before subscribing
mqtt_client_mock.subscribe.side_effect = lambda *args: (4, None)
await mqtt.async_subscribe(hass, "some-topic", lambda *args: 0)
await hass.async_block_till_done()
async def test_handle_message_callback(
hass, caplog, mqtt_mock_entry_no_yaml_config, mqtt_client_mock
):
@ -1424,6 +1438,7 @@ async def test_setup_mqtt_client_protocol(hass):
@patch("homeassistant.components.mqtt.client.TIMEOUT_ACK", 0.2)
@patch("homeassistant.components.mqtt.PLATFORMS", [])
async def test_handle_mqtt_timeout_on_callback(hass, caplog):
"""Test publish without receiving an ACK callback."""
mid = 0
@ -1764,9 +1779,12 @@ async def test_mqtt_subscribes_topics_on_connect(
assert mqtt_client_mock.disconnect.call_count == 0
expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1}
calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls}
assert calls == expected
assert len(hass.add_job.mock_calls) == 1
assert set(hass.add_job.mock_calls[0][1][1]) == {
("home/sensor", 2),
("still/pending", 1),
("topic/test", 0),
}
async def test_setup_entry_with_config_override(