mirror of
https://github.com/home-assistant/core
synced 2024-09-15 17:29:45 +02:00
Add entities for ZHA fan groups (#33291)
* start of fan groups * update fan classes * update group entity domains * add set speed * update discovery for multiple entities for groups * add fan group entity tests * cleanup const * cleanup entity_domain usage * remove bad super call * remove bad update line * fix set speed on fan group * change comparison * pythonic list * discovery guards * Update homeassistant/components/zha/core/discovery.py Co-Authored-By: Alexei Chetroi <lexoid@gmail.com> Co-authored-by: Alexei Chetroi <lexoid@gmail.com>
This commit is contained in:
parent
c89975adf6
commit
4f767dd3ef
@ -23,7 +23,6 @@ ATTR_COMMAND_TYPE = "command_type"
|
||||
ATTR_DEVICE_IEEE = "device_ieee"
|
||||
ATTR_DEVICE_TYPE = "device_type"
|
||||
ATTR_ENDPOINT_ID = "endpoint_id"
|
||||
ATTR_ENTITY_DOMAIN = "entity_domain"
|
||||
ATTR_IEEE = "ieee"
|
||||
ATTR_LAST_SEEN = "last_seen"
|
||||
ATTR_LEVEL = "level"
|
||||
|
@ -6,6 +6,7 @@ from typing import Callable, List, Tuple
|
||||
|
||||
from homeassistant import const as ha_const
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.entity_registry import async_entries_for_device
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
|
||||
@ -182,59 +183,48 @@ class GroupProbe:
|
||||
)
|
||||
return
|
||||
|
||||
if group.entity_domain is None:
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has no user set entity domain - attempting entity domain discovery",
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
group.entity_domain = GroupProbe.determine_default_entity_domain(
|
||||
self._hass, group
|
||||
)
|
||||
entity_domains = GroupProbe.determine_entity_domains(self._hass, group)
|
||||
|
||||
if group.entity_domain is None:
|
||||
if not entity_domains:
|
||||
return
|
||||
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has an entity domain of: %s after discovery",
|
||||
group.name,
|
||||
group.group_id,
|
||||
group.entity_domain,
|
||||
)
|
||||
|
||||
zha_gateway = self._hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
|
||||
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(group.entity_domain)
|
||||
if entity_class is None:
|
||||
return
|
||||
|
||||
self._hass.data[zha_const.DATA_ZHA][group.entity_domain].append(
|
||||
(
|
||||
entity_class,
|
||||
for domain in entity_domains:
|
||||
entity_class = zha_regs.ZHA_ENTITIES.get_group_entity(domain)
|
||||
if entity_class is None:
|
||||
continue
|
||||
self._hass.data[zha_const.DATA_ZHA][domain].append(
|
||||
(
|
||||
group.domain_entity_ids,
|
||||
f"{group.entity_domain}_group_{group.group_id}",
|
||||
group.group_id,
|
||||
zha_gateway.coordinator_zha_device,
|
||||
),
|
||||
entity_class,
|
||||
(
|
||||
group.get_domain_entity_ids(domain),
|
||||
f"{domain}_group_{group.group_id}",
|
||||
group.group_id,
|
||||
zha_gateway.coordinator_zha_device,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
async_dispatcher_send(self._hass, zha_const.SIGNAL_ADD_ENTITIES)
|
||||
|
||||
@staticmethod
|
||||
def determine_default_entity_domain(
|
||||
def determine_entity_domains(
|
||||
hass: HomeAssistantType, group: zha_typing.ZhaGroupType
|
||||
):
|
||||
"""Determine the default entity domain for this group."""
|
||||
) -> List[str]:
|
||||
"""Determine the entity domains for this group."""
|
||||
entity_domains: List[str] = []
|
||||
if len(group.members) < 2:
|
||||
_LOGGER.debug(
|
||||
"Group: %s:0x%04x has less than 2 members so cannot default an entity domain",
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
return None
|
||||
return entity_domains
|
||||
|
||||
zha_gateway = hass.data[zha_const.DATA_ZHA][zha_const.DATA_ZHA_GATEWAY]
|
||||
all_domain_occurrences = []
|
||||
for device in group.members:
|
||||
if device.is_coordinator:
|
||||
continue
|
||||
entities = async_entries_for_device(
|
||||
zha_gateway.ha_entity_registry, device.device_id
|
||||
)
|
||||
@ -245,15 +235,18 @@ class GroupProbe:
|
||||
if entity.domain in zha_regs.GROUP_ENTITY_DOMAINS
|
||||
]
|
||||
)
|
||||
if not all_domain_occurrences:
|
||||
return entity_domains
|
||||
# get all domains we care about if there are more than 2 entities of this domain
|
||||
counts = Counter(all_domain_occurrences)
|
||||
domain = counts.most_common(1)[0][0]
|
||||
entity_domains = [domain[0] for domain in counts.items() if domain[1] >= 2]
|
||||
_LOGGER.debug(
|
||||
"The default entity domain is: %s for group: %s:0x%04x",
|
||||
domain,
|
||||
"The entity domains are: %s for group: %s:0x%04x",
|
||||
entity_domains,
|
||||
group.name,
|
||||
group.group_id,
|
||||
)
|
||||
return domain
|
||||
return entity_domains
|
||||
|
||||
|
||||
PROBE = ProbeEndpoint()
|
||||
|
@ -445,8 +445,6 @@ class ZHAGateway:
|
||||
if zha_group is None:
|
||||
zha_group = ZHAGroup(self._hass, self, zigpy_group)
|
||||
self._groups[zigpy_group.group_id] = zha_group
|
||||
group_entry = self.zha_storage.async_get_or_create_group(zha_group)
|
||||
zha_group.entity_domain = group_entry.entity_domain
|
||||
return zha_group
|
||||
|
||||
@callback
|
||||
@ -469,8 +467,6 @@ class ZHAGateway:
|
||||
"""Update the devices in the store."""
|
||||
for device in self.devices.values():
|
||||
self.zha_storage.async_update_device(device)
|
||||
for group in self.groups.values():
|
||||
self.zha_storage.async_update_group(group)
|
||||
await self.zha_storage.async_save()
|
||||
|
||||
async def async_device_initialized(self, device: zha_typing.ZigpyDeviceType):
|
||||
@ -559,9 +555,7 @@ class ZHAGateway:
|
||||
zha_group.group_id,
|
||||
)
|
||||
discovery.GROUP_PROBE.discover_group_entities(zha_group)
|
||||
if zha_group.entity_domain is not None:
|
||||
self.zha_storage.async_update_group(zha_group)
|
||||
async_dispatcher_send(self._hass, SIGNAL_ADD_ENTITIES)
|
||||
|
||||
return zha_group
|
||||
|
||||
async def async_remove_zigpy_group(self, group_id: int) -> None:
|
||||
@ -577,7 +571,6 @@ class ZHAGateway:
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
self.application_controller.groups.pop(group_id)
|
||||
self.zha_storage.async_delete_group(group)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Stop ZHA Controller Application."""
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Group for Zigbee Home Automation."""
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from zigpy.types.named import EUI64
|
||||
|
||||
@ -28,7 +28,6 @@ class ZHAGroup(LogMixin):
|
||||
self.hass: HomeAssistantType = hass
|
||||
self._zigpy_group: ZigpyGroupType = zigpy_group
|
||||
self._zha_gateway: ZhaGatewayType = zha_gateway
|
||||
self._entity_domain: str = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -45,16 +44,6 @@ class ZHAGroup(LogMixin):
|
||||
"""Return the endpoint for this group."""
|
||||
return self._zigpy_group.endpoint
|
||||
|
||||
@property
|
||||
def entity_domain(self) -> Optional[str]:
|
||||
"""Return the domain that will be used for the entity representing this group."""
|
||||
return self._entity_domain
|
||||
|
||||
@entity_domain.setter
|
||||
def entity_domain(self, domain: Optional[str]) -> None:
|
||||
"""Set the domain that will be used for the entity representing this group."""
|
||||
self._entity_domain = domain
|
||||
|
||||
@property
|
||||
def members(self) -> List[ZhaDeviceType]:
|
||||
"""Return the ZHA devices that are members of this group."""
|
||||
@ -106,22 +95,15 @@ class ZHAGroup(LogMixin):
|
||||
all_entity_ids.append(entity.entity_id)
|
||||
return all_entity_ids
|
||||
|
||||
@property
|
||||
def domain_entity_ids(self) -> List[str]:
|
||||
def get_domain_entity_ids(self, domain) -> List[str]:
|
||||
"""Return entity ids from the entity domain for this group."""
|
||||
if self.entity_domain is None:
|
||||
return
|
||||
domain_entity_ids: List[str] = []
|
||||
for device in self.members:
|
||||
entities = async_entries_for_device(
|
||||
self._zha_gateway.ha_entity_registry, device.device_id
|
||||
)
|
||||
domain_entity_ids.extend(
|
||||
[
|
||||
entity.entity_id
|
||||
for entity in entities
|
||||
if entity.domain == self.entity_domain
|
||||
]
|
||||
[entity.entity_id for entity in entities if entity.domain == domain]
|
||||
)
|
||||
return domain_entity_ids
|
||||
|
||||
@ -130,7 +112,6 @@ class ZHAGroup(LogMixin):
|
||||
"""Get ZHA group info."""
|
||||
group_info: Dict[str, Any] = {}
|
||||
group_info["group_id"] = self.group_id
|
||||
group_info["entity_domain"] = self.entity_domain
|
||||
group_info["name"] = self.name
|
||||
group_info["members"] = [
|
||||
zha_device.async_get_info() for zha_device in self.members
|
||||
|
@ -32,7 +32,7 @@ from .const import CONTROLLER, ZHA_GW_RADIO, ZHA_GW_RADIO_DESCRIPTION, RadioType
|
||||
from .decorators import CALLABLE_T, DictRegistry, SetRegistry
|
||||
from .typing import ChannelType
|
||||
|
||||
GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH]
|
||||
GROUP_ENTITY_DOMAINS = [LIGHT, SWITCH, FAN]
|
||||
|
||||
SMARTTHINGS_ACCELERATION_CLUSTER = 0xFC02
|
||||
SMARTTHINGS_ARRIVAL_SENSOR_DEVICE_TYPE = 0x8000
|
||||
|
@ -10,7 +10,7 @@ from homeassistant.core import callback
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
from .typing import ZhaDeviceType, ZhaGroupType
|
||||
from .typing import ZhaDeviceType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -30,15 +30,6 @@ class ZhaDeviceEntry:
|
||||
last_seen = attr.ib(type=float, default=None)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class ZhaGroupEntry:
|
||||
"""Zha Group storage Entry."""
|
||||
|
||||
name = attr.ib(type=str, default=None)
|
||||
group_id = attr.ib(type=int, default=None)
|
||||
entity_domain = attr.ib(type=float, default=None)
|
||||
|
||||
|
||||
class ZhaStorage:
|
||||
"""Class to hold a registry of zha devices."""
|
||||
|
||||
@ -46,7 +37,6 @@ class ZhaStorage:
|
||||
"""Initialize the zha device storage."""
|
||||
self.hass: HomeAssistantType = hass
|
||||
self.devices: MutableMapping[str, ZhaDeviceEntry] = {}
|
||||
self.groups: MutableMapping[str, ZhaGroupEntry] = {}
|
||||
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
|
||||
|
||||
@callback
|
||||
@ -59,17 +49,6 @@ class ZhaStorage:
|
||||
|
||||
return self.async_update_device(device)
|
||||
|
||||
@callback
|
||||
def async_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Create a new ZhaGroupEntry."""
|
||||
group_entry: ZhaGroupEntry = ZhaGroupEntry(
|
||||
name=group.name,
|
||||
group_id=str(group.group_id),
|
||||
entity_domain=group.entity_domain,
|
||||
)
|
||||
self.groups[str(group.group_id)] = group_entry
|
||||
return self.async_update_group(group)
|
||||
|
||||
@callback
|
||||
def async_get_or_create_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Create a new ZhaDeviceEntry."""
|
||||
@ -78,14 +57,6 @@ class ZhaStorage:
|
||||
return self.devices[ieee_str]
|
||||
return self.async_create_device(device)
|
||||
|
||||
@callback
|
||||
def async_get_or_create_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Create a new ZhaGroupEntry."""
|
||||
group_id: str = str(group.group_id)
|
||||
if group_id in self.groups:
|
||||
return self.groups[group_id]
|
||||
return self.async_create_group(group)
|
||||
|
||||
@callback
|
||||
def async_create_or_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Create or update a ZhaDeviceEntry."""
|
||||
@ -93,13 +64,6 @@ class ZhaStorage:
|
||||
return self.async_update_device(device)
|
||||
return self.async_create_device(device)
|
||||
|
||||
@callback
|
||||
def async_create_or_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Create or update a ZhaGroupEntry."""
|
||||
if str(group.group_id) in self.groups:
|
||||
return self.async_update_group(group)
|
||||
return self.async_create_group(group)
|
||||
|
||||
@callback
|
||||
def async_delete_device(self, device: ZhaDeviceType) -> None:
|
||||
"""Delete ZhaDeviceEntry."""
|
||||
@ -108,14 +72,6 @@ class ZhaStorage:
|
||||
del self.devices[ieee_str]
|
||||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_delete_group(self, group: ZhaGroupType) -> None:
|
||||
"""Delete ZhaGroupEntry."""
|
||||
group_id: str = str(group.group_id)
|
||||
if group_id in self.groups:
|
||||
del self.groups[group_id]
|
||||
self.async_schedule_save()
|
||||
|
||||
@callback
|
||||
def async_update_device(self, device: ZhaDeviceType) -> ZhaDeviceEntry:
|
||||
"""Update name of ZhaDeviceEntry."""
|
||||
@ -129,25 +85,11 @@ class ZhaStorage:
|
||||
self.async_schedule_save()
|
||||
return new
|
||||
|
||||
@callback
|
||||
def async_update_group(self, group: ZhaGroupType) -> ZhaGroupEntry:
|
||||
"""Update name of ZhaGroupEntry."""
|
||||
group_id: str = str(group.group_id)
|
||||
old = self.groups[group_id]
|
||||
|
||||
changes = {}
|
||||
changes["entity_domain"] = group.entity_domain
|
||||
|
||||
new = self.groups[group_id] = attr.evolve(old, **changes)
|
||||
self.async_schedule_save()
|
||||
return new
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load the registry of zha device entries."""
|
||||
data = await self._store.async_load()
|
||||
|
||||
devices: "OrderedDict[str, ZhaDeviceEntry]" = OrderedDict()
|
||||
groups: "OrderedDict[str, ZhaGroupEntry]" = OrderedDict()
|
||||
|
||||
if data is not None:
|
||||
for device in data["devices"]:
|
||||
@ -157,18 +99,7 @@ class ZhaStorage:
|
||||
last_seen=device["last_seen"] if "last_seen" in device else None,
|
||||
)
|
||||
|
||||
if "groups" in data:
|
||||
for group in data["groups"]:
|
||||
groups[group["group_id"]] = ZhaGroupEntry(
|
||||
name=group["name"],
|
||||
group_id=group["group_id"],
|
||||
entity_domain=group["entity_domain"]
|
||||
if "entity_domain" in group
|
||||
else None,
|
||||
)
|
||||
|
||||
self.devices = devices
|
||||
self.groups = groups
|
||||
|
||||
@callback
|
||||
def async_schedule_save(self) -> None:
|
||||
@ -189,14 +120,6 @@ class ZhaStorage:
|
||||
for entry in self.devices.values()
|
||||
]
|
||||
|
||||
data["groups"] = [
|
||||
{
|
||||
"name": entry.name,
|
||||
"group_id": entry.group_id,
|
||||
"entity_domain": entry.entity_domain,
|
||||
}
|
||||
for entry in self.groups.values()
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
"""Fans on Zigbee Home Automation networks."""
|
||||
import functools
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from zigpy.exceptions import DeliveryError
|
||||
import zigpy.zcl.clusters.hvac as hvac
|
||||
|
||||
from homeassistant.components.fan import (
|
||||
DOMAIN,
|
||||
@ -11,8 +15,10 @@ from homeassistant.components.fan import (
|
||||
SUPPORT_SET_SPEED,
|
||||
FanEntity,
|
||||
)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import STATE_UNAVAILABLE
|
||||
from homeassistant.core import CALLBACK_TYPE, State, callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.event import async_track_state_change
|
||||
|
||||
from .core import discovery
|
||||
from .core.const import (
|
||||
@ -21,9 +27,10 @@ from .core.const import (
|
||||
DATA_ZHA_DISPATCHERS,
|
||||
SIGNAL_ADD_ENTITIES,
|
||||
SIGNAL_ATTR_UPDATED,
|
||||
SIGNAL_REMOVE_GROUP,
|
||||
)
|
||||
from .core.registries import ZHA_ENTITIES
|
||||
from .entity import ZhaEntity
|
||||
from .entity import BaseZhaEntity, ZhaEntity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -49,6 +56,7 @@ SPEED_LIST = [
|
||||
VALUE_TO_SPEED = dict(enumerate(SPEED_LIST))
|
||||
SPEED_TO_VALUE = {speed: i for i, speed in enumerate(SPEED_LIST)}
|
||||
STRICT_MATCH = functools.partial(ZHA_ENTITIES.strict_match, DOMAIN)
|
||||
GROUP_MATCH = functools.partial(ZHA_ENTITIES.group_match, DOMAIN)
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
@ -65,31 +73,14 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
hass.data[DATA_ZHA][DATA_ZHA_DISPATCHERS].append(unsub)
|
||||
|
||||
|
||||
@STRICT_MATCH(channel_names=CHANNEL_FAN)
|
||||
class ZhaFan(ZhaEntity, FanEntity):
|
||||
"""Representation of a ZHA fan."""
|
||||
class BaseFan(BaseZhaEntity, FanEntity):
|
||||
"""Base representation of a ZHA fan."""
|
||||
|
||||
def __init__(self, unique_id, zha_device, channels, **kwargs):
|
||||
"""Init this sensor."""
|
||||
super().__init__(unique_id, zha_device, channels, **kwargs)
|
||||
self._fan_channel = self.cluster_channels.get(CHANNEL_FAN)
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_accept_signal(
|
||||
self._fan_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_restore_last_state(self, last_state):
|
||||
"""Restore previous state."""
|
||||
self._state = VALUE_TO_SPEED.get(last_state.state, last_state.state)
|
||||
|
||||
@property
|
||||
def supported_features(self) -> int:
|
||||
"""Flag supported features."""
|
||||
return SUPPORT_SET_SPEED
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the fan."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._state = None
|
||||
self._fan_channel = None
|
||||
|
||||
@property
|
||||
def speed_list(self) -> list:
|
||||
@ -109,15 +100,9 @@ class ZhaFan(ZhaEntity, FanEntity):
|
||||
return self._state != SPEED_OFF
|
||||
|
||||
@property
|
||||
def device_state_attributes(self):
|
||||
"""Return state attributes."""
|
||||
return self.state_attributes
|
||||
|
||||
@callback
|
||||
def async_set_state(self, attr_id, attr_name, value):
|
||||
"""Handle state update from channel."""
|
||||
self._state = VALUE_TO_SPEED.get(value, self._state)
|
||||
self.async_write_ha_state()
|
||||
def supported_features(self) -> int:
|
||||
"""Flag supported features."""
|
||||
return SUPPORT_SET_SPEED
|
||||
|
||||
async def async_turn_on(self, speed: str = None, **kwargs) -> None:
|
||||
"""Turn the entity on."""
|
||||
@ -135,6 +120,34 @@ class ZhaFan(ZhaEntity, FanEntity):
|
||||
await self._fan_channel.async_set_speed(SPEED_TO_VALUE[speed])
|
||||
self.async_set_state(0, "fan_mode", speed)
|
||||
|
||||
|
||||
@STRICT_MATCH(channel_names=CHANNEL_FAN)
|
||||
class ZhaFan(ZhaEntity, BaseFan):
|
||||
"""Representation of a ZHA fan."""
|
||||
|
||||
def __init__(self, unique_id, zha_device, channels, **kwargs):
|
||||
"""Init this sensor."""
|
||||
super().__init__(unique_id, zha_device, channels, **kwargs)
|
||||
self._fan_channel = self.cluster_channels.get(CHANNEL_FAN)
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Run when about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_accept_signal(
|
||||
self._fan_channel, SIGNAL_ATTR_UPDATED, self.async_set_state
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_restore_last_state(self, last_state):
|
||||
"""Restore previous state."""
|
||||
self._state = VALUE_TO_SPEED.get(last_state.state, last_state.state)
|
||||
|
||||
@callback
|
||||
def async_set_state(self, attr_id, attr_name, value):
|
||||
"""Handle state update from channel."""
|
||||
self._state = VALUE_TO_SPEED.get(value, self._state)
|
||||
self.async_write_ha_state()
|
||||
|
||||
async def async_update(self):
|
||||
"""Attempt to retrieve on off state from the fan."""
|
||||
await super().async_update()
|
||||
@ -142,3 +155,73 @@ class ZhaFan(ZhaEntity, FanEntity):
|
||||
state = await self._fan_channel.get_attribute_value("fan_mode")
|
||||
if state is not None:
|
||||
self._state = VALUE_TO_SPEED.get(state, self._state)
|
||||
|
||||
|
||||
@GROUP_MATCH()
|
||||
class FanGroup(BaseFan):
|
||||
"""Representation of a fan group."""
|
||||
|
||||
def __init__(
|
||||
self, entity_ids: List[str], unique_id: str, group_id: int, zha_device, **kwargs
|
||||
) -> None:
|
||||
"""Initialize a fan group."""
|
||||
super().__init__(unique_id, zha_device, **kwargs)
|
||||
self._name: str = f"{zha_device.gateway.groups.get(group_id).name}_group_{group_id}"
|
||||
self._group_id: int = group_id
|
||||
self._available: bool = False
|
||||
self._entity_ids: List[str] = entity_ids
|
||||
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
|
||||
group = self.zha_device.gateway.get_group(self._group_id)
|
||||
self._fan_channel = group.endpoint[hvac.Fan.cluster_id]
|
||||
|
||||
# what should we do with this hack?
|
||||
async def async_set_speed(value) -> None:
|
||||
"""Set the speed of the fan."""
|
||||
try:
|
||||
await self._fan_channel.write_attributes({"fan_mode": value})
|
||||
except DeliveryError as ex:
|
||||
self.error("Could not set speed: %s", ex)
|
||||
return
|
||||
|
||||
self._fan_channel.async_set_speed = async_set_speed
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Register callbacks."""
|
||||
await super().async_added_to_hass()
|
||||
await self.async_accept_signal(
|
||||
None,
|
||||
f"{SIGNAL_REMOVE_GROUP}_{self._group_id}",
|
||||
self.async_remove,
|
||||
signal_override=True,
|
||||
)
|
||||
|
||||
@callback
|
||||
def async_state_changed_listener(
|
||||
entity_id: str, old_state: State, new_state: State
|
||||
):
|
||||
"""Handle child updates."""
|
||||
self.async_schedule_update_ha_state(True)
|
||||
|
||||
self._async_unsub_state_changed = async_track_state_change(
|
||||
self.hass, self._entity_ids, async_state_changed_listener
|
||||
)
|
||||
await self.async_update()
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Handle removal from Home Assistant."""
|
||||
await super().async_will_remove_from_hass()
|
||||
if self._async_unsub_state_changed is not None:
|
||||
self._async_unsub_state_changed()
|
||||
self._async_unsub_state_changed = None
|
||||
|
||||
async def async_update(self):
|
||||
"""Attempt to retrieve on off state from the fan."""
|
||||
all_states = [self.hass.states.get(x) for x in self._entity_ids]
|
||||
states: List[State] = list(filter(None, all_states))
|
||||
on_states: List[State] = [state for state in states if state.state != SPEED_OFF]
|
||||
self._available = any(state.state != STATE_UNAVAILABLE for state in states)
|
||||
# for now just use first non off state since its kind of arbitrary
|
||||
if not on_states:
|
||||
self._state = SPEED_OFF
|
||||
else:
|
||||
self._state = states[0].state
|
||||
|
@ -2,10 +2,21 @@
|
||||
from unittest.mock import call
|
||||
|
||||
import pytest
|
||||
import zigpy.profiles.zha as zha
|
||||
import zigpy.zcl.clusters.general as general
|
||||
import zigpy.zcl.clusters.hvac as hvac
|
||||
|
||||
from homeassistant.components import fan
|
||||
from homeassistant.components.fan import ATTR_SPEED, DOMAIN, SERVICE_SET_SPEED
|
||||
from homeassistant.components.fan import (
|
||||
ATTR_SPEED,
|
||||
DOMAIN,
|
||||
SERVICE_SET_SPEED,
|
||||
SPEED_HIGH,
|
||||
SPEED_MEDIUM,
|
||||
SPEED_OFF,
|
||||
)
|
||||
from homeassistant.components.light import DOMAIN as LIGHT_DOMAIN
|
||||
from homeassistant.components.zha.core.discovery import GROUP_PROBE
|
||||
from homeassistant.const import (
|
||||
ATTR_ENTITY_ID,
|
||||
SERVICE_TURN_OFF,
|
||||
@ -17,11 +28,16 @@ from homeassistant.const import (
|
||||
|
||||
from .common import (
|
||||
async_enable_traffic,
|
||||
async_find_group_entity_id,
|
||||
async_test_rejoin,
|
||||
find_entity_id,
|
||||
get_zha_gateway,
|
||||
send_attributes_report,
|
||||
)
|
||||
|
||||
IEEE_GROUPABLE_DEVICE = "01:2d:6f:00:0a:90:69:e8"
|
||||
IEEE_GROUPABLE_DEVICE2 = "02:2d:6f:00:0a:90:69:e8"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zigpy_device(zigpy_device_mock):
|
||||
@ -32,6 +48,66 @@ def zigpy_device(zigpy_device_mock):
|
||||
return zigpy_device_mock(endpoints)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def coordinator(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha fan platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [],
|
||||
"out_clusters": [],
|
||||
"device_type": zha.DeviceType.COLOR_DIMMABLE_LIGHT,
|
||||
}
|
||||
},
|
||||
ieee="00:15:8d:00:02:32:4f:32",
|
||||
nwk=0x0000,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def device_fan_1(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha fan platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [general.OnOff.cluster_id, hvac.Fan.cluster_id],
|
||||
"out_clusters": [],
|
||||
}
|
||||
},
|
||||
ieee=IEEE_GROUPABLE_DEVICE,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def device_fan_2(hass, zigpy_device_mock, zha_device_joined):
|
||||
"""Test zha fan platform."""
|
||||
|
||||
zigpy_device = zigpy_device_mock(
|
||||
{
|
||||
1: {
|
||||
"in_clusters": [
|
||||
general.OnOff.cluster_id,
|
||||
hvac.Fan.cluster_id,
|
||||
general.LevelControl.cluster_id,
|
||||
],
|
||||
"out_clusters": [],
|
||||
}
|
||||
},
|
||||
ieee=IEEE_GROUPABLE_DEVICE2,
|
||||
)
|
||||
zha_device = await zha_device_joined(zigpy_device)
|
||||
zha_device.set_available(True)
|
||||
return zha_device
|
||||
|
||||
|
||||
async def test_fan(hass, zha_device_joined_restored, zigpy_device):
|
||||
"""Test zha fan platform."""
|
||||
|
||||
@ -106,3 +182,87 @@ async def async_set_speed(hass, entity_id, speed=None):
|
||||
}
|
||||
|
||||
await hass.services.async_call(DOMAIN, SERVICE_SET_SPEED, data, blocking=True)
|
||||
|
||||
|
||||
async def async_test_zha_group_fan_entity(
|
||||
hass, device_fan_1, device_fan_2, coordinator
|
||||
):
|
||||
"""Test the fan entity for a ZHA group."""
|
||||
zha_gateway = get_zha_gateway(hass)
|
||||
assert zha_gateway is not None
|
||||
zha_gateway.coordinator_zha_device = coordinator
|
||||
coordinator._zha_gateway = zha_gateway
|
||||
device_fan_1._zha_gateway = zha_gateway
|
||||
device_fan_2._zha_gateway = zha_gateway
|
||||
member_ieee_addresses = [device_fan_1.ieee, device_fan_2.ieee]
|
||||
|
||||
# test creating a group with 2 members
|
||||
zha_group = await zha_gateway.async_create_zigpy_group(
|
||||
"Test Group", member_ieee_addresses
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert len(zha_group.members) == 2
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in member_ieee_addresses
|
||||
|
||||
entity_domains = GROUP_PROBE.determine_entity_domains(zha_group)
|
||||
assert len(entity_domains) == 2
|
||||
|
||||
assert LIGHT_DOMAIN in entity_domains
|
||||
assert DOMAIN in entity_domains
|
||||
|
||||
entity_id = async_find_group_entity_id(hass, DOMAIN, zha_group)
|
||||
assert hass.states.get(entity_id) is not None
|
||||
|
||||
group_fan_cluster = zha_group.endpoint[hvac.Fan.cluster_id]
|
||||
dev1_fan_cluster = device_fan_1.endpoints[1].fan
|
||||
dev2_fan_cluster = device_fan_2.endpoints[1].fan
|
||||
|
||||
# test that the lights were created and that they are unavailable
|
||||
assert hass.states.get(entity_id).state == STATE_UNAVAILABLE
|
||||
|
||||
# allow traffic to flow through the gateway and device
|
||||
await async_enable_traffic(hass, zha_group.members)
|
||||
|
||||
# test that the fan group entity was created and is off
|
||||
assert hass.states.get(entity_id).state == STATE_OFF
|
||||
|
||||
# turn on from HA
|
||||
group_fan_cluster.write_attributes.reset_mock()
|
||||
await async_turn_on(hass, entity_id)
|
||||
assert len(group_fan_cluster.write_attributes.mock_calls) == 1
|
||||
assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 2})
|
||||
assert hass.states.get(entity_id).state == SPEED_MEDIUM
|
||||
|
||||
# turn off from HA
|
||||
group_fan_cluster.write_attributes.reset_mock()
|
||||
await async_turn_off(hass, entity_id)
|
||||
assert len(group_fan_cluster.write_attributes.mock_calls) == 1
|
||||
assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 0})
|
||||
assert hass.states.get(entity_id).state == STATE_OFF
|
||||
|
||||
# change speed from HA
|
||||
group_fan_cluster.write_attributes.reset_mock()
|
||||
await async_set_speed(hass, entity_id, speed=fan.SPEED_HIGH)
|
||||
assert len(group_fan_cluster.write_attributes.mock_calls) == 1
|
||||
assert group_fan_cluster.write_attributes.call_args == call({"fan_mode": 3})
|
||||
assert hass.states.get(entity_id).state == SPEED_HIGH
|
||||
|
||||
# test some of the group logic to make sure we key off states correctly
|
||||
await dev1_fan_cluster.async_set_speed(SPEED_OFF)
|
||||
await dev2_fan_cluster.async_set_speed(SPEED_OFF)
|
||||
|
||||
# test that group fan is off
|
||||
assert hass.states.get(entity_id).state == STATE_OFF
|
||||
|
||||
await dev1_fan_cluster.async_set_speed(SPEED_MEDIUM)
|
||||
|
||||
# test that group fan is speed medium
|
||||
assert hass.states.get(entity_id).state == SPEED_MEDIUM
|
||||
|
||||
await dev1_fan_cluster.async_set_speed(SPEED_OFF)
|
||||
|
||||
# test that group fan is now off
|
||||
assert hass.states.get(entity_id).state == STATE_OFF
|
||||
|
@ -134,7 +134,6 @@ async def test_gateway_group_methods(hass, device_light_1, device_light_2, coord
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain == LIGHT_DOMAIN
|
||||
assert len(zha_group.members) == 2
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in member_ieee_addresses
|
||||
@ -162,7 +161,6 @@ async def test_gateway_group_methods(hass, device_light_1, device_light_2, coord
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain is None
|
||||
assert len(zha_group.members) == 1
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in [device_light_1.ieee]
|
||||
|
@ -432,7 +432,6 @@ async def async_test_zha_group_light_entity(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain == DOMAIN
|
||||
assert len(zha_group.members) == 2
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in member_ieee_addresses
|
||||
|
@ -173,7 +173,6 @@ async def async_test_zha_group_switch_entity(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert zha_group is not None
|
||||
assert zha_group.entity_domain == DOMAIN
|
||||
assert len(zha_group.members) == 2
|
||||
for member in zha_group.members:
|
||||
assert member.ieee in member_ieee_addresses
|
||||
|
Loading…
Reference in New Issue
Block a user