diff --git a/homeassistant/components/config/device_registry.py b/homeassistant/components/config/device_registry.py index 1cc63297352e..4d4b8333d701 100644 --- a/homeassistant/components/config/device_registry.py +++ b/homeassistant/components/config/device_registry.py @@ -7,7 +7,10 @@ from homeassistant.components.websocket_api.decorators import ( require_admin, ) from homeassistant.core import callback -from homeassistant.helpers.device_registry import DISABLED_USER, async_get_registry +from homeassistant.helpers.device_registry import ( + DeviceEntryDisabler, + async_get_registry, +) WS_TYPE_LIST = "config/device_registry/list" SCHEMA_WS_LIST = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( @@ -22,7 +25,8 @@ SCHEMA_WS_UPDATE = websocket_api.BASE_COMMAND_MESSAGE_SCHEMA.extend( vol.Optional("area_id"): vol.Any(str, None), vol.Optional("name_by_user"): vol.Any(str, None), # We only allow setting disabled_by user via API. - vol.Optional("disabled_by"): vol.Any(DISABLED_USER, None), + # No Enum support like this in voluptuous, use .value + vol.Optional("disabled_by"): vol.Any(DeviceEntryDisabler.USER.value, None), } ) diff --git a/homeassistant/helpers/device_registry.py b/homeassistant/helpers/device_registry.py index b99e80e197a0..45a3d2bc7958 100644 --- a/homeassistant/helpers/device_registry.py +++ b/homeassistant/helpers/device_registry.py @@ -41,10 +41,6 @@ CONNECTION_NETWORK_MAC = "mac" CONNECTION_UPNP = "upnp" CONNECTION_ZIGBEE = "zigbee" -DISABLED_CONFIG_ENTRY = "config_entry" -DISABLED_INTEGRATION = "integration" -DISABLED_USER = "user" - ORPHANED_DEVICE_KEEP_SECONDS = 86400 * 30 @@ -53,6 +49,14 @@ class _DeviceIndex(NamedTuple): connections: dict[tuple[str, str], str] +class DeviceEntryDisabler(StrEnum): + """What disabled a device entry.""" + + CONFIG_ENTRY = "config_entry" + INTEGRATION = "integration" + USER = "user" + + class DeviceEntryType(StrEnum): """Device entry type.""" @@ -67,17 +71,7 @@ class DeviceEntry: config_entries: set[str] = attr.ib(converter=set, factory=set) configuration_url: str | None = attr.ib(default=None) connections: set[tuple[str, str]] = attr.ib(converter=set, factory=set) - disabled_by: str | None = attr.ib( - default=None, - validator=attr.validators.in_( - ( - DISABLED_CONFIG_ENTRY, - DISABLED_INTEGRATION, - DISABLED_USER, - None, - ) - ), - ) + disabled_by: DeviceEntryDisabler | None = attr.ib(default=None) entry_type: DeviceEntryType | None = attr.ib(default=None) id: str = attr.ib(factory=uuid_util.random_uuid_hex) identifiers: set[tuple[str, str]] = attr.ib(converter=set, factory=set) @@ -302,7 +296,7 @@ class DeviceRegistry: default_model: str | None | UndefinedType = UNDEFINED, default_name: str | None | UndefinedType = UNDEFINED, # To disable a device if it gets created - disabled_by: str | None | UndefinedType = UNDEFINED, + disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED, identifiers: set[tuple[str, str]] | None = None, manufacturer: str | None | UndefinedType = UNDEFINED, @@ -389,7 +383,7 @@ class DeviceRegistry: add_config_entry_id: str | UndefinedType = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED, configuration_url: str | None | UndefinedType = UNDEFINED, - disabled_by: str | None | UndefinedType = UNDEFINED, + disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED, model: str | None | UndefinedType = UNDEFINED, name_by_user: str | None | UndefinedType = UNDEFINED, @@ -426,7 +420,7 @@ class DeviceRegistry: add_config_entry_id: str | UndefinedType = UNDEFINED, area_id: str | None | UndefinedType = UNDEFINED, configuration_url: str | None | UndefinedType = UNDEFINED, - disabled_by: str | None | UndefinedType = UNDEFINED, + disabled_by: DeviceEntryDisabler | None | UndefinedType = UNDEFINED, entry_type: DeviceEntryType | None | UndefinedType = UNDEFINED, manufacturer: str | None | UndefinedType = UNDEFINED, merge_connections: set[tuple[str, str]] | UndefinedType = UNDEFINED, @@ -447,6 +441,16 @@ class DeviceRegistry: config_entries = old.config_entries + if isinstance(disabled_by, str) and not isinstance( + disabled_by, DeviceEntryDisabler + ): + report( # type: ignore[unreachable] + "uses str for device registry disabled_by. This is deprecated, " + "it should be updated to use DeviceEntryDisabler instead", + error_if_core=False, + ) + disabled_by = DeviceEntryDisabler(disabled_by) + if ( suggested_area not in (UNDEFINED, None, "") and area_id is UNDEFINED @@ -737,7 +741,7 @@ def async_config_entry_disabled_by_changed( Disable devices in the registry that are associated with a config entry when the config entry is disabled, enable devices in the registry that are associated with a config entry when the config entry is enabled and the devices are marked - DISABLED_CONFIG_ENTRY. + DeviceEntryDisabler.CONFIG_ENTRY. Only disable a device if all associated config entries are disabled. """ @@ -745,7 +749,7 @@ def async_config_entry_disabled_by_changed( if not config_entry.disabled_by: for device in devices: - if device.disabled_by != DISABLED_CONFIG_ENTRY: + if device.disabled_by is not DeviceEntryDisabler.CONFIG_ENTRY: continue registry.async_update_device(device.id, disabled_by=None) return @@ -764,7 +768,9 @@ def async_config_entry_disabled_by_changed( enabled_config_entries ): continue - registry.async_update_device(device.id, disabled_by=DISABLED_CONFIG_ENTRY) + registry.async_update_device( + device.id, disabled_by=DeviceEntryDisabler.CONFIG_ENTRY + ) @callback diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index f4b733eac560..036f235e132c 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -390,7 +390,7 @@ class EntityRegistry: self.async_update_entity(entity.entity_id, disabled_by=None) return - if device.disabled_by == dr.DISABLED_CONFIG_ENTRY: + if device.disabled_by is dr.DeviceEntryDisabler.CONFIG_ENTRY: # Handled by async_config_entry_disabled return diff --git a/tests/components/config/test_device_registry.py b/tests/components/config/test_device_registry.py index edc167405ac9..ee8c933f7619 100644 --- a/tests/components/config/test_device_registry.py +++ b/tests/components/config/test_device_registry.py @@ -97,7 +97,7 @@ async def test_update_device(hass, client, registry): "device_id": device.id, "area_id": "12345A", "name_by_user": "Test Friendly Name", - "disabled_by": helpers_dr.DISABLED_USER, + "disabled_by": helpers_dr.DeviceEntryDisabler.USER, "type": "config/device_registry/update", } ) @@ -107,5 +107,5 @@ async def test_update_device(hass, client, registry): assert msg["result"]["id"] == device.id assert msg["result"]["area_id"] == "12345A" assert msg["result"]["name_by_user"] == "Test Friendly Name" - assert msg["result"]["disabled_by"] == helpers_dr.DISABLED_USER + assert msg["result"]["disabled_by"] == helpers_dr.DeviceEntryDisabler.USER assert len(registry.devices) == 1 diff --git a/tests/components/config/test_entity_registry.py b/tests/components/config/test_entity_registry.py index 074d8e223dae..17762f20df34 100644 --- a/tests/components/config/test_entity_registry.py +++ b/tests/components/config/test_entity_registry.py @@ -3,6 +3,7 @@ import pytest from homeassistant.components.config import entity_registry from homeassistant.const import ATTR_ICON +from homeassistant.helpers.device_registry import DeviceEntryDisabler from homeassistant.helpers.entity_registry import DISABLED_USER, RegistryEntry from tests.common import ( @@ -325,7 +326,7 @@ async def test_enable_entity_disabled_device(hass, client, device_registry): identifiers={("bridgeid", "0123")}, manufacturer="manufacturer", model="model", - disabled_by=DISABLED_USER, + disabled_by=DeviceEntryDisabler.USER, ) mock_registry( diff --git a/tests/helpers/test_device_registry.py b/tests/helpers/test_device_registry.py index 2955c02345f7..a689cc9ac3d0 100644 --- a/tests/helpers/test_device_registry.py +++ b/tests/helpers/test_device_registry.py @@ -176,7 +176,7 @@ async def test_loading_from_storage(hass, hass_storage): "config_entries": ["1234"], "configuration_url": None, "connections": [["Zigbee", "01.23.45.67.89"]], - "disabled_by": device_registry.DISABLED_USER, + "disabled_by": device_registry.DeviceEntryDisabler.USER, "entry_type": device_registry.DeviceEntryType.SERVICE, "id": "abcdefghijklm", "identifiers": [["serial", "12:34:56:AB:CD:EF"]], @@ -216,7 +216,7 @@ async def test_loading_from_storage(hass, hass_storage): assert entry.area_id == "12345A" assert entry.name_by_user == "Test Friendly Name" assert entry.entry_type is device_registry.DeviceEntryType.SERVICE - assert entry.disabled_by == device_registry.DISABLED_USER + assert entry.disabled_by is device_registry.DeviceEntryDisabler.USER assert isinstance(entry.config_entries, set) assert isinstance(entry.connections, set) assert isinstance(entry.identifiers, set) @@ -574,7 +574,7 @@ async def test_loading_saving_data(hass, registry, area_registry): manufacturer="manufacturer", model="light", via_device=("hue", "0123"), - disabled_by=device_registry.DISABLED_USER, + disabled_by=device_registry.DeviceEntryDisabler.USER, ) orig_light2 = registry.async_get_or_create( @@ -623,7 +623,7 @@ async def test_loading_saving_data(hass, registry, area_registry): manufacturer="manufacturer", model="light", via_device=("hue", "0123"), - disabled_by=device_registry.DISABLED_USER, + disabled_by=device_registry.DeviceEntryDisabler.USER, suggested_area="Kitchen", ) @@ -732,7 +732,7 @@ async def test_update(registry): name_by_user="Test Friendly Name", new_identifiers=new_identifiers, via_device_id="98765B", - disabled_by=device_registry.DISABLED_USER, + disabled_by=device_registry.DeviceEntryDisabler.USER, ) assert mock_save.call_count == 1 @@ -743,7 +743,7 @@ async def test_update(registry): assert updated_entry.name_by_user == "Test Friendly Name" assert updated_entry.identifiers == new_identifiers assert updated_entry.via_device_id == "98765B" - assert updated_entry.disabled_by == device_registry.DISABLED_USER + assert updated_entry.disabled_by is device_registry.DeviceEntryDisabler.USER assert registry.async_get_device({("hue", "456")}) is None assert registry.async_get_device({("bla", "123")}) is None @@ -1307,7 +1307,7 @@ async def test_disable_config_entry_disables_devices(hass, registry): entry2 = registry.async_get_or_create( config_entry_id=config_entry.entry_id, connections={(device_registry.CONNECTION_NETWORK_MAC, "34:56:AB:CD:EF:12")}, - disabled_by=device_registry.DISABLED_USER, + disabled_by=device_registry.DeviceEntryDisabler.USER, ) assert not entry1.disabled @@ -1320,10 +1320,10 @@ async def test_disable_config_entry_disables_devices(hass, registry): entry1 = registry.async_get(entry1.id) assert entry1.disabled - assert entry1.disabled_by == device_registry.DISABLED_CONFIG_ENTRY + assert entry1.disabled_by is device_registry.DeviceEntryDisabler.CONFIG_ENTRY entry2 = registry.async_get(entry2.id) assert entry2.disabled - assert entry2.disabled_by == device_registry.DISABLED_USER + assert entry2.disabled_by is device_registry.DeviceEntryDisabler.USER await hass.config_entries.async_set_disabled_by(config_entry.entry_id, None) await hass.async_block_till_done() @@ -1332,7 +1332,7 @@ async def test_disable_config_entry_disables_devices(hass, registry): assert not entry1.disabled entry2 = registry.async_get(entry2.id) assert entry2.disabled - assert entry2.disabled_by == device_registry.DISABLED_USER + assert entry2.disabled_by is device_registry.DeviceEntryDisabler.USER async def test_only_disable_device_if_all_config_entries_are_disabled(hass, registry): @@ -1368,7 +1368,7 @@ async def test_only_disable_device_if_all_config_entries_are_disabled(hass, regi entry1 = registry.async_get(entry1.id) assert entry1.disabled - assert entry1.disabled_by == device_registry.DISABLED_CONFIG_ENTRY + assert entry1.disabled_by is device_registry.DeviceEntryDisabler.CONFIG_ENTRY await hass.config_entries.async_set_disabled_by(config_entry1.entry_id, None) await hass.async_block_till_done() diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index ad88a23f76ff..3dc9cf775c44 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -850,7 +850,9 @@ async def test_disable_device_disables_entities(hass, registry): assert entry2.disabled assert entry3.disabled - device_registry.async_update_device(device_entry.id, disabled_by=er.DISABLED_USER) + device_registry.async_update_device( + device_entry.id, disabled_by=dr.DeviceEntryDisabler.USER + ) await hass.async_block_till_done() entry1 = registry.async_get(entry1.entity_id)