1
mirror of https://github.com/home-assistant/core synced 2024-07-30 21:18:57 +02:00

Update new values coming in for dev registry (#16852)

* Update new values coming in for dev registry

* fix Lint+Test;2C
This commit is contained in:
Paulus Schoutsen 2018-09-27 11:26:58 +02:00 committed by GitHub
parent 29db43edb2
commit da3342f1aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 39 deletions

View File

@ -26,11 +26,12 @@ CONNECTION_ZIGBEE = 'zigbee'
class DeviceEntry:
"""Device Registry Entry."""
config_entries = attr.ib(type=set, converter=set)
connections = attr.ib(type=set, converter=set)
identifiers = attr.ib(type=set, converter=set)
manufacturer = attr.ib(type=str)
model = attr.ib(type=str)
config_entries = attr.ib(type=set, converter=set,
default=attr.Factory(set))
connections = attr.ib(type=set, converter=set, default=attr.Factory(set))
identifiers = attr.ib(type=set, converter=set, default=attr.Factory(set))
manufacturer = attr.ib(type=str, default=None)
model = attr.ib(type=str, default=None)
name = attr.ib(type=str, default=None)
sw_version = attr.ib(type=str, default=None)
hub_device_id = attr.ib(type=str, default=None)
@ -56,46 +57,53 @@ class DeviceRegistry:
return None
@callback
def async_get_or_create(self, *, config_entry_id, connections, identifiers,
manufacturer, model, name=None, sw_version=None,
def async_get_or_create(self, *, config_entry_id, connections=None,
identifiers=None, manufacturer=_UNDEF,
model=_UNDEF, name=_UNDEF, sw_version=_UNDEF,
via_hub=None):
"""Get device. Create if it doesn't exist."""
if not identifiers and not connections:
return None
if identifiers is None:
identifiers = set()
if connections is None:
connections = set()
device = self.async_get_device(identifiers, connections)
if device is None:
device = DeviceEntry()
self.devices[device.id] = device
if via_hub is not None:
hub_device = self.async_get_device({via_hub}, set())
hub_device_id = hub_device.id if hub_device else None
hub_device_id = hub_device.id if hub_device else _UNDEF
else:
hub_device_id = None
hub_device_id = _UNDEF
if device is not None:
return self._async_update_device(
device.id, config_entry_id=config_entry_id,
hub_device_id=hub_device_id
)
device = DeviceEntry(
config_entries={config_entry_id},
connections=connections,
identifiers=identifiers,
return self._async_update_device(
device.id,
add_config_entry_id=config_entry_id,
hub_device_id=hub_device_id,
merge_connections=connections,
merge_identifiers=identifiers,
manufacturer=manufacturer,
model=model,
name=name,
sw_version=sw_version,
hub_device_id=hub_device_id
)
self.devices[device.id] = device
self.async_schedule_save()
return device
@callback
def _async_update_device(self, device_id, *, config_entry_id=_UNDEF,
def _async_update_device(self, device_id, *, add_config_entry_id=_UNDEF,
remove_config_entry_id=_UNDEF,
merge_connections=_UNDEF,
merge_identifiers=_UNDEF,
manufacturer=_UNDEF,
model=_UNDEF,
name=_UNDEF,
sw_version=_UNDEF,
hub_device_id=_UNDEF):
"""Update device attributes."""
old = self.devices[device_id]
@ -104,21 +112,34 @@ class DeviceRegistry:
config_entries = old.config_entries
if (config_entry_id is not _UNDEF and
config_entry_id not in old.config_entries):
config_entries = old.config_entries | {config_entry_id}
if (add_config_entry_id is not _UNDEF and
add_config_entry_id not in old.config_entries):
config_entries = old.config_entries | {add_config_entry_id}
if (remove_config_entry_id is not _UNDEF and
remove_config_entry_id in config_entries):
config_entries = set(config_entries)
config_entries.remove(remove_config_entry_id)
config_entries = config_entries - {remove_config_entry_id}
if config_entries is not old.config_entries:
changes['config_entries'] = config_entries
if (hub_device_id is not _UNDEF and
hub_device_id != old.hub_device_id):
changes['hub_device_id'] = hub_device_id
for attr_name, value in (
('connections', merge_connections),
('identifiers', merge_identifiers),
):
old_value = getattr(old, attr_name)
if value is not _UNDEF and value != old_value:
changes[attr_name] = old_value | value
for attr_name, value in (
('manufacturer', manufacturer),
('model', model),
('name', name),
('sw_version', sw_version),
('hub_device_id', hub_device_id),
):
if value is not _UNDEF and value != getattr(old, attr_name):
changes[attr_name] = value
if not changes:
return old

View File

@ -27,7 +27,6 @@ async def test_list_devices(hass, client, registry):
manufacturer='manufacturer', model='model')
registry.async_get_or_create(
config_entry_id='1234',
connections={},
identifiers={('bridgeid', '1234')},
manufacturer='manufacturer', model='model',
via_hub=('bridgeid', '0123'))

View File

@ -17,7 +17,10 @@ async def test_get_or_create_returns_same_entry(registry):
config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '0123')},
manufacturer='manufacturer', model='model')
sw_version='sw-version',
name='name',
manufacturer='manufacturer',
model='model')
entry2 = registry.async_get_or_create(
config_entry_id='1234',
connections={('ethernet', '11:22:33:44:55:66:77:88')},
@ -25,15 +28,19 @@ async def test_get_or_create_returns_same_entry(registry):
manufacturer='manufacturer', model='model')
entry3 = registry.async_get_or_create(
config_entry_id='1234',
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')},
identifiers={('bridgeid', '1234')},
manufacturer='manufacturer', model='model')
connections={('ethernet', '12:34:56:78:90:AB:CD:EF')}
)
assert len(registry.devices) == 1
assert entry.id == entry2.id
assert entry.id == entry3.id
assert entry.identifiers == {('bridgeid', '0123')}
assert entry3.manufacturer == 'manufacturer'
assert entry3.model == 'model'
assert entry3.name == 'name'
assert entry3.sw_version == 'sw-version'
async def test_requirement_for_identifier_or_connection(registry):
"""Make sure we do require some descriptor of device."""