Convert MQTT to use asyncio (#115910)

This commit is contained in:
J. Nick Koston 2024-04-21 22:33:58 +02:00 committed by GitHub
parent 5a24690d79
commit 423544401e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 464 additions and 90 deletions

View File

@ -265,7 +265,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
conf: dict[str, Any]
mqtt_data: MqttData
async def _setup_client() -> tuple[MqttData, dict[str, Any]]:
async def _setup_client(
client_available: asyncio.Future[bool],
) -> tuple[MqttData, dict[str, Any]]:
"""Set up the MQTT client."""
# Fetch configuration
conf = dict(entry.data)
@ -294,7 +296,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
entry.add_update_listener(_async_config_entry_updated)
)
await mqtt_data.client.async_connect()
await mqtt_data.client.async_connect(client_available)
return (mqtt_data, conf)
client_available: asyncio.Future[bool]
@ -303,13 +305,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
else:
client_available = hass.data[DATA_MQTT_AVAILABLE]
setup_ok: bool = False
try:
mqtt_data, conf = await _setup_client()
setup_ok = True
finally:
if not client_available.done():
client_available.set_result(setup_ok)
mqtt_data, conf = await _setup_client(client_available)
async def async_publish_service(call: ServiceCall) -> None:
"""Handle MQTT publish service calls."""

View File

