From 8a8732f0bc2a7cd891a3ddaff3edbe9c246d6ebf Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Thu, 24 Nov 2022 08:25:44 +0100 Subject: [PATCH] Strict type hints for MQTT integration (#82317) * Strict type hints for MQTT integration * Fix errors * Additional corrections * Use cv.template to avoid untyped calls * Enable strict typing policy for MQTT integration * Use ignore[no-untyped-call] * Use # type: ignore[unreachable] * Correct cast * Refactor getting discovery_payload * Remove unused type ignore comments --- .strict-typing | 1 + .../components/mqtt/alarm_control_panel.py | 2 +- homeassistant/components/mqtt/client.py | 26 ++++---- homeassistant/components/mqtt/discovery.py | 62 +++++++++++-------- .../components/mqtt/light/schema_basic.py | 2 +- homeassistant/components/mqtt/mixins.py | 5 +- homeassistant/components/mqtt/models.py | 14 +++-- homeassistant/components/mqtt/subscription.py | 6 +- homeassistant/components/mqtt/util.py | 12 ++-- mypy.ini | 10 +++ 10 files changed, 87 insertions(+), 53 deletions(-) diff --git a/.strict-typing b/.strict-typing index 559ee64ddd7..7ad69dc477f 100644 --- a/.strict-typing +++ b/.strict-typing @@ -187,6 +187,7 @@ homeassistant.components.mjpeg.* homeassistant.components.modbus.* homeassistant.components.modem_callerid.* homeassistant.components.moon.* +homeassistant.components.mqtt.* homeassistant.components.mysensors.* homeassistant.components.nam.* homeassistant.components.nanoleaf.* diff --git a/homeassistant/components/mqtt/alarm_control_panel.py b/homeassistant/components/mqtt/alarm_control_panel.py index 41021fe0d37..a0d065cc7fe 100644 --- a/homeassistant/components/mqtt/alarm_control_panel.py +++ b/homeassistant/components/mqtt/alarm_control_panel.py @@ -249,7 +249,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity): @property def code_arm_required(self) -> bool: """Whether the code is required for arm actions.""" - return self._config[CONF_CODE_ARM_REQUIRED] + return bool(self._config[CONF_CODE_ARM_REQUIRED]) async def async_alarm_disarm(self, code: str | None = None) -> None: """Send disarm command. diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 66f1e130ff7..414027776a3 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -80,7 +80,6 @@ if TYPE_CHECKING: # because integrations should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt - _LOGGER = logging.getLogger(__name__) DISCOVERY_COOLDOWN = 2 @@ -148,16 +147,19 @@ AsyncDeprecatedMessageCallbackType = Callable[ [str, ReceivePayloadType, int], Coroutine[Any, Any, None] ] DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None] +DeprecatedMessageCallbackTypes = Union[ + AsyncDeprecatedMessageCallbackType, DeprecatedMessageCallbackType +] def wrap_msg_callback( - msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType, + msg_callback: DeprecatedMessageCallbackTypes, ) -> AsyncMessageCallbackType | MessageCallbackType: """Wrap an MQTT message callback to support deprecated signature.""" # Check for partials to properly determine if coroutine function check_func = msg_callback while isinstance(check_func, partial): - check_func = check_func.func + check_func = check_func.func # type: ignore[unreachable] wrapper_func: AsyncMessageCallbackType | MessageCallbackType if asyncio.iscoroutinefunction(check_func): @@ -170,14 +172,15 @@ def wrap_msg_callback( ) wrapper_func = async_wrapper - else: + return wrapper_func - @wraps(msg_callback) - def wrapper(msg: ReceiveMessage) -> None: - """Call with deprecated signature.""" - msg_callback(msg.topic, msg.payload, msg.qos) + @wraps(msg_callback) + def wrapper(msg: ReceiveMessage) -> None: + """Call with deprecated signature.""" + msg_callback(msg.topic, msg.payload, msg.qos) + + wrapper_func = wrapper - wrapper_func = wrapper return wrapper_func @@ -187,8 +190,7 @@ async def async_subscribe( topic: str, msg_callback: AsyncMessageCallbackType | MessageCallbackType - | DeprecatedMessageCallbackType - | AsyncDeprecatedMessageCallbackType, + | DeprecatedMessageCallbackTypes, qos: int = DEFAULT_QOS, encoding: str | None = DEFAULT_ENCODING, ) -> CALLBACK_TYPE: @@ -219,7 +221,7 @@ async def async_subscribe( msg_callback.__name__, ) wrapped_msg_callback = wrap_msg_callback( - cast(DeprecatedMessageCallbackType, msg_callback) + cast(DeprecatedMessageCallbackTypes, msg_callback) ) async_remove = await mqtt_data.client.async_subscribe( diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 84f14d26146..9907de18ee9 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -97,7 +97,7 @@ async def async_start( # noqa: C901 mqtt_data = get_mqtt_data(hass) mqtt_integrations = {} - async def async_discovery_message_received(msg) -> None: + async def async_discovery_message_received(msg: ReceiveMessage) -> None: """Process the received message.""" mqtt_data.last_discovery = time.time() payload = msg.payload @@ -122,46 +122,50 @@ async def async_start( # noqa: C901 if payload: try: - payload = json_loads(payload) + discovery_payload = MQTTDiscoveryPayload(json_loads(payload)) except ValueError: _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) return + else: + discovery_payload = MQTTDiscoveryPayload({}) - payload = MQTTDiscoveryPayload(payload) - - for key in list(payload): + for key in list(discovery_payload): abbreviated_key = key key = ABBREVIATIONS.get(key, key) - payload[key] = payload.pop(abbreviated_key) + discovery_payload[key] = discovery_payload.pop(abbreviated_key) - if CONF_DEVICE in payload: - device = payload[CONF_DEVICE] + if CONF_DEVICE in discovery_payload: + device = discovery_payload[CONF_DEVICE] for key in list(device): abbreviated_key = key key = DEVICE_ABBREVIATIONS.get(key, key) device[key] = device.pop(abbreviated_key) - if CONF_AVAILABILITY in payload: - for availability_conf in cv.ensure_list(payload[CONF_AVAILABILITY]): + if CONF_AVAILABILITY in discovery_payload: + for availability_conf in cv.ensure_list( + discovery_payload[CONF_AVAILABILITY] + ): if isinstance(availability_conf, dict): for key in list(availability_conf): abbreviated_key = key key = ABBREVIATIONS.get(key, key) availability_conf[key] = availability_conf.pop(abbreviated_key) - if TOPIC_BASE in payload: - base = payload.pop(TOPIC_BASE) - for key, value in payload.items(): + if TOPIC_BASE in discovery_payload: + base = discovery_payload.pop(TOPIC_BASE) + for key, value in discovery_payload.items(): if isinstance(value, str) and value: if value[0] == TOPIC_BASE and key.endswith("topic"): - payload[key] = f"{base}{value[1:]}" + discovery_payload[key] = f"{base}{value[1:]}" if value[-1] == TOPIC_BASE and key.endswith("topic"): - payload[key] = f"{value[:-1]}{base}" - if payload.get(CONF_AVAILABILITY): - for availability_conf in cv.ensure_list(payload[CONF_AVAILABILITY]): + discovery_payload[key] = f"{value[:-1]}{base}" + if discovery_payload.get(CONF_AVAILABILITY): + for availability_conf in cv.ensure_list( + discovery_payload[CONF_AVAILABILITY] + ): if not isinstance(availability_conf, dict): continue - if topic := availability_conf.get(CONF_TOPIC): + if topic := str(availability_conf.get(CONF_TOPIC)): if topic[0] == TOPIC_BASE: availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}" if topic[-1] == TOPIC_BASE: @@ -171,21 +175,25 @@ async def async_start( # noqa: C901 discovery_id = " ".join((node_id, object_id)) if node_id else object_id discovery_hash = (component, discovery_id) - if payload: + if discovery_payload: # Attach MQTT topic to the payload, used for debug prints - setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')") + setattr( + discovery_payload, + "__configuration_source__", + f"MQTT (topic: '{topic}')", + ) discovery_data = { ATTR_DISCOVERY_HASH: discovery_hash, - ATTR_DISCOVERY_PAYLOAD: payload, + ATTR_DISCOVERY_PAYLOAD: discovery_payload, ATTR_DISCOVERY_TOPIC: topic, } - setattr(payload, "discovery_data", discovery_data) + setattr(discovery_payload, "discovery_data", discovery_data) - payload[CONF_PLATFORM] = "mqtt" + discovery_payload[CONF_PLATFORM] = "mqtt" if discovery_hash in mqtt_data.discovery_pending_discovered: pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"] - pending.appendleft(payload) + pending.appendleft(discovery_payload) _LOGGER.info( "Component has already been discovered: %s %s, queuing update", component, @@ -193,7 +201,9 @@ async def async_start( # noqa: C901 ) return - await async_process_discovery_payload(component, discovery_id, payload) + await async_process_discovery_payload( + component, discovery_id, discovery_payload + ) async def async_process_discovery_payload( component: str, discovery_id: str, payload: MQTTDiscoveryPayload @@ -204,7 +214,7 @@ async def async_start( # noqa: C901 discovery_hash = (component, discovery_id) if discovery_hash in mqtt_data.discovery_already_discovered or payload: - async def discovery_done(_) -> None: + async def discovery_done(_: Any) -> None: pending = mqtt_data.discovery_pending_discovered[discovery_hash][ "pending" ] diff --git a/homeassistant/components/mqtt/light/schema_basic.py b/homeassistant/components/mqtt/light/schema_basic.py index dfa8af0097d..689d388b92a 100644 --- a/homeassistant/components/mqtt/light/schema_basic.py +++ b/homeassistant/components/mqtt/light/schema_basic.py @@ -680,7 +680,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity): restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR) @property - def assumed_state(self): + def assumed_state(self) -> bool: """Return true if we do optimistic updates.""" return self._optimistic diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 3cb208f3adc..6df545d7508 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -620,7 +620,8 @@ async def cleanup_device_registry( def get_discovery_hash(discovery_data: DiscoveryInfoType) -> tuple[str, str]: """Get the discovery hash from the discovery data.""" - return discovery_data[ATTR_DISCOVERY_HASH] + discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] + return discovery_hash def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None: @@ -1113,7 +1114,7 @@ class MqttEntity( @property def entity_registry_enabled_default(self) -> bool: """Return if the entity should be enabled when first added to the entity registry.""" - return self._config[CONF_ENABLED_BY_DEFAULT] + return bool(self._config[CONF_ENABLED_BY_DEFAULT]) @property def entity_category(self) -> EntityCategory | None: diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index c0b299c7582..aaef5e3e3e8 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -150,7 +150,7 @@ class MqttCommandTemplate: if self._entity: values[ATTR_ENTITY_ID] = self._entity.entity_id values[ATTR_NAME] = self._entity.name - if not self._template_state: + if not self._template_state and self._command_template.hass is not None: self._template_state = template.TemplateStateFromEntityId( self._entity.hass, self._entity.entity_id ) @@ -200,6 +200,8 @@ class MqttValueTemplate: variables: TemplateVarsType = None, ) -> ReceivePayloadType: """Render with possible json value or pass-though a received MQTT value.""" + rendered_payload: ReceivePayloadType + if self._value_template is None: return payload @@ -227,9 +229,12 @@ class MqttValueTemplate: values, self._value_template, ) - return self._value_template.async_render_with_possible_json_value( - payload, variables=values + rendered_payload = ( + self._value_template.async_render_with_possible_json_value( + payload, variables=values + ) ) + return rendered_payload _LOGGER.debug( "Rendering incoming payload '%s' with variables %s with default value '%s' and %s", @@ -238,9 +243,10 @@ class MqttValueTemplate: default, self._value_template, ) - return self._value_template.async_render_with_possible_json_value( + rendered_payload = self._value_template.async_render_with_possible_json_value( payload, default, variables=values ) + return rendered_payload class EntityTopicState: diff --git a/homeassistant/components/mqtt/subscription.py b/homeassistant/components/mqtt/subscription.py index 87f5d3882bb..e3fd5e50093 100644 --- a/homeassistant/components/mqtt/subscription.py +++ b/homeassistant/components/mqtt/subscription.py @@ -19,7 +19,7 @@ class EntitySubscription: """Class to hold data about an active entity topic subscription.""" hass: HomeAssistant = attr.ib() - topic: str = attr.ib() + topic: str | None = attr.ib() message_callback: MessageCallbackType = attr.ib() subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None = attr.ib() unsubscribe_callback: Callable[[], None] | None = attr.ib() @@ -39,7 +39,7 @@ class EntitySubscription: other.unsubscribe_callback() # Clear debug data if it exists debug_info.remove_subscription( - self.hass, other.message_callback, other.topic + self.hass, other.message_callback, str(other.topic) ) if self.topic is None: @@ -112,7 +112,7 @@ def async_prepare_subscribe_topics( remaining.unsubscribe_callback() # Clear debug data if it exists debug_info.remove_subscription( - hass, remaining.message_callback, remaining.topic + hass, remaining.message_callback, str(remaining.topic) ) return new_state diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index bfd961871d3..97bb120f842 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -97,7 +97,7 @@ def valid_subscribe_topic(topic: Any) -> str: def valid_subscribe_topic_template(value: Any) -> template.Template: """Validate either a jinja2 template or a valid MQTT subscription topic.""" - tpl = template.Template(value) + tpl = cv.template(value) if tpl.is_static: valid_subscribe_topic(value) @@ -115,7 +115,8 @@ def valid_publish_topic(topic: Any) -> str: def valid_qos_schema(qos: Any) -> int: """Validate that QOS value is valid.""" - return _VALID_QOS_SCHEMA(qos) + validated_qos: int = _VALID_QOS_SCHEMA(qos) + return validated_qos _MQTT_WILL_BIRTH_SCHEMA = vol.Schema( @@ -138,9 +139,12 @@ def valid_birth_will(config: ConfigType) -> ConfigType: def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData: """Return typed MqttData from hass.data[DATA_MQTT].""" + mqtt_data: MqttData if ensure_exists: - return hass.data.setdefault(DATA_MQTT, MqttData()) - return hass.data[DATA_MQTT] + mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData()) + return mqtt_data + mqtt_data = hass.data[DATA_MQTT] + return mqtt_data async def async_create_certificate_temp_files( diff --git a/mypy.ini b/mypy.ini index a347218cb70..737761ae5a5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1623,6 +1623,16 @@ disallow_untyped_defs = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.mqtt.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.mysensors.*] check_untyped_defs = true disallow_incomplete_defs = true