Return specific group state if there is one (#115866)

* Return specific group state if there is one

* Refactor

* Additional test cases

* Refactor

* Break out if more than one on state

* tweaks

* Remove log, add comment

* add comment

* Apply suggestions from code review

Co-authored-by: J. Nick Koston <nick@koston.org>

* Refactor and improve comments

* Refactor to class method

* More filtering

* Apply suggestions from code review

* Only active if not excluded

* Do not use a set

* Apply suggestions from code review

Co-authored-by: Erik Montnemery <erik@montnemery.com>

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Jan Bouwhuis 2024-04-24 15:12:29 +02:00 committed by GitHub
parent 1f4585cc9e
commit 350ca48d4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 109 additions and 24 deletions

View File

@ -8,7 +8,7 @@ from collections.abc import Callable, Collection, Mapping
import logging
from typing import Any
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_ON
from homeassistant.const import ATTR_ASSUMED_STATE, ATTR_ENTITY_ID, STATE_OFF, STATE_ON
from homeassistant.core import (
CALLBACK_TYPE,
Event,
@ -131,6 +131,9 @@ class Group(Entity):
_unrecorded_attributes = frozenset({ATTR_ENTITY_ID, ATTR_ORDER, ATTR_AUTO})
_attr_should_poll = False
# In case there is only one active domain we use specific ON or OFF
# values, if all ON or OFF states are equal
single_active_domain: str | None
tracking: tuple[str, ...]
trackable: tuple[str, ...]
@ -287,6 +290,7 @@ class Group(Entity):
if not entity_ids:
self.tracking = ()
self.trackable = ()
self.single_active_domain = None
return
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
@ -294,12 +298,22 @@ class Group(Entity):
tracking: list[str] = []
trackable: list[str] = []
self.single_active_domain = None
multiple_domains: bool = False
for ent_id in entity_ids:
ent_id_lower = ent_id.lower()
domain = split_entity_id(ent_id_lower)[0]
tracking.append(ent_id_lower)
if domain not in excluded_domains:
trackable.append(ent_id_lower)
if domain in excluded_domains:
continue
trackable.append(ent_id_lower)
if not multiple_domains and self.single_active_domain is None:
self.single_active_domain = domain
if self.single_active_domain != domain:
multiple_domains = True
self.single_active_domain = None
self.trackable = tuple(trackable)
self.tracking = tuple(tracking)
@ -395,10 +409,36 @@ class Group(Entity):
self._on_off[entity_id] = state in registry.on_off_mapping
else:
entity_on_state = registry.on_states_by_domain[domain]
if domain in registry.on_states_by_domain:
self._on_states.update(entity_on_state)
self._on_states.update(entity_on_state)
self._on_off[entity_id] = state in entity_on_state
def _detect_specific_on_off_state(self, group_is_on: bool) -> set[str]:
"""Check if a specific ON or OFF state is possible."""
# In case the group contains entities of the same domain with the same ON
# or an OFF state (one or more domains), we want to use that specific state.
# If we have more then one ON or OFF state we default to STATE_ON or STATE_OFF.
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
active_on_states: set[str] = set()
active_off_states: set[str] = set()
for entity_id in self.trackable:
if (state := self.hass.states.get(entity_id)) is None:
continue
current_state = state.state
if (
group_is_on
and (domain_on_states := registry.on_states_by_domain.get(state.domain))
and current_state in domain_on_states
):
active_on_states.add(current_state)
# If we have more than one on state, the group state
# will result in STATE_ON and we can stop checking
if len(active_on_states) > 1:
break
elif current_state in registry.off_on_mapping:
active_off_states.add(current_state)
return active_on_states if group_is_on else active_off_states
@callback
def _async_update_group_state(self, tr_state: State | None = None) -> None:
"""Update group state.
@ -425,27 +465,48 @@ class Group(Entity):
elif tr_state.attributes.get(ATTR_ASSUMED_STATE):
self._assumed_state = True
num_on_states = len(self._on_states)
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
if (num_on_states := len(self._on_states)) == 0:
self._state = None
return
group_is_on = self.mode(self._on_off.values())
# If all the entity domains we are tracking
# have the same on state we use this state
# and its hass.data[REG_KEY].on_off_mapping to off
if num_on_states == 1:
on_state = list(self._on_states)[0]
# If we do not have an on state for any domains
# we use None (which will be STATE_UNKNOWN)
elif num_on_states == 0:
self._state = None
return
on_state = next(iter(self._on_states))
# If the entity domains have more than one
# on state, we use STATE_ON/STATE_OFF
else:
# on state, we use STATE_ON/STATE_OFF, unless there is
# only one specific `on` state in use for one specific domain
elif self.single_active_domain and num_on_states:
active_on_states = self._detect_specific_on_off_state(True)
on_state = (
list(active_on_states)[0] if len(active_on_states) == 1 else STATE_ON
)
elif group_is_on:
on_state = STATE_ON
group_is_on = self.mode(self._on_off.values())
if group_is_on:
self._state = on_state
return
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
if (
active_domain := self.single_active_domain
) and active_domain in registry.off_state_by_domain:
# If there is only one domain used,
# then we return the off state for that domain.s
self._state = registry.off_state_by_domain[active_domain]
else:
registry: GroupIntegrationRegistry = self.hass.data[REG_KEY]
self._state = registry.on_off_mapping[on_state]
active_off_states = self._detect_specific_on_off_state(False)
# If there is one off state in use then we return that specific state,
# also if there a multiple domains involved, e.g.
# person and device_tracker, with a shared state.
self._state = (
list(active_off_states)[0] if len(active_off_states) == 1 else STATE_OFF
)
def async_get_component(hass: HomeAssistant) -> EntityComponent[Group]:

View File

@ -49,9 +49,12 @@ class GroupIntegrationRegistry:
def __init__(self) -> None:
"""Imitialize registry."""
self.on_off_mapping: dict[str, str] = {STATE_ON: STATE_OFF}
self.on_off_mapping: dict[str, dict[str | None, str]] = {
STATE_ON: {None: STATE_OFF}
}
self.off_on_mapping: dict[str, str] = {STATE_OFF: STATE_ON}
self.on_states_by_domain: dict[str, set[str]] = {}
self.off_state_by_domain: dict[str, str] = {}
self.exclude_domains: set[str] = set()
def exclude_domain(self) -> None:
@ -60,11 +63,14 @@ class GroupIntegrationRegistry:
def on_off_states(self, on_states: set, off_state: str) -> None:
"""Register on and off states for the current domain."""
domain = current_domain.get()
for on_state in on_states:
if on_state not in self.on_off_mapping:
self.on_off_mapping[on_state] = off_state
self.on_off_mapping[on_state] = {domain: off_state}
else:
self.on_off_mapping[on_state][domain] = off_state
if len(on_states) == 1 and off_state not in self.off_on_mapping:
self.off_on_mapping[off_state] = list(on_states)[0]
self.on_states_by_domain[current_domain.get()] = set(on_states)
self.on_states_by_domain[domain] = set(on_states)
self.off_state_by_domain[domain] = off_state

View File

@ -9,7 +9,7 @@ from unittest.mock import patch
import pytest
from homeassistant.components import group
from homeassistant.components import group, vacuum
from homeassistant.const import (
ATTR_ASSUMED_STATE,
ATTR_FRIENDLY_NAME,
@ -659,6 +659,24 @@ async def test_is_on(hass: HomeAssistant) -> None:
(STATE_ON, True),
(STATE_OFF, False),
),
(
("vacuum", "vacuum"),
# Cleaning is the only on state
(vacuum.STATE_DOCKED, vacuum.STATE_CLEANING),
# Returning is the only on state
(vacuum.STATE_RETURNING, vacuum.STATE_PAUSED),
(vacuum.STATE_CLEANING, True),
(vacuum.STATE_RETURNING, True),
),
(
("vacuum", "vacuum"),
# Multiple on states, so group state will be STATE_ON
(vacuum.STATE_RETURNING, vacuum.STATE_CLEANING),
# Only off states, so group state will be off
(vacuum.STATE_PAUSED, vacuum.STATE_IDLE),
(STATE_ON, True),
(STATE_OFF, False),
),
],
)
async def test_is_on_and_state_mixed_domains(
@ -1220,7 +1238,7 @@ async def test_group_climate_all_cool(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
assert hass.states.get("group.group_zero").state == STATE_ON
assert hass.states.get("group.group_zero").state == "cool"
async def test_group_climate_all_off(hass: HomeAssistant) -> None:
@ -1334,7 +1352,7 @@ async def test_group_vacuum_on(hass: HomeAssistant) -> None:
)
await hass.async_block_till_done()
assert hass.states.get("group.group_zero").state == STATE_ON
assert hass.states.get("group.group_zero").state == "cleaning"
async def test_device_tracker_not_home(hass: HomeAssistant) -> None: