Strict type hints for MQTT integration (#82317)

* Strict type hints for MQTT integration

* Fix errors

* Additional corrections

* Use cv.template to avoid untyped calls

* Enable strict typing policy for MQTT integration

* Use ignore[no-untyped-call]

* Use # type: ignore[unreachable]

* Correct cast

* Refactor getting discovery_payload

* Remove unused type ignore comments
This commit is contained in:
Jan Bouwhuis 2022-11-24 08:25:44 +01:00 committed by GitHub
parent 697b5db3f2
commit 8a8732f0bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 87 additions and 53 deletions

View File

@ -187,6 +187,7 @@ homeassistant.components.mjpeg.*
homeassistant.components.modbus.*
homeassistant.components.modem_callerid.*
homeassistant.components.moon.*
homeassistant.components.mqtt.*
homeassistant.components.mysensors.*
homeassistant.components.nam.*
homeassistant.components.nanoleaf.*

View File

@ -249,7 +249,7 @@ class MqttAlarm(MqttEntity, alarm.AlarmControlPanelEntity):
@property
def code_arm_required(self) -> bool:
"""Whether the code is required for arm actions."""
return self._config[CONF_CODE_ARM_REQUIRED]
return bool(self._config[CONF_CODE_ARM_REQUIRED])
async def async_alarm_disarm(self, code: str | None = None) -> None:
"""Send disarm command.

View File

@ -80,7 +80,6 @@ if TYPE_CHECKING:
# because integrations should be able to optionally rely on MQTT.
import paho.mqtt.client as mqtt
_LOGGER = logging.getLogger(__name__)
DISCOVERY_COOLDOWN = 2
@ -148,16 +147,19 @@ AsyncDeprecatedMessageCallbackType = Callable[
[str, ReceivePayloadType, int], Coroutine[Any, Any, None]
]
DeprecatedMessageCallbackType = Callable[[str, ReceivePayloadType, int], None]
DeprecatedMessageCallbackTypes = Union[
AsyncDeprecatedMessageCallbackType, DeprecatedMessageCallbackType
]
def wrap_msg_callback(
msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType,
msg_callback: DeprecatedMessageCallbackTypes,
) -> AsyncMessageCallbackType | MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function
check_func = msg_callback
while isinstance(check_func, partial):
check_func = check_func.func
check_func = check_func.func # type: ignore[unreachable]
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
if asyncio.iscoroutinefunction(check_func):
@ -170,14 +172,15 @@ def wrap_msg_callback(
)
wrapper_func = async_wrapper
else:
return wrapper_func
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
@wraps(msg_callback)
def wrapper(msg: ReceiveMessage) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
wrapper_func = wrapper
wrapper_func = wrapper
return wrapper_func
@ -187,8 +190,7 @@ async def async_subscribe(
topic: str,
msg_callback: AsyncMessageCallbackType
| MessageCallbackType
| DeprecatedMessageCallbackType
| AsyncDeprecatedMessageCallbackType,
| DeprecatedMessageCallbackTypes,
qos: int = DEFAULT_QOS,
encoding: str | None = DEFAULT_ENCODING,
) -> CALLBACK_TYPE:
@ -219,7 +221,7 @@ async def async_subscribe(
msg_callback.__name__,
)
wrapped_msg_callback = wrap_msg_callback(
cast(DeprecatedMessageCallbackType, msg_callback)
cast(DeprecatedMessageCallbackTypes, msg_callback)
)
async_remove = await mqtt_data.client.async_subscribe(

View File

@ -97,7 +97,7 @@ async def async_start( # noqa: C901
mqtt_data = get_mqtt_data(hass)
mqtt_integrations = {}
async def async_discovery_message_received(msg) -> None:
async def async_discovery_message_received(msg: ReceiveMessage) -> None:
"""Process the received message."""
mqtt_data.last_discovery = time.time()
payload = msg.payload
@ -122,46 +122,50 @@ async def async_start( # noqa: C901
if payload:
try:
payload = json_loads(payload)
discovery_payload = MQTTDiscoveryPayload(json_loads(payload))
except ValueError:
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return
else:
discovery_payload = MQTTDiscoveryPayload({})
payload = MQTTDiscoveryPayload(payload)
for key in list(payload):
for key in list(discovery_payload):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
payload[key] = payload.pop(abbreviated_key)
discovery_payload[key] = discovery_payload.pop(abbreviated_key)
if CONF_DEVICE in payload:
device = payload[CONF_DEVICE]
if CONF_DEVICE in discovery_payload:
device = discovery_payload[CONF_DEVICE]
for key in list(device):
abbreviated_key = key
key = DEVICE_ABBREVIATIONS.get(key, key)
device[key] = device.pop(abbreviated_key)
if CONF_AVAILABILITY in payload:
for availability_conf in cv.ensure_list(payload[CONF_AVAILABILITY]):
if CONF_AVAILABILITY in discovery_payload:
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if isinstance(availability_conf, dict):
for key in list(availability_conf):
abbreviated_key = key
key = ABBREVIATIONS.get(key, key)
availability_conf[key] = availability_conf.pop(abbreviated_key)
if TOPIC_BASE in payload:
base = payload.pop(TOPIC_BASE)
for key, value in payload.items():
if TOPIC_BASE in discovery_payload:
base = discovery_payload.pop(TOPIC_BASE)
for key, value in discovery_payload.items():
if isinstance(value, str) and value:
if value[0] == TOPIC_BASE and key.endswith("topic"):
payload[key] = f"{base}{value[1:]}"
discovery_payload[key] = f"{base}{value[1:]}"
if value[-1] == TOPIC_BASE and key.endswith("topic"):
payload[key] = f"{value[:-1]}{base}"
if payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(payload[CONF_AVAILABILITY]):
discovery_payload[key] = f"{value[:-1]}{base}"
if discovery_payload.get(CONF_AVAILABILITY):
for availability_conf in cv.ensure_list(
discovery_payload[CONF_AVAILABILITY]
):
if not isinstance(availability_conf, dict):
continue
if topic := availability_conf.get(CONF_TOPIC):
if topic := str(availability_conf.get(CONF_TOPIC)):
if topic[0] == TOPIC_BASE:
availability_conf[CONF_TOPIC] = f"{base}{topic[1:]}"
if topic[-1] == TOPIC_BASE:
@ -171,21 +175,25 @@ async def async_start( # noqa: C901
discovery_id = " ".join((node_id, object_id)) if node_id else object_id
discovery_hash = (component, discovery_id)
if payload:
if discovery_payload:
# Attach MQTT topic to the payload, used for debug prints
setattr(payload, "__configuration_source__", f"MQTT (topic: '{topic}')")
setattr(
discovery_payload,
"__configuration_source__",
f"MQTT (topic: '{topic}')",
)
discovery_data = {
ATTR_DISCOVERY_HASH: discovery_hash,
ATTR_DISCOVERY_PAYLOAD: payload,
ATTR_DISCOVERY_PAYLOAD: discovery_payload,
ATTR_DISCOVERY_TOPIC: topic,
}
setattr(payload, "discovery_data", discovery_data)
setattr(discovery_payload, "discovery_data", discovery_data)
payload[CONF_PLATFORM] = "mqtt"
discovery_payload[CONF_PLATFORM] = "mqtt"
if discovery_hash in mqtt_data.discovery_pending_discovered:
pending = mqtt_data.discovery_pending_discovered[discovery_hash]["pending"]
pending.appendleft(payload)
pending.appendleft(discovery_payload)
_LOGGER.info(
"Component has already been discovered: %s %s, queuing update",
component,
@ -193,7 +201,9 @@ async def async_start( # noqa: C901
)
return
await async_process_discovery_payload(component, discovery_id, payload)
await async_process_discovery_payload(
component, discovery_id, discovery_payload
)
async def async_process_discovery_payload(
component: str, discovery_id: str, payload: MQTTDiscoveryPayload
@ -204,7 +214,7 @@ async def async_start( # noqa: C901
discovery_hash = (component, discovery_id)
if discovery_hash in mqtt_data.discovery_already_discovered or payload:
async def discovery_done(_) -> None:
async def discovery_done(_: Any) -> None:
pending = mqtt_data.discovery_pending_discovered[discovery_hash][
"pending"
]

View File

@ -680,7 +680,7 @@ class MqttLight(MqttEntity, LightEntity, RestoreEntity):
restore_state(ATTR_HS_COLOR, ATTR_XY_COLOR)
@property
def assumed_state(self):
def assumed_state(self) -> bool:
"""Return true if we do optimistic updates."""
return self._optimistic

View File

@ -620,7 +620,8 @@ async def cleanup_device_registry(
def get_discovery_hash(discovery_data: DiscoveryInfoType) -> tuple[str, str]:
"""Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH]
discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
return discovery_hash
def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None:
@ -1113,7 +1114,7 @@ class MqttEntity(
@property
def entity_registry_enabled_default(self) -> bool:
"""Return if the entity should be enabled when first added to the entity registry."""
return self._config[CONF_ENABLED_BY_DEFAULT]
return bool(self._config[CONF_ENABLED_BY_DEFAULT])
@property
def entity_category(self) -> EntityCategory | None:

View File

@ -150,7 +150,7 @@ class MqttCommandTemplate:
if self._entity:
values[ATTR_ENTITY_ID] = self._entity.entity_id
values[ATTR_NAME] = self._entity.name
if not self._template_state:
if not self._template_state and self._command_template.hass is not None:
self._template_state = template.TemplateStateFromEntityId(
self._entity.hass, self._entity.entity_id
)
@ -200,6 +200,8 @@ class MqttValueTemplate:
variables: TemplateVarsType = None,
) -> ReceivePayloadType:
"""Render with possible json value or pass-though a received MQTT value."""
rendered_payload: ReceivePayloadType
if self._value_template is None:
return payload
@ -227,9 +229,12 @@ class MqttValueTemplate:
values,
self._value_template,
)
return self._value_template.async_render_with_possible_json_value(
payload, variables=values
rendered_payload = (
self._value_template.async_render_with_possible_json_value(
payload, variables=values
)
)
return rendered_payload
_LOGGER.debug(
"Rendering incoming payload '%s' with variables %s with default value '%s' and %s",
@ -238,9 +243,10 @@ class MqttValueTemplate:
default,
self._value_template,
)
return self._value_template.async_render_with_possible_json_value(
rendered_payload = self._value_template.async_render_with_possible_json_value(
payload, default, variables=values
)
return rendered_payload
class EntityTopicState:

View File

@ -19,7 +19,7 @@ class EntitySubscription:
"""Class to hold data about an active entity topic subscription."""
hass: HomeAssistant = attr.ib()
topic: str = attr.ib()
topic: str | None = attr.ib()
message_callback: MessageCallbackType = attr.ib()
subscribe_task: Coroutine[Any, Any, Callable[[], None]] | None = attr.ib()
unsubscribe_callback: Callable[[], None] | None = attr.ib()
@ -39,7 +39,7 @@ class EntitySubscription:
other.unsubscribe_callback()
# Clear debug data if it exists
debug_info.remove_subscription(
self.hass, other.message_callback, other.topic
self.hass, other.message_callback, str(other.topic)
)
if self.topic is None:
@ -112,7 +112,7 @@ def async_prepare_subscribe_topics(
remaining.unsubscribe_callback()
# Clear debug data if it exists
debug_info.remove_subscription(
hass, remaining.message_callback, remaining.topic
hass, remaining.message_callback, str(remaining.topic)
)
return new_state

View File

@ -97,7 +97,7 @@ def valid_subscribe_topic(topic: Any) -> str:
def valid_subscribe_topic_template(value: Any) -> template.Template:
"""Validate either a jinja2 template or a valid MQTT subscription topic."""
tpl = template.Template(value)
tpl = cv.template(value)
if tpl.is_static:
valid_subscribe_topic(value)
@ -115,7 +115,8 @@ def valid_publish_topic(topic: Any) -> str:
def valid_qos_schema(qos: Any) -> int:
"""Validate that QOS value is valid."""
return _VALID_QOS_SCHEMA(qos)
validated_qos: int = _VALID_QOS_SCHEMA(qos)
return validated_qos
_MQTT_WILL_BIRTH_SCHEMA = vol.Schema(
@ -138,9 +139,12 @@ def valid_birth_will(config: ConfigType) -> ConfigType:
def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData:
"""Return typed MqttData from hass.data[DATA_MQTT]."""
mqtt_data: MqttData
if ensure_exists:
return hass.data.setdefault(DATA_MQTT, MqttData())
return hass.data[DATA_MQTT]
mqtt_data = hass.data.setdefault(DATA_MQTT, MqttData())
return mqtt_data
mqtt_data = hass.data[DATA_MQTT]
return mqtt_data
async def async_create_certificate_temp_files(

View File

@ -1623,6 +1623,16 @@ disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.mqtt.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.mysensors.*]
check_untyped_defs = true
disallow_incomplete_defs = true