Set default parallel_update value should base on async_update (#22149)

* Set default parallel_update value should base on async_update

* Set default parallel_update value should base on async_update

* Delay the parallel_update_semaphore creation

* Remove outdated comment
This commit is contained in:
Jason Hu 2019-03-25 23:53:36 -07:00 committed by GitHub
parent a62c116959
commit e85b089eff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 242 additions and 121 deletions

View File

@ -27,7 +27,6 @@ class EntityPlatform:
domain: str
platform_name: str
scan_interval: timedelta
parallel_updates: int
entity_namespace: str
async_entities_added_callback: @callback method
"""
@ -52,22 +51,21 @@ class EntityPlatform:
# which powers entity_component.add_entities
if platform is None:
self.parallel_updates = None
self.parallel_updates_semaphore = None
return
# Async platforms do all updates in parallel by default
if hasattr(platform, 'async_setup_platform'):
default_parallel_updates = 0
else:
default_parallel_updates = 1
self.parallel_updates = getattr(platform, 'PARALLEL_UPDATES', None)
# semaphore will be created on demand
self.parallel_updates_semaphore = None
parallel_updates = getattr(platform, 'PARALLEL_UPDATES',
default_parallel_updates)
if parallel_updates:
self.parallel_updates = asyncio.Semaphore(
parallel_updates, loop=hass.loop)
else:
self.parallel_updates = None
def _get_parallel_updates_semaphore(self):
"""Get or create a semaphore for parallel updates."""
if self.parallel_updates_semaphore is None:
self.parallel_updates_semaphore = asyncio.Semaphore(
self.parallel_updates if self.parallel_updates else 1,
loop=self.hass.loop
)
return self.parallel_updates_semaphore
async def async_setup(self, platform_config, discovery_info=None):
"""Set up the platform from a config file."""
@ -240,7 +238,22 @@ class EntityPlatform:
entity.hass = self.hass
entity.platform = self
entity.parallel_updates = self.parallel_updates
# Async entity
# PARALLEL_UPDATE == None: entity.parallel_updates = None
# PARALLEL_UPDATE == 0: entity.parallel_updates = None
# PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p)
# Sync entity
# PARALLEL_UPDATE == None: entity.parallel_updates = Semaphore(1)
# PARALLEL_UPDATE == 0: entity.parallel_updates = None
# PARALLEL_UPDATE > 0: entity.parallel_updates = Semaphore(p)
if hasattr(entity, 'async_update') and not self.parallel_updates:
entity.parallel_updates = None
elif (not hasattr(entity, 'async_update')
and self.parallel_updates == 0):
entity.parallel_updates = None
else:
entity.parallel_updates = self._get_parallel_updates_semaphore()
# Update properties before we generate the entity_id
if update_before_add:

View File

@ -1,6 +1,7 @@
"""Test the entity helper."""
# pylint: disable=protected-access
import asyncio
import threading
from datetime import timedelta
from unittest.mock import MagicMock, patch, PropertyMock
@ -225,11 +226,10 @@ def test_async_schedule_update_ha_state(hass):
assert update_call is True
@asyncio.coroutine
def test_async_parallel_updates_with_zero(hass):
async def test_async_parallel_updates_with_zero(hass):
"""Test parallel updates with 0 (disabled)."""
updates = []
test_lock = asyncio.Event(loop=hass.loop)
test_lock = asyncio.Event()
class AsyncEntity(entity.Entity):
@ -239,37 +239,73 @@ def test_async_parallel_updates_with_zero(hass):
self.hass = hass
self._count = count
@asyncio.coroutine
def async_update(self):
async def async_update(self):
"""Test update."""
updates.append(self._count)
yield from test_lock.wait()
await test_lock.wait()
ent_1 = AsyncEntity("sensor.test_1", 1)
ent_2 = AsyncEntity("sensor.test_2", 2)
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
try:
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
while True:
if len(updates) == 2:
break
yield from asyncio.sleep(0, loop=hass.loop)
while True:
if len(updates) >= 2:
break
await asyncio.sleep(0)
assert len(updates) == 2
assert updates == [1, 2]
test_lock.set()
assert len(updates) == 2
assert updates == [1, 2]
finally:
test_lock.set()
@asyncio.coroutine
def test_async_parallel_updates_with_one(hass):
async def test_async_parallel_updates_with_zero_on_sync_update(hass):
"""Test parallel updates with 0 (disabled)."""
updates = []
test_lock = threading.Event()
class AsyncEntity(entity.Entity):
def __init__(self, entity_id, count):
"""Initialize Async test entity."""
self.entity_id = entity_id
self.hass = hass
self._count = count
def update(self):
"""Test update."""
updates.append(self._count)
if not test_lock.wait(timeout=1):
# if timeout populate more data to fail the test
updates.append(self._count)
ent_1 = AsyncEntity("sensor.test_1", 1)
ent_2 = AsyncEntity("sensor.test_2", 2)
try:
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
while True:
if len(updates) >= 2:
break
await asyncio.sleep(0)
assert len(updates) == 2
assert updates == [1, 2]
finally:
test_lock.set()
await asyncio.sleep(0)
async def test_async_parallel_updates_with_one(hass):
"""Test parallel updates with 1 (sequential)."""
updates = []
test_lock = asyncio.Lock(loop=hass.loop)
test_semaphore = asyncio.Semaphore(1, loop=hass.loop)
yield from test_lock.acquire()
test_lock = asyncio.Lock()
test_semaphore = asyncio.Semaphore(1)
class AsyncEntity(entity.Entity):
@ -280,59 +316,71 @@ def test_async_parallel_updates_with_one(hass):
self._count = count
self.parallel_updates = test_semaphore
@asyncio.coroutine
def async_update(self):
async def async_update(self):
"""Test update."""
updates.append(self._count)
yield from test_lock.acquire()
await test_lock.acquire()
ent_1 = AsyncEntity("sensor.test_1", 1)
ent_2 = AsyncEntity("sensor.test_2", 2)
ent_3 = AsyncEntity("sensor.test_3", 3)
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
ent_3.async_schedule_update_ha_state(True)
await test_lock.acquire()
while True:
if len(updates) == 1:
break
yield from asyncio.sleep(0, loop=hass.loop)
try:
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
ent_3.async_schedule_update_ha_state(True)
assert len(updates) == 1
assert updates == [1]
while True:
if len(updates) >= 1:
break
await asyncio.sleep(0)
test_lock.release()
assert len(updates) == 1
assert updates == [1]
while True:
if len(updates) == 2:
break
yield from asyncio.sleep(0, loop=hass.loop)
updates.clear()
test_lock.release()
await asyncio.sleep(0)
assert len(updates) == 2
assert updates == [1, 2]
while True:
if len(updates) >= 1:
break
await asyncio.sleep(0)
test_lock.release()
assert len(updates) == 1
assert updates == [2]
while True:
if len(updates) == 3:
break
yield from asyncio.sleep(0, loop=hass.loop)
updates.clear()
test_lock.release()
await asyncio.sleep(0)
assert len(updates) == 3
assert updates == [1, 2, 3]
while True:
if len(updates) >= 1:
break
await asyncio.sleep(0)
test_lock.release()
assert len(updates) == 1
assert updates == [3]
updates.clear()
test_lock.release()
await asyncio.sleep(0)
finally:
# we may have more than one lock need to release in case test failed
for _ in updates:
test_lock.release()
await asyncio.sleep(0)
test_lock.release()
@asyncio.coroutine
def test_async_parallel_updates_with_two(hass):
async def test_async_parallel_updates_with_two(hass):
"""Test parallel updates with 2 (parallel)."""
updates = []
test_lock = asyncio.Lock(loop=hass.loop)
test_semaphore = asyncio.Semaphore(2, loop=hass.loop)
yield from test_lock.acquire()
test_lock = asyncio.Lock()
test_semaphore = asyncio.Semaphore(2)
class AsyncEntity(entity.Entity):
@ -354,34 +402,48 @@ def test_async_parallel_updates_with_two(hass):
ent_3 = AsyncEntity("sensor.test_3", 3)
ent_4 = AsyncEntity("sensor.test_4", 4)
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
ent_3.async_schedule_update_ha_state(True)
ent_4.async_schedule_update_ha_state(True)
await test_lock.acquire()
while True:
if len(updates) == 2:
break
yield from asyncio.sleep(0, loop=hass.loop)
try:
assert len(updates) == 2
assert updates == [1, 2]
ent_1.async_schedule_update_ha_state(True)
ent_2.async_schedule_update_ha_state(True)
ent_3.async_schedule_update_ha_state(True)
ent_4.async_schedule_update_ha_state(True)
test_lock.release()
yield from asyncio.sleep(0, loop=hass.loop)
test_lock.release()
while True:
if len(updates) >= 2:
break
await asyncio.sleep(0)
while True:
if len(updates) == 4:
break
yield from asyncio.sleep(0, loop=hass.loop)
assert len(updates) == 2
assert updates == [1, 2]
assert len(updates) == 4
assert updates == [1, 2, 3, 4]
updates.clear()
test_lock.release()
await asyncio.sleep(0)
test_lock.release()
await asyncio.sleep(0)
test_lock.release()
yield from asyncio.sleep(0, loop=hass.loop)
test_lock.release()
while True:
if len(updates) >= 2:
break
await asyncio.sleep(0)
assert len(updates) == 2
assert updates == [3, 4]
updates.clear()
test_lock.release()
await asyncio.sleep(0)
test_lock.release()
await asyncio.sleep(0)
finally:
# we may have more than one lock need to release in case test failed
for _ in updates:
test_lock.release()
await asyncio.sleep(0)
test_lock.release()
@asyncio.coroutine

View File

@ -251,80 +251,126 @@ def test_updated_state_used_for_entity_id(hass):
assert entity_ids[0] == "test_domain.living_room"
@asyncio.coroutine
def test_parallel_updates_async_platform(hass):
"""Warn we log when platform setup takes a long time."""
async def test_parallel_updates_async_platform(hass):
"""Test async platform does not have parallel_updates limit by default."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
loader.set_component(hass, 'test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
await component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
class AsyncEntity(MockEntity):
"""Mock entity that has async_update."""
@asyncio.coroutine
def test_parallel_updates_async_platform_with_constant(hass):
"""Warn we log when platform setup takes a long time."""
async def async_update(self):
pass
entity = AsyncEntity()
await handle.async_add_entities([entity])
assert entity.parallel_updates is None
async def test_parallel_updates_async_platform_with_constant(hass):
"""Test async platform can set parallel_updates limit."""
platform = MockPlatform()
platform.PARALLEL_UPDATES = 2
loader.set_component(hass, 'test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
await component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates == 2
class AsyncEntity(MockEntity):
"""Mock entity that has async_update."""
async def async_update(self):
pass
entity = AsyncEntity()
await handle.async_add_entities([entity])
assert entity.parallel_updates is not None
assert entity.parallel_updates._value == 2
async def test_parallel_updates_sync_platform(hass):
"""Test sync platform parallel_updates default set to 1."""
platform = MockPlatform()
@asyncio.coroutine
def mock_update(*args, **kwargs):
pass
platform.async_setup_platform = mock_update
platform.PARALLEL_UPDATES = 1
loader.set_component(hass, 'test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
await component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates is None
assert handle.parallel_updates is not None
class SyncEntity(MockEntity):
"""Mock entity that has update."""
async def update(self):
pass
entity = SyncEntity()
await handle.async_add_entities([entity])
assert entity.parallel_updates is not None
assert entity.parallel_updates._value == 1
@asyncio.coroutine
def test_parallel_updates_sync_platform(hass):
"""Warn we log when platform setup takes a long time."""
platform = MockPlatform(setup_platform=lambda *args: None)
async def test_parallel_updates_sync_platform_with_constant(hass):
"""Test sync platform can set parallel_updates limit."""
platform = MockPlatform()
platform.PARALLEL_UPDATES = 2
loader.set_component(hass, 'test_domain.platform', platform)
component = EntityComponent(_LOGGER, DOMAIN, hass)
component._platforms = {}
yield from component.async_setup({
await component.async_setup({
DOMAIN: {
'platform': 'platform',
}
})
handle = list(component._platforms.values())[-1]
assert handle.parallel_updates == 2
assert handle.parallel_updates is not None
class SyncEntity(MockEntity):
"""Mock entity that has update."""
async def update(self):
pass
entity = SyncEntity()
await handle.async_add_entities([entity])
assert entity.parallel_updates is not None
assert entity.parallel_updates._value == 2
@asyncio.coroutine