1
mirror of https://github.com/home-assistant/core synced 2024-09-15 17:29:45 +02:00

Improve ZHA startup performance (#70111)

* Remove semaphores and background mains init

* additional logging

* correct cache usage and update tests
This commit is contained in:
David F. Mulcahey 2022-04-27 11:24:26 -04:00 committed by GitHub
parent 02ddfd513a
commit 361119d5c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 76 additions and 45 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
from collections.abc import Coroutine
from typing import TYPE_CHECKING, Any, TypeVar
import zigpy.endpoint
@ -50,7 +49,6 @@ class Channels:
self._pools: list[ChannelPool] = []
self._power_config: base.ZigbeeChannel | None = None
self._identify: base.ZigbeeChannel | None = None
self._semaphore = asyncio.Semaphore(3)
self._unique_id = str(zha_device.ieee)
self._zdo_channel = base.ZDOChannel(zha_device.device.endpoints[0], zha_device)
self._zha_device = zha_device
@ -82,11 +80,6 @@ class Channels:
if self._identify is None:
self._identify = channel
@property
def semaphore(self) -> asyncio.Semaphore:
"""Return semaphore for concurrent tasks."""
return self._semaphore
@property
def zdo_channel(self) -> base.ZDOChannel:
"""Return ZDO channel."""
@ -336,13 +329,8 @@ class ChannelPool:
async def _execute_channel_tasks(self, func_name: str, *args: Any) -> None:
"""Add a throttled channel task and swallow exceptions."""
async def _throttle(coro: Coroutine[Any, Any, None]) -> None:
async with self._channels.semaphore:
return await coro
channels = [*self.claimed_channels.values(), *self.client_channels.values()]
tasks = [_throttle(getattr(ch, func_name)(*args)) for ch in channels]
tasks = [getattr(ch, func_name)(*args) for ch in channels]
results = await asyncio.gather(*tasks, return_exceptions=True)
for channel, outcome in zip(channels, results):
if isinstance(outcome, Exception):

View File

@ -310,11 +310,14 @@ class ZigbeeChannel(LogMixin):
"""Set cluster binding and attribute reporting."""
if not self._ch_pool.skip_configuration:
if self.BIND:
self.debug("Performing cluster binding")
await self.bind()
if self.cluster.is_server:
self.debug("Configuring cluster attribute reporting")
await self.configure_reporting()
ch_specific_cfg = getattr(self, "async_configure_channel_specific", None)
if ch_specific_cfg:
self.debug("Performing channel specific configuration")
await ch_specific_cfg()
self.debug("finished channel configuration")
else:
@ -325,6 +328,7 @@ class ZigbeeChannel(LogMixin):
async def async_initialize(self, from_cache: bool) -> None:
"""Initialize channel."""
if not from_cache and self._ch_pool.skip_configuration:
self.debug("Skipping channel initialization")
self._status = ChannelStatus.INITIALIZED
return
@ -334,12 +338,23 @@ class ZigbeeChannel(LogMixin):
uncached.extend([cfg["attr"] for cfg in self.REPORT_CONFIG])
if cached:
await self._get_attributes(True, cached, from_cache=True)
self.debug("initializing cached channel attributes: %s", cached)
await self._get_attributes(
True, cached, from_cache=True, only_cache=from_cache
)
if uncached:
await self._get_attributes(True, uncached, from_cache=from_cache)
self.debug(
"initializing uncached channel attributes: %s - from cache[%s]",
uncached,
from_cache,
)
await self._get_attributes(
True, uncached, from_cache=from_cache, only_cache=from_cache
)
ch_specific_init = getattr(self, "async_initialize_channel_specific", None)
if ch_specific_init:
self.debug("Performing channel specific initialization: %s", uncached)
await ch_specific_init(from_cache=from_cache)
self.debug("finished channel initialization")
@ -407,7 +422,7 @@ class ZigbeeChannel(LogMixin):
self._cluster,
[attribute],
allow_cache=from_cache,
only_cache=from_cache and not self._ch_pool.is_mains_powered,
only_cache=from_cache,
manufacturer=manufacturer,
)
return result.get(attribute)
@ -417,6 +432,7 @@ class ZigbeeChannel(LogMixin):
raise_exceptions: bool,
attributes: list[int | str],
from_cache: bool = True,
only_cache: bool = True,
) -> dict[int | str, Any]:
"""Get the values for a list of attributes."""
manufacturer = None
@ -428,17 +444,18 @@ class ZigbeeChannel(LogMixin):
result = {}
while chunk:
try:
self.debug("Reading attributes in chunks: %s", chunk)
read, _ = await self.cluster.read_attributes(
attributes,
allow_cache=from_cache,
only_cache=from_cache and not self._ch_pool.is_mains_powered,
only_cache=only_cache,
manufacturer=manufacturer,
)
result.update(read)
except (asyncio.TimeoutError, zigpy.exceptions.ZigbeeException) as ex:
self.debug(
"failed to get attributes '%s' on '%s' cluster: %s",
attributes,
chunk,
self.cluster.ep_attribute,
str(ex),
)

