Pass Message object to MQTT message callbacks (#21959)

* Pass Message object to MQTT message callbacks

* Improve method of detecting deprecated msg callback

* Fix mysensors

* Fixup

* Review comments

* Fix merge error
This commit is contained in:
emontnemery 2019-03-13 20:58:20 +01:00 committed by Paulus Schoutsen
parent 50ec3d7de5
commit 5957e4b75b
19 changed files with 203 additions and 142 deletions

View File

@ -5,6 +5,8 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/
"""
import asyncio
import inspect
from functools import partial, wraps
from itertools import groupby
import json
import logging
@ -264,7 +266,19 @@ MQTT_PUBLISH_SCHEMA = vol.Schema({
# pylint: disable=invalid-name
PublishPayloadType = Union[str, bytes, int, float, None]
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
MessageCallbackType = Callable[[str, SubscribePayloadType, int], None]
@attr.s(slots=True, frozen=True)
class Message:
"""MQTT Message."""
topic = attr.ib(type=str)
payload = attr.ib(type=PublishPayloadType)
qos = attr.ib(type=int)
retain = attr.ib(type=bool)
MessageCallbackType = Callable[[Message], None]
def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType:
@ -304,6 +318,30 @@ def publish_template(hass: HomeAssistantType, topic, payload_template,
hass.services.call(DOMAIN, SERVICE_PUBLISH, data)
def wrap_msg_callback(
msg_callback: MessageCallbackType) -> MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function
check_func = msg_callback
while isinstance(check_func, partial):
check_func = check_func.func
wrapper_func = None
if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback)
async def async_wrapper(msg: Any) -> None:
"""Catch and log exception."""
await msg_callback(msg.topic, msg.payload, msg.qos)
wrapper_func = async_wrapper
else:
@wraps(msg_callback)
def wrapper(msg: Any) -> None:
"""Catch and log exception."""
msg_callback(msg.topic, msg.payload, msg.qos)
wrapper_func = wrapper
return wrapper_func
@bind_hass
async def async_subscribe(hass: HomeAssistantType, topic: str,
msg_callback: MessageCallbackType,
@ -313,11 +351,25 @@ async def async_subscribe(hass: HomeAssistantType, topic: str,
Call the return value to unsubscribe.
"""
# Count callback parameters which don't have a default value
non_default = 0
if msg_callback:
non_default = sum(p.default == inspect.Parameter.empty for _, p in
inspect.signature(msg_callback).parameters.items())
wrapped_msg_callback = msg_callback
# If we have 3 paramaters with no default value, wrap the callback
if non_default == 3:
_LOGGER.info(
"Signature of MQTT msg_callback '%s.%s' is deprecated",
inspect.getmodule(msg_callback).__name__, msg_callback.__name__)
wrapped_msg_callback = wrap_msg_callback(msg_callback)
async_remove = await hass.data[DATA_MQTT].async_subscribe(
topic, catch_log_exception(
msg_callback, lambda topic, msg, qos:
wrapped_msg_callback, lambda msg:
"Exception in {} when handling msg on '{}': '{}'".format(
msg_callback.__name__, topic, msg)),
msg_callback.__name__, msg.topic, msg.payload)),
qos, encoding)
return async_remove
@ -575,16 +627,6 @@ class Subscription:
encoding = attr.ib(type=str, default='utf-8')
@attr.s(slots=True, frozen=True)
class Message:
"""MQTT Message."""
topic = attr.ib(type=str)
payload = attr.ib(type=PublishPayloadType)
qos = attr.ib(type=int, default=0)
retain = attr.ib(type=bool, default=False)
class MQTT:
"""Home Assistant MQTT client."""
@ -770,7 +812,8 @@ class MQTT:
@callback
def _mqtt_handle_message(self, msg) -> None:
_LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload)
_LOGGER.debug("Received message on %s%s: %s", msg.topic,
" (retained)" if msg.retain else "", msg.payload)
for subscription in self.subscriptions:
if not _match_topic(subscription.topic, msg.topic):
@ -787,7 +830,8 @@ class MQTT:
continue
self.hass.async_run_job(
subscription.callback, msg.topic, payload, msg.qos)
subscription.callback, Message(msg.topic, payload, msg.qos,
msg.retain))
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback."""
@ -865,11 +909,9 @@ class MqttAttributes(Entity):
from .subscription import async_subscribe_topics
@callback
def attributes_message_received(topic: str,
payload: SubscribePayloadType,
qos: int) -> None:
def attributes_message_received(msg: Message) -> None:
try:
json_dict = json.loads(payload)
json_dict = json.loads(msg.payload)
if isinstance(json_dict, dict):
self._attributes = json_dict
self.async_write_ha_state()
@ -877,7 +919,7 @@ class MqttAttributes(Entity):
_LOGGER.warning("JSON result was not a dictionary")
self._attributes = None
except ValueError:
_LOGGER.warning("Erroneous JSON: %s", payload)
_LOGGER.warning("Erroneous JSON: %s", msg.payload)
self._attributes = None
self._attributes_sub_state = await async_subscribe_topics(
@ -927,13 +969,11 @@ class MqttAvailability(Entity):
from .subscription import async_subscribe_topics
@callback
def availability_message_received(topic: str,
payload: SubscribePayloadType,
qos: int) -> None:
def availability_message_received(msg: Message) -> None:
"""Handle a new received MQTT availability message."""
if payload == self._avail_config[CONF_PAYLOAD_AVAILABLE]:
if msg.payload == self._avail_config[CONF_PAYLOAD_AVAILABLE]:
self._available = True
elif payload == self._avail_config[CONF_PAYLOAD_NOT_AVAILABLE]:
elif msg.payload == self._avail_config[CONF_PAYLOAD_NOT_AVAILABLE]:
self._available = False
self.async_write_ha_state()
@ -1064,12 +1104,13 @@ async def websocket_subscribe(hass, connection, msg):
if not connection.user.is_admin:
raise Unauthorized
async def forward_messages(topic: str, payload: str, qos: int):
async def forward_messages(mqttmsg: Message):
"""Forward events to websocket."""
connection.send_message(websocket_api.event_message(msg['id'], {
'topic': topic,
'payload': payload,
'qos': qos,
'topic': mqttmsg.topic,
'payload': mqttmsg.payload,
'qos': mqttmsg.qos,
'retain': mqttmsg.retain,
}))
connection.subscriptions[msg['id']] = await async_subscribe(

View File

@ -126,16 +126,17 @@ class MqttAlarm(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
def message_received(topic, payload, qos):
def message_received(msg):
"""Run when new MQTT message has been received."""
if payload not in (STATE_ALARM_DISARMED, STATE_ALARM_ARMED_HOME,
STATE_ALARM_ARMED_AWAY,
STATE_ALARM_ARMED_NIGHT,
STATE_ALARM_PENDING,
STATE_ALARM_TRIGGERED):
_LOGGER.warning("Received unexpected payload: %s", payload)
if msg.payload not in (
STATE_ALARM_DISARMED, STATE_ALARM_ARMED_HOME,
STATE_ALARM_ARMED_AWAY,
STATE_ALARM_ARMED_NIGHT,
STATE_ALARM_PENDING,
STATE_ALARM_TRIGGERED):
_LOGGER.warning("Received unexpected payload: %s", msg.payload)
return
self._state = payload
self._state = msg.payload
self.async_write_ha_state()
self._sub_state = await subscription.async_subscribe_topics(

View File

@ -133,8 +133,9 @@ class MqttBinarySensor(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self.async_write_ha_state()
@callback
def state_message_received(_topic, payload, _qos):
def state_message_received(msg):
"""Handle a new received MQTT state message."""
payload = msg.payload
value_template = self._config.get(CONF_VALUE_TEMPLATE)
if value_template is not None:
payload = value_template.async_render_with_possible_json_value(

View File

@ -102,9 +102,9 @@ class MqttCamera(MqttDiscoveryUpdate, Camera):
async def _subscribe_topics(self):
"""(Re)Subscribe to topics."""
@callback
def message_received(topic, payload, qos):
def message_received(msg):
"""Handle new MQTT messages."""
self._last_image = payload
self._last_image = msg.payload
self._sub_state = await subscription.async_subscribe_topics(
self.hass, self._sub_state,

View File

@ -288,8 +288,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
qos = self._config.get(CONF_QOS)
@callback
def handle_current_temp_received(topic, payload, qos):
def handle_current_temp_received(msg):
"""Handle current temperature coming via MQTT."""
payload = msg.payload
if CONF_CURRENT_TEMPERATURE_TEMPLATE in self._value_templates:
payload =\
self._value_templates[CONF_CURRENT_TEMPERATURE_TEMPLATE].\
@ -308,8 +309,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_mode_received(topic, payload, qos):
def handle_mode_received(msg):
"""Handle receiving mode via MQTT."""
payload = msg.payload
if CONF_MODE_STATE_TEMPLATE in self._value_templates:
payload = self._value_templates[CONF_MODE_STATE_TEMPLATE].\
async_render_with_possible_json_value(payload)
@ -327,8 +329,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_temperature_received(topic, payload, qos):
def handle_temperature_received(msg):
"""Handle target temperature coming via MQTT."""
payload = msg.payload
if CONF_TEMPERATURE_STATE_TEMPLATE in self._value_templates:
payload = \
self._value_templates[CONF_TEMPERATURE_STATE_TEMPLATE].\
@ -347,8 +350,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_fan_mode_received(topic, payload, qos):
def handle_fan_mode_received(msg):
"""Handle receiving fan mode via MQTT."""
payload = msg.payload
if CONF_FAN_MODE_STATE_TEMPLATE in self._value_templates:
payload = \
self._value_templates[CONF_FAN_MODE_STATE_TEMPLATE].\
@ -367,8 +371,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_swing_mode_received(topic, payload, qos):
def handle_swing_mode_received(msg):
"""Handle receiving swing mode via MQTT."""
payload = msg.payload
if CONF_SWING_MODE_STATE_TEMPLATE in self._value_templates:
payload = \
self._value_templates[CONF_SWING_MODE_STATE_TEMPLATE].\
@ -387,8 +392,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_away_mode_received(topic, payload, qos):
def handle_away_mode_received(msg):
"""Handle receiving away mode via MQTT."""
payload = msg.payload
payload_on = self._config.get(CONF_PAYLOAD_ON)
payload_off = self._config.get(CONF_PAYLOAD_OFF)
if CONF_AWAY_MODE_STATE_TEMPLATE in self._value_templates:
@ -416,8 +422,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_aux_mode_received(topic, payload, qos):
def handle_aux_mode_received(msg):
"""Handle receiving aux mode via MQTT."""
payload = msg.payload
payload_on = self._config.get(CONF_PAYLOAD_ON)
payload_off = self._config.get(CONF_PAYLOAD_OFF)
if CONF_AUX_STATE_TEMPLATE in self._value_templates:
@ -444,8 +451,9 @@ class MqttClimate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': qos}
@callback
def handle_hold_mode_received(topic, payload, qos):
def handle_hold_mode_received(msg):
"""Handle receiving hold mode via MQTT."""
payload = msg.payload
if CONF_HOLD_STATE_TEMPLATE in self._value_templates:
payload = self._value_templates[CONF_HOLD_STATE_TEMPLATE].\
async_render_with_possible_json_value(payload)

View File

@ -216,19 +216,20 @@ class MqttCover(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
topics = {}
@callback
def tilt_updated(topic, payload, qos):
def tilt_updated(msg):
"""Handle tilt updates."""
if (payload.isnumeric() and
(self._config.get(CONF_TILT_MIN) <= int(payload) <=
if (msg.payload.isnumeric() and
(self._config.get(CONF_TILT_MIN) <= int(msg.payload) <=
self._config.get(CONF_TILT_MAX))):
level = self.find_percentage_in_range(float(payload))
level = self.find_percentage_in_range(float(msg.payload))
self._tilt_value = level
self.async_write_ha_state()
@callback
def state_message_received(topic, payload, qos):
def state_message_received(msg):
"""Handle new MQTT state messages."""
payload = msg.payload
if template is not None:
payload = template.async_render_with_possible_json_value(
payload)
@ -243,8 +244,9 @@ class MqttCover(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self.async_write_ha_state()
@callback
def position_message_received(topic, payload, qos):
def position_message_received(msg):
"""Handle new MQTT state messages."""
payload = msg.payload
if template is not None:
payload = template.async_render_with_possible_json_value(
payload)

View File

@ -31,10 +31,10 @@ async def async_setup_scanner(hass, config, async_see, discovery_info=None):
for dev_id, topic in devices.items():
@callback
def async_message_received(topic, payload, qos, dev_id=dev_id):
def async_message_received(msg, dev_id=dev_id):
"""Handle received MQTT message."""
hass.async_create_task(
async_see(dev_id=dev_id, location_name=payload))
async_see(dev_id=dev_id, location_name=msg.payload))
await mqtt.async_subscribe(
hass, topic, async_message_received, qos)

View File

@ -200,8 +200,10 @@ def clear_discovery_hash(hass, discovery_hash):
async def async_start(hass: HomeAssistantType, discovery_topic, hass_config,
config_entry=None) -> bool:
"""Initialize of MQTT Discovery."""
async def async_device_message_received(topic, payload, qos):
async def async_device_message_received(msg):
"""Process the received message."""
payload = msg.payload
topic = msg.topic
match = TOPIC_MATCHER.match(topic)
if not match:

View File

@ -212,9 +212,9 @@ class MqttFan(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
templates[key] = tpl.async_render_with_possible_json_value
@callback
def state_received(topic, payload, qos):
def state_received(msg):
"""Handle new received MQTT message."""
payload = templates[CONF_STATE](payload)
payload = templates[CONF_STATE](msg.payload)
if payload == self._payload[STATE_ON]:
self._state = True
elif payload == self._payload[STATE_OFF]:
@ -228,9 +228,9 @@ class MqttFan(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
'qos': self._config.get(CONF_QOS)}
@callback
def speed_received(topic, payload, qos):
def speed_received(msg):
"""Handle new received MQTT message for the speed."""
payload = templates[ATTR_SPEED](payload)
payload = templates[ATTR_SPEED](msg.payload)
if payload == self._payload[SPEED_LOW]:
self._speed = SPEED_LOW
elif payload == self._payload[SPEED_MEDIUM]:
@ -247,9 +247,9 @@ class MqttFan(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._speed = SPEED_OFF
@callback
def oscillation_received(topic, payload, qos):
def oscillation_received(msg):
"""Handle new received MQTT message for the oscillation."""
payload = templates[OSCILLATION](payload)
payload = templates[OSCILLATION](msg.payload)
if payload == self._payload[OSCILLATE_ON_PAYLOAD]:
self._oscillation = True
elif payload == self._payload[OSCILLATE_OFF_PAYLOAD]:

View File

@ -254,11 +254,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
last_state = await self.async_get_last_state()
@callback
def state_received(topic, payload, qos):
def state_received(msg):
"""Handle new MQTT messages."""
payload = templates[CONF_STATE](payload)
payload = templates[CONF_STATE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty state message from '%s'", topic)
_LOGGER.debug("Ignoring empty state message from '%s'",
msg.topic)
return
if payload == self._payload['on']:
@ -276,12 +277,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._state = last_state.state == STATE_ON
@callback
def brightness_received(topic, payload, qos):
def brightness_received(msg):
"""Handle new MQTT messages for the brightness."""
payload = templates[CONF_BRIGHTNESS](payload)
payload = templates[CONF_BRIGHTNESS](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty brightness message from '%s'",
topic)
msg.topic)
return
device_value = float(payload)
@ -305,11 +306,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._brightness = None
@callback
def rgb_received(topic, payload, qos):
def rgb_received(msg):
"""Handle new MQTT messages for RGB."""
payload = templates[CONF_RGB](payload)
payload = templates[CONF_RGB](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty rgb message from '%s'", topic)
_LOGGER.debug("Ignoring empty rgb message from '%s'",
msg.topic)
return
rgb = [int(val) for val in payload.split(',')]
@ -333,12 +335,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._hs = (0, 0)
@callback
def color_temp_received(topic, payload, qos):
def color_temp_received(msg):
"""Handle new MQTT messages for color temperature."""
payload = templates[CONF_COLOR_TEMP](payload)
payload = templates[CONF_COLOR_TEMP](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty color temp message from '%s'",
topic)
msg.topic)
return
self._color_temp = int(payload)
@ -359,11 +361,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._color_temp = None
@callback
def effect_received(topic, payload, qos):
def effect_received(msg):
"""Handle new MQTT messages for effect."""
payload = templates[CONF_EFFECT](payload)
payload = templates[CONF_EFFECT](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty effect message from '%s'", topic)
_LOGGER.debug("Ignoring empty effect message from '%s'",
msg.topic)
return
self._effect = payload
@ -384,11 +387,11 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._effect = None
@callback
def hs_received(topic, payload, qos):
def hs_received(msg):
"""Handle new MQTT messages for hs color."""
payload = templates[CONF_HS](payload)
payload = templates[CONF_HS](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty hs message from '%s'", topic)
_LOGGER.debug("Ignoring empty hs message from '%s'", msg.topic)
return
try:
@ -412,12 +415,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._hs = (0, 0)
@callback
def white_value_received(topic, payload, qos):
def white_value_received(msg):
"""Handle new MQTT messages for white value."""
payload = templates[CONF_WHITE_VALUE](payload)
payload = templates[CONF_WHITE_VALUE](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty white value message from '%s'",
topic)
msg.topic)
return
device_value = float(payload)
@ -441,12 +444,12 @@ class MqttLight(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
self._white_value = None
@callback
def xy_received(topic, payload, qos):
def xy_received(msg):
"""Handle new MQTT messages for xy color."""
payload = templates[CONF_XY](payload)
payload = templates[CONF_XY](msg.payload)
if not payload:
_LOGGER.debug("Ignoring empty xy-color message from '%s'",
topic)
msg.topic)
return
xy_color = [float(val) for val in payload.split(',')]

View File

@ -201,9 +201,9 @@ class MqttLightJson(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
last_state = await self.async_get_last_state()
@callback
def state_received(topic, payload, qos):
def state_received(msg):
"""Handle new MQTT messages."""
values = json.loads(payload)
values = json.loads(msg.payload)
if values['state'] == 'ON':
self._state = True

View File

@ -188,10 +188,10 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
last_state = await self.async_get_last_state()
@callback
def state_received(topic, payload, qos):
def state_received(msg):
"""Handle new MQTT messages."""
state = self._templates[CONF_STATE_TEMPLATE].\
async_render_with_possible_json_value(payload)
async_render_with_possible_json_value(msg.payload)
if state == STATE_ON:
self._state = True
elif state == STATE_OFF:
@ -203,7 +203,7 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
try:
self._brightness = int(
self._templates[CONF_BRIGHTNESS_TEMPLATE].
async_render_with_possible_json_value(payload)
async_render_with_possible_json_value(msg.payload)
)
except ValueError:
_LOGGER.warning("Invalid brightness value received")
@ -212,7 +212,7 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
try:
self._color_temp = int(
self._templates[CONF_COLOR_TEMP_TEMPLATE].
async_render_with_possible_json_value(payload)
async_render_with_possible_json_value(msg.payload)
)
except ValueError:
_LOGGER.warning("Invalid color temperature value received")
@ -221,13 +221,13 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
try:
red = int(
self._templates[CONF_RED_TEMPLATE].
async_render_with_possible_json_value(payload))
async_render_with_possible_json_value(msg.payload))
green = int(
self._templates[CONF_GREEN_TEMPLATE].
async_render_with_possible_json_value(payload))
async_render_with_possible_json_value(msg.payload))
blue = int(
self._templates[CONF_BLUE_TEMPLATE].
async_render_with_possible_json_value(payload))
async_render_with_possible_json_value(msg.payload))
self._hs = color_util.color_RGB_to_hs(red, green, blue)
except ValueError:
_LOGGER.warning("Invalid color value received")
@ -236,14 +236,14 @@ class MqttTemplate(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
try:
self._white_value = int(
self._templates[CONF_WHITE_VALUE_TEMPLATE].
async_render_with_possible_json_value(payload)
async_render_with_possible_json_value(msg.payload)
)
except ValueError:
_LOGGER.warning('Invalid white value received')
if self._templates[CONF_EFFECT_TEMPLATE] is not None:
effect = self._templates[CONF_EFFECT_TEMPLATE].\
async_render_with_possible_json_value(payload)
async_render_with_possible_json_value(msg.payload)
if effect in self._config.get(CONF_EFFECT_LIST):
self._effect = effect

View File

@ -120,8 +120,9 @@ class MqttLock(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
value_template.hass = self.hass
@callback
def message_received(topic, payload, qos):
def message_received(msg):
"""Handle new MQTT messages."""
payload = msg.payload
if value_template is not None:
payload = value_template.async_render_with_possible_json_value(
payload)

View File

@ -133,8 +133,9 @@ class MqttSensor(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
template.hass = self.hass
@callback
def message_received(topic, payload, qos):
def message_received(msg):
"""Handle new MQTT messages."""
payload = msg.payload
# auto-expire enabled?
expire_after = self._config.get(CONF_EXPIRE_AFTER)
if expire_after is not None and expire_after > 0:

View File

@ -143,8 +143,9 @@ class MqttSwitch(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
template.hass = self.hass
@callback
def state_message_received(topic, payload, qos):
def state_message_received(msg):
"""Handle new MQTT state messages."""
payload = msg.payload
if template is not None:
payload = template.async_render_with_possible_json_value(
payload)

View File

@ -284,45 +284,45 @@ class MqttVacuum(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
tpl.hass = self.hass
@callback
def message_received(topic, payload, qos):
def message_received(msg):
"""Handle new MQTT message."""
if topic == self._state_topics[CONF_BATTERY_LEVEL_TOPIC] and \
if msg.topic == self._state_topics[CONF_BATTERY_LEVEL_TOPIC] and \
self._templates[CONF_BATTERY_LEVEL_TEMPLATE]:
battery_level = self._templates[CONF_BATTERY_LEVEL_TEMPLATE]\
.async_render_with_possible_json_value(
payload, error_value=None)
msg.payload, error_value=None)
if battery_level is not None:
self._battery_level = int(battery_level)
if topic == self._state_topics[CONF_CHARGING_TOPIC] and \
if msg.topic == self._state_topics[CONF_CHARGING_TOPIC] and \
self._templates[CONF_CHARGING_TEMPLATE]:
charging = self._templates[CONF_CHARGING_TEMPLATE]\
.async_render_with_possible_json_value(
payload, error_value=None)
msg.payload, error_value=None)
if charging is not None:
self._charging = cv.boolean(charging)
if topic == self._state_topics[CONF_CLEANING_TOPIC] and \
if msg.topic == self._state_topics[CONF_CLEANING_TOPIC] and \
self._templates[CONF_CLEANING_TEMPLATE]:
cleaning = self._templates[CONF_CLEANING_TEMPLATE]\
.async_render_with_possible_json_value(
payload, error_value=None)
msg.payload, error_value=None)
if cleaning is not None:
self._cleaning = cv.boolean(cleaning)
if topic == self._state_topics[CONF_DOCKED_TOPIC] and \
if msg.topic == self._state_topics[CONF_DOCKED_TOPIC] and \
self._templates[CONF_DOCKED_TEMPLATE]:
docked = self._templates[CONF_DOCKED_TEMPLATE]\
.async_render_with_possible_json_value(
payload, error_value=None)
msg.payload, error_value=None)
if docked is not None:
self._docked = cv.boolean(docked)
if topic == self._state_topics[CONF_ERROR_TOPIC] and \
if msg.topic == self._state_topics[CONF_ERROR_TOPIC] and \
self._templates[CONF_ERROR_TEMPLATE]:
error = self._templates[CONF_ERROR_TEMPLATE]\
.async_render_with_possible_json_value(
payload, error_value=None)
msg.payload, error_value=None)
if error is not None:
self._error = cv.string(error)
@ -338,11 +338,11 @@ class MqttVacuum(MqttAttributes, MqttAvailability, MqttDiscoveryUpdate,
else:
self._status = "Stopped"
if topic == self._state_topics[CONF_FAN_SPEED_TOPIC] and \
if msg.topic == self._state_topics[CONF_FAN_SPEED_TOPIC] and \
self._templates[CONF_FAN_SPEED_TEMPLATE]:
fan_speed = self._templates[CONF_FAN_SPEED_TEMPLATE]\
.async_render_with_possible_json_value(
payload, error_value=None)
msg.payload, error_value=None)
if fan_speed is not None:
self._fan_speed = fan_speed

View File

@ -98,9 +98,9 @@ async def _get_gateway(hass, config, gateway_conf, persistence_file):
def sub_callback(topic, sub_cb, qos):
"""Call MQTT subscribe function."""
@callback
def internal_callback(*args):
def internal_callback(msg):
"""Call callback."""
sub_cb(*args)
sub_cb(msg.topic, msg.payload, msg.qos)
hass.async_create_task(
mqtt.async_subscribe(topic, internal_callback, qos))

View File

@ -316,8 +316,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert 'test-topic' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert 'test-topic' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
unsub()
@ -343,8 +343,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert 'test-topic/bier/on' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert 'test-topic/bier/on' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_level_wildcard_no_subtree_match(self):
"""Test the subscription of wildcard topics."""
@ -372,8 +372,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert 'test-topic/bier/on' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert 'test-topic/bier/on' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_subtree_wildcard_root_topic(self):
"""Test the subscription of wildcard topics."""
@ -383,8 +383,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert 'test-topic' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert 'test-topic' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_subtree_wildcard_no_match(self):
"""Test the subscription of wildcard topics."""
@ -403,8 +403,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert 'hi/test-topic' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert 'hi/test-topic' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self):
"""Test the subscription of wildcard topics."""
@ -414,8 +414,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert 'hi/test-topic/here-iam' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert 'hi/test-topic/here-iam' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self):
"""Test the subscription of wildcard topics."""
@ -443,8 +443,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert '$test-topic/subtree/on' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert '$test-topic/subtree/on' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_sys_root_and_wildcard_topic(self):
"""Test the subscription of $ root and wildcard topics."""
@ -454,8 +454,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert '$test-topic/some-topic' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert '$test-topic/some-topic' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_topic_sys_root_and_wildcard_subtree_topic(self):
"""Test the subscription of $ root and wildcard subtree topics."""
@ -466,8 +466,8 @@ class TestMQTTCallbacks(unittest.TestCase):
self.hass.block_till_done()
assert 1 == len(self.calls)
assert '$test-topic/subtree/some-topic' == self.calls[0][0]
assert 'test-payload' == self.calls[0][1]
assert '$test-topic/subtree/some-topic' == self.calls[0][0].topic
assert 'test-payload' == self.calls[0][0].payload
def test_subscribe_special_characters(self):
"""Test the subscription to topics with special characters."""
@ -479,8 +479,8 @@ class TestMQTTCallbacks(unittest.TestCase):
fire_mqtt_message(self.hass, topic, payload)
self.hass.block_till_done()
assert 1 == len(self.calls)
assert topic == self.calls[0][0]
assert payload == self.calls[0][1]
assert topic == self.calls[0][0].topic
assert payload == self.calls[0][0].payload
def test_mqtt_failed_connection_results_in_disconnect(self):
"""Test if connection failure leads to disconnect."""

View File

@ -35,8 +35,8 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
async_fire_mqtt_message(hass, 'test-topic1', 'test-payload1')
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 'test-topic1' == calls1[0][0]
assert 'test-payload1' == calls1[0][1]
assert 'test-topic1' == calls1[0][0].topic
assert 'test-payload1' == calls1[0][0].payload
assert 0 == len(calls2)
async_fire_mqtt_message(hass, 'test-topic2', 'test-payload2')
@ -44,8 +44,8 @@ async def test_subscribe_topics(hass, mqtt_mock, caplog):
await hass.async_block_till_done()
assert 1 == len(calls1)
assert 1 == len(calls2)
assert 'test-topic2' == calls2[0][0]
assert 'test-payload2' == calls2[0][1]
assert 'test-topic2' == calls2[0][0].topic
assert 'test-payload2' == calls2[0][0].payload
await async_unsubscribe_topics(hass, sub_state)
@ -108,8 +108,8 @@ async def test_modify_topics(hass, mqtt_mock, caplog):
await hass.async_block_till_done()
await hass.async_block_till_done()
assert 2 == len(calls1)
assert 'test-topic1_1' == calls1[1][0]
assert 'test-payload' == calls1[1][1]
assert 'test-topic1_1' == calls1[1][0].topic
assert 'test-payload' == calls1[1][0].payload
assert 1 == len(calls2)
await async_unsubscribe_topics(hass, sub_state)