@ -3,12 +3,14 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
import contextlib
from dataclasses import dataclass
from functools import lru_cache
from functools import lru_cache, partial
from itertools import chain, groupby
import logging
from operator import attrgetter
import socket
import ssl
import time
from typing import TYPE_CHECKING, Any
@ -35,7 +37,7 @@ from homeassistant.core import (
callback,
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util
@ -92,6 +94,9 @@ INITIAL_SUBSCRIBE_COOLDOWN = 1.0
SUBSCRIBE_COOLDOWN = 0.1
UNSUBSCRIBE_COOLDOWN = 0.1
TIMEOUT_ACK = 10
RECONNECT_INTERVAL_SECONDS = 10
SocketType = socket.socket | ssl.SSLSocket | Any
SubscribePayloadType = str | bytes # Only bytes if encoding is None
@ -258,7 +263,9 @@ class MqttClientSetup:
# However, that feature is not mandatory so we generate our own.
client_id = mqtt.base62(uuid.uuid4().int, padding=22)
transport = config.get(CONF_TRANSPORT, DEFAULT_TRANSPORT)
self._client = mqtt.Client(client_id, protocol=proto, transport=transport)
self._client = mqtt.Client(
client_id, protocol=proto, transport=transport, reconnect_on_failure=False
)
# Enable logging
self._client.enable_logger()
@ -404,12 +411,17 @@ class MQTT:
self._ha_started = asyncio.Event()
self._cleanup_on_unload: list[Callable[[], None]] = []
self._paho_lock = asyncio.Lock() # Prevents parallel calls to the MQTT client
self._connection_lock = asyncio.Lock()
self._pending_operations: dict[int, asyncio.Event] = {}
self._pending_operations_condition = asyncio.Condition()
self._subscribe_debouncer = EnsureJobAfterCooldown(
INITIAL_SUBSCRIBE_COOLDOWN, self._async_perform_subscriptions
)
self._misc_task: asyncio.Task | None = None
self._reconnect_task: asyncio.Task | None = None
self._should_reconnect: bool = True
self._available_future: asyncio.Future[bool] | None = None
self._max_qos: dict[str, int] = {} # topic, max qos
self._pending_subscriptions: dict[str, int] = {} # topic, qos
self._unsubscribe_debouncer = EnsureJobAfterCooldown(
@ -456,25 +468,140 @@ class MQTT:
while self._cleanup_on_unload:
self._cleanup_on_unload.pop()()
@contextlib.asynccontextmanager
async def _async_connect_in_executor(self) -> AsyncGenerator[None, None]:
# While we are connecting in the executor we need to
# handle on_socket_open and on_socket_register_write
# in the executor as well.
mqttc = self._mqttc
try:
mqttc.on_socket_open = self._on_socket_open
mqttc.on_socket_register_write = self._on_socket_register_write
yield
finally:
# Once the executor job is done, we can switch back to
# handling these in the event loop.
mqttc.on_socket_open = self._async_on_socket_open
mqttc.on_socket_register_write = self._async_on_socket_register_write
def init_client(self) -> None:
"""Initialize paho client."""
self._mqttc = MqttClientSetup(self.conf).client
self._mqttc.on_connect = self._mqtt_on_connect
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message
self._mqttc.on_publish = self._mqtt_on_callback
self._mqttc.on_subscribe = self._mqtt_on_callback
self._mqttc.on_unsubscribe = self._mqtt_on_callback
mqttc = MqttClientSetup(self.conf).client
# on_socket_unregister_write and _async_on_socket_close
# are only ever called in the event loop
mqttc.on_socket_close = self._async_on_socket_close
mqttc.on_socket_unregister_write = self._async_on_socket_unregister_write
# These will be called in the event loop
mqttc.on_connect = self._async_mqtt_on_connect
mqttc.on_disconnect = self._async_mqtt_on_disconnect
mqttc.on_message = self._async_mqtt_on_message
mqttc.on_publish = self._async_mqtt_on_callback
mqttc.on_subscribe = self._async_mqtt_on_callback
mqttc.on_unsubscribe = self._async_mqtt_on_callback
if will := self.conf.get(CONF_WILL_MESSAGE, DEFAULT_WILL):
will_message = PublishMessage(**will)
self._mqttc.will_set(
mqttc.will_set(
topic=will_message.topic,
payload=will_message.payload,
qos=will_message.qos,
retain=will_message.retain,
)
self._mqttc = mqttc
async def _misc_loop(self) -> None:
"""Start the MQTT client misc loop."""
# pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
while self._mqttc.loop_misc() == mqtt.MQTT_ERR_SUCCESS:
await asyncio.sleep(1)
@callback
def _async_reader_callback(self, client: mqtt.Client) -> None:
"""Handle reading data from the socket."""
if (status := client.loop_read()) != 0:
self._async_on_disconnect(status)
@callback
def _async_start_misc_loop(self) -> None:
"""Start the misc loop."""
if self._misc_task is None or self._misc_task.done():
_LOGGER.debug("%s: Starting client misc loop", self.config_entry.title)
self._misc_task = self.config_entry.async_create_background_task(
self.hass, self._misc_loop(), name="mqtt misc loop"
)
def _on_socket_open(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Handle socket open."""
self.loop.call_soon_threadsafe(
self._async_on_socket_open, client, userdata, sock
)
@callback
def _async_on_socket_open(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Handle socket open."""
fileno = sock.fileno()
_LOGGER.debug("%s: connection opened %s", self.config_entry.title, fileno)
if fileno > -1:
self.loop.add_reader(sock, partial(self._async_reader_callback, client))
self._async_start_misc_loop()
@callback
def _async_on_socket_close(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Handle socket close."""
fileno = sock.fileno()
_LOGGER.debug("%s: connection closed %s", self.config_entry.title, fileno)
# If socket close is called before the connect
# result is set make sure the first connection result is set
self._async_connection_result(False)
if fileno > -1:
self.loop.remove_reader(sock)
if self._misc_task is not None and not self._misc_task.done():
self._misc_task.cancel()
@callback
def _async_writer_callback(self, client: mqtt.Client) -> None:
"""Handle writing data to the socket."""
if (status := client.loop_write()) != 0:
self._async_on_disconnect(status)
def _on_socket_register_write(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Register the socket for writing."""
self.loop.call_soon_threadsafe(
self._async_on_socket_register_write, client, None, sock
)
@callback
def _async_on_socket_register_write(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Register the socket for writing."""
fileno = sock.fileno()
_LOGGER.debug("%s: register write %s", self.config_entry.title, fileno)
if fileno > -1:
self.loop.add_writer(sock, partial(self._async_writer_callback, client))
@callback
def _async_on_socket_unregister_write(
self, client: mqtt.Client, userdata: Any, sock: SocketType
) -> None:
"""Unregister the socket for writing."""
fileno = sock.fileno()
_LOGGER.debug("%s: unregister write %s", self.config_entry.title, fileno)
if fileno > -1:
self.loop.remove_writer(sock)
def _is_active_subscription(self, topic: str) -> bool:
"""Check if a topic has an active subscription."""
return topic in self._simple_subscriptions or any(
@ -485,10 +612,7 @@ class MQTT:
self, topic: str, payload: PublishPayloadType, qos: int, retain: bool
) -> None:
"""Publish a MQTT message."""
async with self._paho_lock:
msg_info = await self.hass.async_add_executor_job(
self._mqttc.publish, topic, payload, qos, retain
)
msg_info = self._mqttc.publish(topic, payload, qos, retain)
_LOGGER.debug(
"Transmitting%s message on %s: '%s', mid: %s, qos: %s",
" retained" if retain else "",
@ -500,37 +624,71 @@ class MQTT:
_raise_on_error(msg_info.rc)
await self._wait_for_mid(msg_info.mid)
async def async_connect(self) -> None:
async def async_connect(self, client_available: asyncio.Future[bool]) -> None:
"""Connect to the host. Does not process messages yet."""
# pylint: disable-next=import-outside-toplevel
import paho.mqtt.client as mqtt
result: int | None = None
self._available_future = client_available
self._should_reconnect = True
try:
result = await self.hass.async_add_executor_job(
self._mqttc.connect,
self.conf[CONF_BROKER],
self.conf.get(CONF_PORT, DEFAULT_PORT),
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
)
async with self._connection_lock, self._async_connect_in_executor():
result = await self.hass.async_add_executor_job(
self._mqttc.connect,
self.conf[CONF_BROKER],
self.conf.get(CONF_PORT, DEFAULT_PORT),
self.conf.get(CONF_KEEPALIVE, DEFAULT_KEEPALIVE),
)
except OSError as err:
_LOGGER.error("Failed to connect to MQTT server due to exception: %s", err)
self._async_connection_result(False)
finally:
if result is not None and result != 0:
if result is not None:
_LOGGER.error(
"Failed to connect to MQTT server: %s",
mqtt.error_string(result),
)
self._async_connection_result(False)
if result is not None and result != 0:
_LOGGER.error(
"Failed to connect to MQTT server: %s", mqtt.error_string(result)
@callback
def _async_connection_result(self, connected: bool) -> None:
"""Handle a connection result."""
if self._available_future and not self._available_future.done():
self._available_future.set_result(connected)
if connected:
self._async_cancel_reconnect()
elif self._should_reconnect and not self._reconnect_task:
self._reconnect_task = self.config_entry.async_create_background_task(
self.hass, self._reconnect_loop(), "mqtt reconnect loop"
)
self._mqttc.loop_start()
@callback
def _async_cancel_reconnect(self) -> None:
"""Cancel the reconnect task."""
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None
async def _reconnect_loop(self) -> None:
"""Reconnect to the MQTT server."""
while True:
if not self.connected:
try:
async with self._connection_lock, self._async_connect_in_executor():
await self.hass.async_add_executor_job(self._mqttc.reconnect)
except OSError as err:
_LOGGER.debug(
"Error re-connecting to MQTT server due to exception: %s", err
)
await asyncio.sleep(RECONNECT_INTERVAL_SECONDS)
async def async_disconnect(self) -> None:
"""Stop the MQTT client."""
def stop() -> None:
"""Stop the MQTT client."""
# Do not disconnect, we want the broker to always publish will
self._mqttc.loop_stop()
def no_more_acks() -> bool:
"""Return False if there are unprocessed ACKs."""
return not any(not op.is_set() for op in self._pending_operations.values())
@ -549,8 +707,10 @@ class MQTT:
await self._pending_operations_condition.wait_for(no_more_acks)
# stop the MQTT loop
async with self._paho_lock:
await self.hass.async_add_executor_job(stop)
async with self._connection_lock:
self._should_reconnect = False
self._async_cancel_reconnect()
self._mqttc.disconnect()
@callback
def async_restore_tracked_subscriptions(
@ -689,11 +849,8 @@ class MQTT:
subscriptions: dict[str, int] = self._pending_subscriptions
self._pending_subscriptions = {}
async with self._paho_lock:
subscription_list = list(subscriptions.items())
result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, subscription_list
)
subscription_list = list(subscriptions.items())
result, mid = self._mqttc.subscribe(subscription_list)
for topic, qos in subscriptions.items():
_LOGGER.debug("Subscribing to %s, mid: %s, qos: %s", topic, mid, qos)
@ -712,17 +869,15 @@ class MQTT:
topics = list(self._pending_unsubscribes)
self._pending_unsubscribes = set()
async with self._paho_lock:
result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topics
)
result, mid = self._mqttc.unsubscribe(topics)
_raise_on_error(result)
for topic in topics:
_LOGGER.debug("Unsubscribing from %s, mid: %s", topic, mid)
await self._wait_for_mid(mid)
def _mqtt_on_connect(
@callback
def _async_mqtt_on_connect(
self,
_mqttc: mqtt.Client,
_userdata: None,
@ -746,7 +901,7 @@ class MQTT:
return
self.connected = True
dispatcher_send(self.hass, MQTT_CONNECTED)
async_dispatcher_send(self.hass, MQTT_CONNECTED)
_LOGGER.info(
"Connected to MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],
@ -754,7 +909,7 @@ class MQTT:
result_code,
)
self.hass.create_task(self._async_resubscribe())
self.hass.async_create_task(self._async_resubscribe())
if birth := self.conf.get(CONF_BIRTH_MESSAGE, DEFAULT_BIRTH):
@ -771,13 +926,17 @@ class MQTT:
)
birth_message = PublishMessage(**birth)
asyncio.run_coroutine_threadsafe(
publish_birth_message(birth_message), self.hass.loop
self.config_entry.async_create_background_task(
self.hass,
publish_birth_message(birth_message),
name="mqtt birth message",
)
else:
# Update subscribe cooldown period to a shorter time
self._subscribe_debouncer.set_timeout(SUBSCRIBE_COOLDOWN)
self._async_connection_result(True)
async def _async_resubscribe(self) -> None:
"""Resubscribe on reconnect."""
self._max_qos.clear()
@ -796,16 +955,6 @@ class MQTT:
)
await self._async_perform_subscriptions()
def _mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None:
"""Message received callback."""
# MQTT messages tend to be high volume,
# and since they come in via a thread and need to be processed in the event loop,
# we want to avoid hass.add_job since most of the time is spent calling
# inspect to figure out how to run the callback.
self.loop.call_soon_threadsafe(self._mqtt_handle_message, msg)
@lru_cache(None) # pylint: disable=method-cache-max-size-none
def _matching_subscriptions(self, topic: str) -> list[Subscription]:
subscriptions: list[Subscription] = []
@ -819,7 +968,9 @@ class MQTT:
return subscriptions
@callback
def _mqtt_handle_message(self, msg: mqtt.MQTTMessage) -> None:
def _async_mqtt_on_message(
self, _mqttc: mqtt.Client, _userdata: None, msg: mqtt.MQTTMessage
) -> None:
topic = msg.topic
# msg.topic is a property that decodes the topic to a string
# every time it is accessed. Save the result to avoid
@ -878,7 +1029,8 @@ class MQTT:
self.hass.async_run_hass_job(subscription.job, receive_msg)
self._mqtt_data.state_write_requests.process_write_state_requests(msg)
def _mqtt_on_callback(
@callback
def _async_mqtt_on_callback(
self,
_mqttc: mqtt.Client,
_userdata: None,
@ -890,7 +1042,7 @@ class MQTT:
# The callback signature for on_unsubscribe is different from on_subscribe
# see https://github.com/eclipse/paho.mqtt.python/issues/687
# properties and reasoncodes are not used in Home Assistant
self.hass.create_task(self._mqtt_handle_mid(mid))
self.hass.async_create_task(self._mqtt_handle_mid(mid))
async def _mqtt_handle_mid(self, mid: int) -> None:
# Create the mid event if not created, either _mqtt_handle_mid or _wait_for_mid
@ -906,7 +1058,8 @@ class MQTT:
if mid not in self._pending_operations:
self._pending_operations[mid] = asyncio.Event()
def _mqtt_on_disconnect(
@callback
def _async_mqtt_on_disconnect(
self,
_mqttc: mqtt.Client,
_userdata: None,
@ -914,8 +1067,19 @@ class MQTT:
properties: mqtt.Properties | None = None,
) -> None:
"""Disconnected callback."""
self._async_on_disconnect(result_code)
@callback
def _async_on_disconnect(self, result_code: int) -> None:
if not self.connected:
# This function is re-entrant and may be called multiple times
# when there is a broken pipe error.
return
# If disconnect is called before the connect
# result is set make sure the first connection result is set
self._async_connection_result(False)
self.connected = False
dispatcher_send(self.hass, MQTT_DISCONNECTED)
async_dispatcher_send(self.hass, MQTT_DISCONNECTED)
_LOGGER.warning(
"Disconnected from MQTT server %s:%s (%s)",
self.conf[CONF_BROKER],

View File

@ -452,7 +452,7 @@ def async_fire_mqtt_message(
mqtt_data: MqttData = hass.data["mqtt"]
assert mqtt_data.client
mqtt_data.client._mqtt_handle_message(msg)
mqtt_data.client._async_mqtt_on_message(Mock(), None, msg)
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)

View File

@ -4,17 +4,22 @@ import asyncio
from copy import deepcopy
from datetime import datetime, timedelta
import json
import socket
import ssl
from typing import Any, TypedDict
from unittest.mock import ANY, MagicMock, call, mock_open, patch
from freezegun.api import FrozenDateTimeFactory
import paho.mqtt.client as paho_mqtt
import pytest
import voluptuous as vol
from homeassistant.components import mqtt
from homeassistant.components.mqtt import debug_info
from homeassistant.components.mqtt.client import EnsureJobAfterCooldown
from homeassistant.components.mqtt.client import (
RECONNECT_INTERVAL_SECONDS,
EnsureJobAfterCooldown,
)
from homeassistant.components.mqtt.mixins import MQTT_ENTITY_DEVICE_INFO_SCHEMA
from homeassistant.components.mqtt.models import (
MessageCallbackType,
@ -146,7 +151,7 @@ async def test_mqtt_disconnects_on_home_assistant_stop(
hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
await hass.async_block_till_done()
assert mqtt_client_mock.loop_stop.call_count == 1
assert mqtt_client_mock.disconnect.call_count == 1
async def test_mqtt_await_ack_at_disconnect(
@ -161,8 +166,14 @@ async def test_mqtt_await_ack_at_disconnect(
rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect = MagicMock(return_value=0)
mock_client().publish = MagicMock(return_value=FakeInfo())
mqtt_client = mock_client.return_value
mqtt_client.connect = MagicMock(
return_value=0,
side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe(
mqtt_client.on_connect, mqtt_client, None, 0, 0, 0
),
)
mqtt_client.publish = MagicMock(return_value=FakeInfo())
entry = MockConfigEntry(
domain=mqtt.DOMAIN,
data={"certificate": "auto", mqtt.CONF_BROKER: "test-broker"},
@ -1669,6 +1680,7 @@ async def test_not_calling_subscribe_when_unsubscribed_within_cooldown(
the subscribe cool down period has ended.
"""
mqtt_mock = await mqtt_mock_entry()
mqtt_client_mock.subscribe.reset_mock()
# Fake that the client is connected
mqtt_mock().connected = True
@ -1925,6 +1937,7 @@ async def test_canceling_debouncer_on_shutdown(
"""Test canceling the debouncer when HA shuts down."""
mqtt_mock = await mqtt_mock_entry()
mqtt_client_mock.subscribe.reset_mock()
# Fake that the client is connected
mqtt_mock().connected = True
@ -2008,7 +2021,7 @@ async def test_initial_setup_logs_error(
"""Test for setup failure if initial client connection fails."""
entry = MockConfigEntry(domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"})
entry.add_to_hass(hass)
mqtt_client_mock.connect.return_value = 1
mqtt_client_mock.connect.side_effect = MagicMock(return_value=1)
try:
assert await hass.config_entries.async_setup(entry.entry_id)
except HomeAssistantError:
@ -2230,7 +2243,12 @@ async def test_handle_mqtt_timeout_on_callback(
mock_client = mock_client.return_value
mock_client.publish.return_value = FakeInfo()
mock_client.subscribe.side_effect = _mock_ack
mock_client.connect.return_value = 0
mock_client.connect = MagicMock(
return_value=0,
side_effect=lambda *args, **kwargs: hass.loop.call_soon_threadsafe(
mock_client.on_connect, mock_client, None, 0, 0, 0
),
)
entry = MockConfigEntry(
domain=mqtt.DOMAIN, data={mqtt.CONF_BROKER: "test-broker"}
@ -4144,3 +4162,179 @@ async def test_multi_platform_discovery(
)
is not None
)
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_auto_reconnect(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test reconnection is automatically done."""
mqtt_mock = await mqtt_mock_entry()
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.reconnect.reset_mock()
mqtt_client_mock.disconnect()
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
mqtt_client_mock.reconnect.side_effect = OSError("foo")
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
)
await hass.async_block_till_done()
assert len(mqtt_client_mock.reconnect.mock_calls) == 1
assert "Error re-connecting to MQTT server due to exception: foo" in caplog.text
mqtt_client_mock.reconnect.side_effect = None
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
)
await hass.async_block_till_done()
assert len(mqtt_client_mock.reconnect.mock_calls) == 2
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
mqtt_client_mock.disconnect()
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
async_fire_time_changed(
hass, utcnow() + timedelta(seconds=RECONNECT_INTERVAL_SECONDS)
)
await hass.async_block_till_done()
# Should not reconnect after stop
assert len(mqtt_client_mock.reconnect.mock_calls) == 2
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_server_sock_connect_and_disconnect(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test handling the socket connected and disconnected."""
mqtt_mock = await mqtt_mock_entry()
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
client, server = socket.socketpair(
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
)
client.setblocking(False)
server.setblocking(False)
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client)
await hass.async_block_till_done()
server.close() # mock the server closing the connection on us
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
mqtt_client_mock.on_socket_unregister_write(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_close(mqtt_client_mock, None, client)
mqtt_client_mock.on_disconnect(mqtt_client_mock, None, client)
await hass.async_block_till_done()
unsub()
# Should have failed
assert len(calls) == 0
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_client_sock_failure_after_connect(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
calls: list[ReceiveMessage],
record_calls: MessageCallbackType,
) -> None:
"""Test handling the socket connected and disconnected."""
mqtt_mock = await mqtt_mock_entry()
# Fake that the client is connected
mqtt_mock().connected = True
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
client, server = socket.socketpair(
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
)
client.setblocking(False)
server.setblocking(False)
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_register_writer(mqtt_client_mock, None, client)
await hass.async_block_till_done()
mqtt_client_mock.loop_write.side_effect = OSError("foo")
client.close() # close the client socket out from under the client
assert mqtt_mock.connected is True
unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
async_fire_time_changed(hass, utcnow() + timedelta(seconds=5))
await hass.async_block_till_done()
unsub()
# Should have failed
assert len(calls) == 0
@patch("homeassistant.components.mqtt.client.INITIAL_SUBSCRIBE_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.DISCOVERY_COOLDOWN", 0.0)
@patch("homeassistant.components.mqtt.client.SUBSCRIBE_COOLDOWN", 0.0)
async def test_loop_write_failure(
hass: HomeAssistant,
mqtt_client_mock: MqttMockPahoClient,
mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test handling the socket connected and disconnected."""
mqtt_mock = await mqtt_mock_entry()
await hass.async_block_till_done()
assert mqtt_mock.connected is True
mqtt_client_mock.loop_misc.return_value = paho_mqtt.MQTT_ERR_SUCCESS
client, server = socket.socketpair(
family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0
)
client.setblocking(False)
server.setblocking(False)
mqtt_client_mock.on_socket_open(mqtt_client_mock, None, client)
mqtt_client_mock.on_socket_register_write(mqtt_client_mock, None, client)
mqtt_client_mock.loop_write.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
mqtt_client_mock.loop_read.return_value = paho_mqtt.MQTT_ERR_CONN_LOST
# Fill up the outgoing buffer to ensure that loop_write
# and loop_read are called that next time control is
# returned to the event loop
try:
for _ in range(1000):
server.send(b"long" * 100)
except BlockingIOError:
pass
server.close()
# Once for the reader callback
await hass.async_block_till_done()
# Another for the writer callback
await hass.async_block_till_done()
# Final for the disconnect callback
await hass.async_block_till_done()
assert "Disconnected from MQTT server mock-broker:1883 (7)" in caplog.text

View File

@ -163,7 +163,7 @@ async def help_test_availability_when_connection_lost(
# Disconnected from MQTT server -> state changed to unavailable
mqtt_mock.connected = False
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0)
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -172,7 +172,7 @@ async def help_test_availability_when_connection_lost(
# Reconnected to MQTT server -> state still unavailable
mqtt_mock.connected = True
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0)
mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -224,7 +224,7 @@ async def help_test_deep_sleep_availability_when_connection_lost(
# Disconnected from MQTT server -> state changed to unavailable
mqtt_mock.connected = False
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0)
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -233,7 +233,7 @@ async def help_test_deep_sleep_availability_when_connection_lost(
# Reconnected to MQTT server -> state no longer unavailable
mqtt_mock.connected = True
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0)
mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -476,7 +476,7 @@ async def help_test_availability_poll_state(
# Disconnected from MQTT server
mqtt_mock.connected = False
await hass.async_add_executor_job(mqtt_client_mock.on_disconnect, None, None, 0)
mqtt_client_mock.on_disconnect(None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()
@ -484,7 +484,7 @@ async def help_test_availability_poll_state(
# Reconnected to MQTT server
mqtt_mock.connected = True
await hass.async_add_executor_job(mqtt_client_mock.on_connect, None, None, None, 0)
mqtt_client_mock.on_connect(None, None, None, 0)
await hass.async_block_till_done()
await hass.async_block_till_done()
await hass.async_block_till_done()

View File

@ -904,26 +904,45 @@ def mqtt_client_mock(hass: HomeAssistant) -> Generator[MqttMockPahoClient, None,
self.rc = 0
with patch("paho.mqtt.client.Client") as mock_client:
# The below use a call_soon for the on_publish/on_subscribe/on_unsubscribe
# callbacks to simulate the behavior of the real MQTT client which will
# not be synchronous.
@ha.callback
def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
mid = get_mid()
mock_client.on_publish(0, 0, mid)
hass.loop.call_soon(mock_client.on_publish, 0, 0, mid)
return FakeInfo(mid)
def _subscribe(topic, qos=0):
mid = get_mid()
mock_client.on_subscribe(0, 0, mid)
hass.loop.call_soon(mock_client.on_subscribe, 0, 0, mid)
return (0, mid)
def _unsubscribe(topic):
mid = get_mid()
mock_client.on_unsubscribe(0, 0, mid)
hass.loop.call_soon(mock_client.on_unsubscribe, 0, 0, mid)
return (0, mid)
def _connect(*args, **kwargs):
# Connect always calls reconnect once, but we
# mock it out so we call reconnect to simulate
# the behavior.
mock_client.reconnect()
hass.loop.call_soon_threadsafe(
mock_client.on_connect, mock_client, None, 0, 0, 0
)
mock_client.on_socket_open(
mock_client, None, Mock(fileno=Mock(return_value=-1))
)
mock_client.on_socket_register_write(
mock_client, None, Mock(fileno=Mock(return_value=-1))
)
return 0
mock_client = mock_client.return_value
mock_client.connect.return_value = 0
mock_client.connect.side_effect = _connect
mock_client.subscribe.side_effect = _subscribe
mock_client.unsubscribe.side_effect = _unsubscribe
mock_client.publish.side_effect = _async_fire_mqtt_message
@ -985,6 +1004,7 @@ async def _mqtt_mock_entry(
# connected set to True to get a more realistic behavior when subscribing
mock_mqtt_instance.connected = True
mqtt_client_mock.on_connect(mqtt_client_mock, None, 0, 0, 0)
async_dispatcher_send(hass, mqtt.MQTT_CONNECTED)
await hass.async_block_till_done()