View File

@ -463,7 +463,9 @@ class PowerConfigurationChannel(ZigbeeChannel):
"battery_size",
"battery_quantity",
]
return self.get_attributes(attributes, from_cache=from_cache)
return self.get_attributes(
attributes, from_cache=from_cache, only_cache=from_cache
)
@registries.ZIGBEE_CHANNEL_REGISTRY.register(general.PowerProfile.cluster_id)

View File

@ -97,7 +97,7 @@ class ElectricalMeasurementChannel(ZigbeeChannel):
for a in self.REPORT_CONFIG
if a["attr"] not in self.cluster.unsupported_attributes
]
result = await self.get_attributes(attrs, from_cache=False)
result = await self.get_attributes(attrs, from_cache=False, only_cache=False)
if result:
for attr, value in result.items():
self.async_send_signal(

View File

@ -351,11 +351,15 @@ class ZHADevice(LogMixin):
if self.is_coordinator:
return
if self.last_seen is None:
self.debug("last_seen is None, marking the device unavailable")
self.update_available(False)
return
difference = time.time() - self.last_seen
if difference < self.consider_unavailable_time:
self.debug(
"Device seen - marking the device available and resetting counter"
)
self.update_available(True)
self._checkins_missed_count = 0
return
@ -365,6 +369,10 @@ class ZHADevice(LogMixin):
or self.manufacturer == "LUMI"
or not self._channels.pools
):
self.debug(
"last_seen is %s seconds ago and ping attempts have been exhausted, marking the device unavailable",
difference,
)
self.update_available(False)
return
@ -386,13 +394,23 @@ class ZHADevice(LogMixin):
def update_available(self, available: bool) -> None:
"""Update device availability and signal entities."""
self.debug(
"Update device availability - device available: %s - new availability: %s - changed: %s",
self.available,
available,
self.available ^ available,
)
availability_changed = self.available ^ available
self.available = available
if availability_changed and available:
# reinit channels then signal entities
self.debug(
"Device availability changed and device became available, reinitializing channels"
)
self.hass.async_create_task(self._async_became_available())
return
if availability_changed and not available:
self.debug("Device availability changed and device became unavailable")
self._channels.zha_send_event(
{
"device_event_type": "device_offline",

View File

@ -239,29 +239,25 @@ class ZHAGateway:
async def async_initialize_devices_and_entities(self) -> None:
"""Initialize devices and load entities."""
semaphore = asyncio.Semaphore(2)
async def _throttle(zha_device: ZHADevice, cached: bool) -> None:
async with semaphore:
await zha_device.async_initialize(from_cache=cached)
_LOGGER.debug("Loading battery powered devices")
_LOGGER.warning("Loading all devices")
await asyncio.gather(
*(
_throttle(dev, cached=True)
for dev in self.devices.values()
if not dev.is_mains_powered
)
*(dev.async_initialize(from_cache=True) for dev in self.devices.values())
)
_LOGGER.debug("Loading mains powered devices")
await asyncio.gather(
*(
_throttle(dev, cached=False)
for dev in self.devices.values()
if dev.is_mains_powered
async def fetch_updated_state() -> None:
"""Fetch updated state for mains powered devices."""
_LOGGER.warning("Fetching current state for mains powered devices")
await asyncio.gather(
*(
dev.async_initialize(from_cache=False)
for dev in self.devices.values()
if dev.is_mains_powered
)
)
)
# background the fetching of state for mains powered devices
asyncio.create_task(fetch_updated_state())
def device_joined(self, device: zigpy.device.Device) -> None:
"""Handle device joined.

View File

@ -488,7 +488,7 @@ class Light(BaseLight, ZhaEntity):
]
results = await self._color_channel.get_attributes(
attributes, from_cache=False
attributes, from_cache=False, only_cache=False
)
if (color_mode := results.get("color_mode")) is not None:

View File

@ -177,6 +177,7 @@ def zha_device_joined(hass, setup_zha):
"""Return a newly joined ZHA device."""
async def _zha_device(zigpy_dev):
zigpy_dev.last_seen = time.time()
await setup_zha()
zha_gateway = common.get_zha_gateway(hass)
await zha_gateway.async_device_initialized(zigpy_dev)

View File

@ -106,7 +106,7 @@ def _send_time_changed(hass, seconds):
@patch(
"homeassistant.components.zha.core.channels.general.BasicChannel.async_initialize",
new=mock.MagicMock(),
new=mock.AsyncMock(),
)
async def test_check_available_success(
hass, device_with_basic_channel, zha_device_restored
@ -160,7 +160,7 @@ async def test_check_available_success(
@patch(
"homeassistant.components.zha.core.channels.general.BasicChannel.async_initialize",
new=mock.MagicMock(),
new=mock.AsyncMock(),
)
async def test_check_available_unsuccessful(
hass, device_with_basic_channel, zha_device_restored
@ -203,7 +203,7 @@ async def test_check_available_unsuccessful(
@patch(
"homeassistant.components.zha.core.channels.general.BasicChannel.async_initialize",
new=mock.MagicMock(),
new=mock.AsyncMock(),
)
async def test_check_available_no_basic_channel(
hass, device_without_basic_channel, zha_device_restored, caplog

View File

@ -471,7 +471,10 @@ async def test_fan_update_entity(
assert hass.states.get(entity_id).attributes[ATTR_PERCENTAGE] == 0
assert hass.states.get(entity_id).attributes[ATTR_PRESET_MODE] is None
assert hass.states.get(entity_id).attributes[ATTR_PERCENTAGE_STEP] == 100 / 3
assert cluster.read_attributes.await_count == 2
if zha_device_joined_restored.name == "zha_device_joined":
assert cluster.read_attributes.await_count == 2
else:
assert cluster.read_attributes.await_count == 4
await async_setup_component(hass, "homeassistant", {})
await hass.async_block_till_done()
@ -480,7 +483,10 @@ async def test_fan_update_entity(
"homeassistant", "update_entity", {"entity_id": entity_id}, blocking=True
)
assert hass.states.get(entity_id).state == STATE_OFF
assert cluster.read_attributes.await_count == 3
if zha_device_joined_restored.name == "zha_device_joined":
assert cluster.read_attributes.await_count == 3
else:
assert cluster.read_attributes.await_count == 5
cluster.PLUGGED_ATTR_READS = {"fan_mode": 1}
await hass.services.async_call(
@ -490,4 +496,7 @@ async def test_fan_update_entity(
assert hass.states.get(entity_id).attributes[ATTR_PERCENTAGE] == 33
assert hass.states.get(entity_id).attributes[ATTR_PRESET_MODE] is None
assert hass.states.get(entity_id).attributes[ATTR_PERCENTAGE_STEP] == 100 / 3
assert cluster.read_attributes.await_count == 4
if zha_device_joined_restored.name == "zha_device_joined":
assert cluster.read_attributes.await_count == 4
else:
assert cluster.read_attributes.await_count == 6