diff --git a/homeassistant/components/deconz/deconz_device.py b/homeassistant/components/deconz/deconz_device.py index 80557caeca68..b77014cc34bb 100644 --- a/homeassistant/components/deconz/deconz_device.py +++ b/homeassistant/components/deconz/deconz_device.py @@ -77,11 +77,6 @@ class DeconzDevice(DeconzBase, Entity): self.hass, self.gateway.signal_reachable, self.async_update_callback ) ) - self.listeners.append( - async_dispatcher_connect( - self.hass, self.gateway.signal_remove_entity, self.async_remove_self - ) - ) async def async_will_remove_from_hass(self) -> None: """Disconnect device object when removed.""" @@ -91,15 +86,6 @@ class DeconzDevice(DeconzBase, Entity): for unsub_dispatcher in self.listeners: unsub_dispatcher() - async def async_remove_self(self, deconz_ids: list) -> None: - """Schedule removal of this entity. - - Called by signal_remove_entity scheduled by async_added_to_hass. - """ - if self._device.deconz_id not in deconz_ids: - return - await self.async_remove() - @callback def async_update_callback(self, force_update=False, ignore_update=False): """Update the device's state.""" diff --git a/homeassistant/components/deconz/gateway.py b/homeassistant/components/deconz/gateway.py index eb83f5c15c5d..6ef68b43f643 100644 --- a/homeassistant/components/deconz/gateway.py +++ b/homeassistant/components/deconz/gateway.py @@ -164,15 +164,14 @@ class DeconzGateway: else: deconz_ids += [group.deconz_id for group in groups] - if deconz_ids: - async_dispatcher_send(self.hass, self.signal_remove_entity, deconz_ids) - entity_registry = await self.hass.helpers.entity_registry.async_get_registry() for entity_id, deconz_id in self.deconz_ids.items(): if deconz_id in deconz_ids and entity_registry.async_is_registered( entity_id ): + # Removing an entity from the entity registry will also remove them + # from Home Assistant entity_registry.async_remove(entity_id) @property @@ -197,11 +196,6 @@ class DeconzGateway: } return new_device[device_type] - @property - def signal_remove_entity(self) -> str: - """Gateway specific event to signal removal of entity.""" - return f"deconz-remove-{self.bridgeid}" - @callback def async_add_device_callback(self, device_type, device) -> None: """Handle event of new device creation in deCONZ.""" diff --git a/homeassistant/components/dyson/air_quality.py b/homeassistant/components/dyson/air_quality.py index 647fb2367074..d146ba11aa6d 100644 --- a/homeassistant/components/dyson/air_quality.py +++ b/homeassistant/components/dyson/air_quality.py @@ -27,9 +27,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None): # Get Dyson Devices from parent component device_ids = [device.unique_id for device in hass.data[DYSON_AIQ_DEVICES]] + new_entities = [] for device in hass.data[DYSON_DEVICES]: + print(device.serial) if isinstance(device, DysonPureCool) and device.serial not in device_ids: - hass.data[DYSON_AIQ_DEVICES].append(DysonAirSensor(device)) + new_entities.append(DysonAirSensor(device)) + + if not new_entities: + return + + hass.data[DYSON_AIQ_DEVICES].extend(new_entities) add_entities(hass.data[DYSON_AIQ_DEVICES]) diff --git a/homeassistant/components/dyson/sensor.py b/homeassistant/components/dyson/sensor.py index 55f2ff69314c..7d73e1d43b3f 100644 --- a/homeassistant/components/dyson/sensor.py +++ b/homeassistant/components/dyson/sensor.py @@ -41,18 +41,24 @@ def setup_platform(hass, config, add_entities, discovery_info=None): # Get Dyson Devices from parent component device_ids = [device.unique_id for device in hass.data[DYSON_SENSOR_DEVICES]] + new_entities = [] for device in hass.data[DYSON_DEVICES]: if isinstance(device, DysonPureCool): if f"{device.serial}-temperature" not in device_ids: - devices.append(DysonTemperatureSensor(device, unit)) + new_entities.append(DysonTemperatureSensor(device, unit)) if f"{device.serial}-humidity" not in device_ids: - devices.append(DysonHumiditySensor(device)) + new_entities.append(DysonHumiditySensor(device)) elif isinstance(device, DysonPureCoolLink): - devices.append(DysonFilterLifeSensor(device)) - devices.append(DysonDustSensor(device)) - devices.append(DysonHumiditySensor(device)) - devices.append(DysonTemperatureSensor(device, unit)) - devices.append(DysonAirQualitySensor(device)) + new_entities.append(DysonFilterLifeSensor(device)) + new_entities.append(DysonDustSensor(device)) + new_entities.append(DysonHumiditySensor(device)) + new_entities.append(DysonTemperatureSensor(device, unit)) + new_entities.append(DysonAirQualitySensor(device)) + + if not new_entities: + return + + devices.extend(new_entities) add_entities(devices) diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index d7e2fa1ac837..f84edb1f204b 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -3,11 +3,12 @@ import asyncio import voluptuous as vol -from homeassistant.auth.permissions.const import POLICY_READ +from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ +from homeassistant.components.websocket_api.const import ERR_NOT_FOUND from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL from homeassistant.core import DOMAIN as HASS_DOMAIN, callback from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, Unauthorized -from homeassistant.helpers import config_validation as cv +from homeassistant.helpers import config_validation as cv, entity from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.loader import IntegrationNotFound, async_get_integration @@ -30,6 +31,7 @@ def async_register_commands(hass, async_reg): async_reg(hass, handle_render_template) async_reg(hass, handle_manifest_list) async_reg(hass, handle_manifest_get) + async_reg(hass, handle_entity_source) def pong_message(iden): @@ -263,3 +265,46 @@ def handle_render_template(hass, connection, msg): connection.send_result(msg["id"]) state_listener() + + +@callback +@decorators.websocket_command( + {vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]} +) +def handle_entity_source(hass, connection, msg): + """Handle entity source command.""" + raw_sources = entity.entity_sources(hass) + entity_perm = connection.user.permissions.check_entity + + if "entity_id" not in msg: + if connection.user.permissions.access_all_entities("read"): + sources = raw_sources + else: + sources = { + entity_id: source + for entity_id, source in raw_sources.items() + if entity_perm(entity_id, "read") + } + + connection.send_message(messages.result_message(msg["id"], sources)) + return + + sources = {} + + for entity_id in msg["entity_id"]: + if not entity_perm(entity_id, "read"): + raise Unauthorized( + context=connection.context(msg), + permission=POLICY_READ, + perm_category=CAT_ENTITIES, + ) + + source = raw_sources.get(entity_id) + + if source is None: + connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found") + return + + sources[entity_id] = source + + connection.send_result(msg["id"], sources) diff --git a/homeassistant/components/zha/core/const.py b/homeassistant/components/zha/core/const.py index cb6a698d72f4..63652f58f30d 100644 --- a/homeassistant/components/zha/core/const.py +++ b/homeassistant/components/zha/core/const.py @@ -274,7 +274,6 @@ SIGNAL_REMOVE = "remove" SIGNAL_SET_LEVEL = "set_level" SIGNAL_STATE_ATTR = "update_state_attribute" SIGNAL_UPDATE_DEVICE = "{}_zha_update_device" -SIGNAL_REMOVE_GROUP = "remove_group" SIGNAL_GROUP_ENTITY_REMOVED = "group_entity_removed" SIGNAL_GROUP_MEMBERSHIP_CHANGE = "group_membership_change" diff --git a/homeassistant/components/zha/core/gateway.py b/homeassistant/components/zha/core/gateway.py index ef39c408ec50..1f58dc650869 100644 --- a/homeassistant/components/zha/core/gateway.py +++ b/homeassistant/components/zha/core/gateway.py @@ -57,7 +57,6 @@ from .const import ( SIGNAL_ADD_ENTITIES, SIGNAL_GROUP_MEMBERSHIP_CHANGE, SIGNAL_REMOVE, - SIGNAL_REMOVE_GROUP, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZHA_GW_MSG, @@ -298,13 +297,10 @@ class ZHAGateway: self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED) def group_removed(self, zigpy_group: ZigpyGroupType) -> None: - """Handle zigpy group added event.""" + """Handle zigpy group removed event.""" self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED) zha_group = self._groups.pop(zigpy_group.group_id, None) zha_group.info("group_removed") - async_dispatcher_send( - self._hass, f"{SIGNAL_REMOVE_GROUP}_0x{zigpy_group.group_id:04x}" - ) self._cleanup_group_entity_registry_entries(zigpy_group) def _send_group_gateway_message( @@ -619,7 +615,7 @@ class ZHAGateway: if not group: _LOGGER.debug("Group: %s:0x%04x could not be found", group.name, group_id) return - if group and group.members: + if group.members: tasks = [] for member in group.members: tasks.append(member.async_remove_from_group()) diff --git a/homeassistant/components/zha/entity.py b/homeassistant/components/zha/entity.py index 695a4f6ca6aa..309691dd3df9 100644 --- a/homeassistant/components/zha/entity.py +++ b/homeassistant/components/zha/entity.py @@ -24,7 +24,6 @@ from .core.const import ( SIGNAL_GROUP_ENTITY_REMOVED, SIGNAL_GROUP_MEMBERSHIP_CHANGE, SIGNAL_REMOVE, - SIGNAL_REMOVE_GROUP, ) from .core.helpers import LogMixin from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType @@ -217,32 +216,35 @@ class ZhaGroupEntity(BaseZhaEntity): """Initialize a light group.""" super().__init__(unique_id, zha_device, **kwargs) self._available = False - self._name = ( - f"{zha_device.gateway.groups.get(group_id).name}_zha_group_0x{group_id:04x}" - ) + self._group = zha_device.gateway.groups.get(group_id) + self._name = f"{self._group.name}_zha_group_0x{group_id:04x}" self._group_id: int = group_id self._entity_ids: List[str] = entity_ids self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None + self._handled_group_membership = False @property def available(self) -> bool: """Return entity availability.""" return self._available + async def _handle_group_membership_changed(self): + """Handle group membership changed.""" + # Make sure we don't call remove twice as members are removed + if self._handled_group_membership: + return + + self._handled_group_membership = True + await self.async_remove() + async def async_added_to_hass(self) -> None: """Register callbacks.""" await super().async_added_to_hass() - self.async_accept_signal( - None, - f"{SIGNAL_REMOVE_GROUP}_0x{self._group_id:04x}", - self.async_remove, - signal_override=True, - ) self.async_accept_signal( None, f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{self._group_id:04x}", - self.async_remove, + self._handle_group_membership_changed, signal_override=True, ) diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index d085c1a9021a..44587fec0432 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -54,6 +54,10 @@ class Unauthorized(HomeAssistantError): """Unauthorized error.""" super().__init__(self.__class__.__name__) self.context = context + + if user_id is None and context is not None: + user_id = context.user_id + self.user_id = user_id self.entity_id = entity_id self.config_entry_id = config_entry_id diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 162bba81ddd7..72cc46509785 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -25,15 +25,26 @@ from homeassistant.const import ( TEMP_FAHRENHEIT, ) from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback -from homeassistant.exceptions import NoEntitySpecifiedError +from homeassistant.exceptions import HomeAssistantError, NoEntitySpecifiedError from homeassistant.helpers.entity_platform import EntityPlatform from homeassistant.helpers.entity_registry import RegistryEntry from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event from homeassistant.helpers.typing import StateType +from homeassistant.loader import bind_hass from homeassistant.util import dt as dt_util, ensure_unique_string, slugify _LOGGER = logging.getLogger(__name__) SLOW_UPDATE_WARNING = 10 +DATA_ENTITY_SOURCE = "entity_info" +SOURCE_CONFIG_ENTRY = "config_entry" +SOURCE_PLATFORM_CONFIG = "platform_config" + + +@callback +@bind_hass +def entity_sources(hass: HomeAssistant) -> Dict[str, Dict[str, str]]: + """Get the entity sources.""" + return hass.data.get(DATA_ENTITY_SOURCE, {}) def generate_entity_id( @@ -109,6 +120,9 @@ class Entity(ABC): _context: Optional[Context] = None _context_set: Optional[datetime] = None + # If entity is added to an entity platform + _added = False + @property def should_poll(self) -> bool: """Return True if entity has to be polled for state. @@ -477,10 +491,49 @@ class Entity(ABC): To be extended by integrations. """ + @callback + def add_to_platform_start( + self, + hass: HomeAssistant, + platform: EntityPlatform, + parallel_updates: Optional[asyncio.Semaphore], + ) -> None: + """Start adding an entity to a platform.""" + if self._added: + raise HomeAssistantError( + f"Entity {self.entity_id} cannot be added a second time to an entity platform" + ) + + self.hass = hass + self.platform = platform + self.parallel_updates = parallel_updates + self._added = True + + @callback + def add_to_platform_abort(self) -> None: + """Abort adding an entity to a platform.""" + self.hass = None + self.platform = None + self.parallel_updates = None + self._added = False + + async def add_to_platform_finish(self) -> None: + """Finish adding an entity to a platform.""" + await self.async_internal_added_to_hass() + await self.async_added_to_hass() + self.async_write_ha_state() + async def async_remove(self) -> None: """Remove entity from Home Assistant.""" assert self.hass is not None + if self.platform and not self._added: + raise HomeAssistantError( + f"Entity {self.entity_id} async_remove called twice" + ) + + self._added = False + if self._on_remove is not None: while self._on_remove: self._on_remove.pop()() @@ -507,8 +560,25 @@ class Entity(ABC): Not to be extended by integrations. """ + assert self.hass is not None + + if self.platform: + info = {"domain": self.platform.platform_name} + + if self.platform.config_entry: + info["source"] = SOURCE_CONFIG_ENTRY + info["config_entry"] = self.platform.config_entry.entry_id + else: + info["source"] = SOURCE_PLATFORM_CONFIG + + self.hass.data.setdefault(DATA_ENTITY_SOURCE, {})[self.entity_id] = info + if self.registry_entry is not None: - assert self.hass is not None + # This is an assert as it should never happen, but helps in tests + assert ( + not self.registry_entry.disabled_by + ), f"Entity {self.entity_id} is being added while it's disabled" + self.async_on_remove( async_track_entity_registry_updated_event( self.hass, self.entity_id, self._async_registry_updated @@ -520,6 +590,9 @@ class Entity(ABC): Not to be extended by integrations. """ + if self.platform: + assert self.hass is not None + self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id) async def _async_registry_updated(self, event: Event) -> None: """Handle entity registry update.""" diff --git a/homeassistant/helpers/entity_platform.py b/homeassistant/helpers/entity_platform.py index 88d74a208bd6..d5fad7fd0257 100644 --- a/homeassistant/helpers/entity_platform.py +++ b/homeassistant/helpers/entity_platform.py @@ -6,7 +6,7 @@ from logging import Logger from types import ModuleType from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional -from homeassistant.config_entries import ConfigEntry +from homeassistant import config_entries from homeassistant.const import DEVICE_DEFAULT_NAME from homeassistant.core import ( CALLBACK_TYPE, @@ -60,7 +60,7 @@ class EntityPlatform: self.platform = platform self.scan_interval = scan_interval self.entity_namespace = entity_namespace - self.config_entry: Optional[ConfigEntry] = None + self.config_entry: Optional[config_entries.ConfigEntry] = None self.entities: Dict[str, Entity] = {} # pylint: disable=used-before-assignment self._tasks: List[asyncio.Future] = [] # Method to cancel the state change listener @@ -149,7 +149,7 @@ class EntityPlatform: await self._async_setup_platform(async_create_setup_task) - async def async_setup_entry(self, config_entry: ConfigEntry) -> bool: + async def async_setup_entry(self, config_entry: config_entries.ConfigEntry) -> bool: """Set up the platform from a config entry.""" # Store it so that we can save config entry ID in entity registry self.config_entry = config_entry @@ -332,10 +332,10 @@ class EntityPlatform: if entity is None: raise ValueError("Entity cannot be None") - entity.hass = self.hass - entity.platform = self - entity.parallel_updates = self._get_parallel_updates_semaphore( - hasattr(entity, "async_update") + entity.add_to_platform_start( + self.hass, + self, + self._get_parallel_updates_semaphore(hasattr(entity, "async_update")), ) # Update properties before we generate the entity_id @@ -344,8 +344,7 @@ class EntityPlatform: await entity.async_device_update(warning=False) except Exception: # pylint: disable=broad-except self.logger.exception("%s: Error on device update!", self.platform_name) - entity.hass = None - entity.platform = None + entity.add_to_platform_abort() return requested_entity_id = None @@ -423,8 +422,7 @@ class EntityPlatform: or entity.name or f'"{self.platform_name} {entity.unique_id}"', ) - entity.hass = None - entity.platform = None + entity.add_to_platform_abort() return # We won't generate an entity ID if the platform has already set one @@ -450,8 +448,7 @@ class EntityPlatform: # Make sure it is valid in case an entity set the value themselves if not valid_entity_id(entity.entity_id): - entity.hass = None - entity.platform = None + entity.add_to_platform_abort() raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}") already_exists = entity.entity_id in self.entities @@ -472,18 +469,14 @@ class EntityPlatform: else: msg = f"Entity id already exists - ignoring: {entity.entity_id}" self.logger.error(msg) - entity.hass = None - entity.platform = None + entity.add_to_platform_abort() return entity_id = entity.entity_id self.entities[entity_id] = entity entity.async_on_remove(lambda: self.entities.pop(entity_id)) - await entity.async_internal_added_to_hass() - await entity.async_added_to_hass() - - entity.async_write_ha_state() + await entity.add_to_platform_finish() async def async_reset(self) -> None: """Remove all entities and reset data. diff --git a/tests/components/switch/test_init.py b/tests/components/switch/test_init.py index 5853e5faee22..cf2933282eaa 100644 --- a/tests/components/switch/test_init.py +++ b/tests/components/switch/test_init.py @@ -1,85 +1,55 @@ """The tests for the Switch component.""" -# pylint: disable=protected-access -import unittest +import pytest from homeassistant import core from homeassistant.components import switch from homeassistant.const import CONF_PLATFORM -from homeassistant.setup import async_setup_component, setup_component +from homeassistant.setup import async_setup_component -from tests.common import get_test_home_assistant, mock_entity_platform from tests.components.switch import common -class TestSwitch(unittest.TestCase): - """Test the switch module.""" - - # pylint: disable=invalid-name - def setUp(self): - """Set up things to be run when tests are started.""" - self.hass = get_test_home_assistant() - platform = getattr(self.hass.components, "test.switch") - platform.init() - # Switch 1 is ON, switch 2 is OFF - self.switch_1, self.switch_2, self.switch_3 = platform.ENTITIES - self.addCleanup(self.hass.stop) - - def test_methods(self): - """Test is_on, turn_on, turn_off methods.""" - assert setup_component( - self.hass, switch.DOMAIN, {switch.DOMAIN: {CONF_PLATFORM: "test"}} - ) - self.hass.block_till_done() - assert switch.is_on(self.hass, self.switch_1.entity_id) - assert not switch.is_on(self.hass, self.switch_2.entity_id) - assert not switch.is_on(self.hass, self.switch_3.entity_id) - - common.turn_off(self.hass, self.switch_1.entity_id) - common.turn_on(self.hass, self.switch_2.entity_id) - - self.hass.block_till_done() - - assert not switch.is_on(self.hass, self.switch_1.entity_id) - assert switch.is_on(self.hass, self.switch_2.entity_id) - - # Turn all off - common.turn_off(self.hass) - - self.hass.block_till_done() - - assert not switch.is_on(self.hass, self.switch_1.entity_id) - assert not switch.is_on(self.hass, self.switch_2.entity_id) - assert not switch.is_on(self.hass, self.switch_3.entity_id) - - # Turn all on - common.turn_on(self.hass) - - self.hass.block_till_done() - - assert switch.is_on(self.hass, self.switch_1.entity_id) - assert switch.is_on(self.hass, self.switch_2.entity_id) - assert switch.is_on(self.hass, self.switch_3.entity_id) - - def test_setup_two_platforms(self): - """Test with bad configuration.""" - # Test if switch component returns 0 switches - test_platform = getattr(self.hass.components, "test.switch") - test_platform.init(True) - - mock_entity_platform(self.hass, "switch.test2", test_platform) - test_platform.init(False) - - assert setup_component( - self.hass, - switch.DOMAIN, - { - switch.DOMAIN: {CONF_PLATFORM: "test"}, - f"{switch.DOMAIN} 2": {CONF_PLATFORM: "test2"}, - }, - ) +@pytest.fixture(autouse=True) +def entities(hass): + """Initialize the test switch.""" + platform = getattr(hass.components, "test.switch") + platform.init() + yield platform.ENTITIES -async def test_switch_context(hass, hass_admin_user): +async def test_methods(hass, entities): + """Test is_on, turn_on, turn_off methods.""" + switch_1, switch_2, switch_3 = entities + assert await async_setup_component( + hass, switch.DOMAIN, {switch.DOMAIN: {CONF_PLATFORM: "test"}} + ) + await hass.async_block_till_done() + assert switch.is_on(hass, switch_1.entity_id) + assert not switch.is_on(hass, switch_2.entity_id) + assert not switch.is_on(hass, switch_3.entity_id) + + await common.async_turn_off(hass, switch_1.entity_id) + await common.async_turn_on(hass, switch_2.entity_id) + + assert not switch.is_on(hass, switch_1.entity_id) + assert switch.is_on(hass, switch_2.entity_id) + + # Turn all off + await common.async_turn_off(hass) + + assert not switch.is_on(hass, switch_1.entity_id) + assert not switch.is_on(hass, switch_2.entity_id) + assert not switch.is_on(hass, switch_3.entity_id) + + # Turn all on + await common.async_turn_on(hass) + + assert switch.is_on(hass, switch_1.entity_id) + assert switch.is_on(hass, switch_2.entity_id) + assert switch.is_on(hass, switch_3.entity_id) + + +async def test_switch_context(hass, entities, hass_admin_user): """Test that switch context works.""" assert await async_setup_component(hass, "switch", {"switch": {"platform": "test"}}) diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index 57724fb54c60..c3d4c12eec68 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -10,10 +10,11 @@ from homeassistant.components.websocket_api.auth import ( from homeassistant.components.websocket_api.const import URL from homeassistant.core import callback from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers import entity from homeassistant.loader import async_get_integration from homeassistant.setup import async_setup_component -from tests.common import async_mock_service +from tests.common import MockEntity, MockEntityPlatform, async_mock_service async def test_call_service(hass, websocket_client): @@ -519,3 +520,116 @@ async def test_manifest_get(hass, websocket_client): assert msg["type"] == const.TYPE_RESULT assert not msg["success"] assert msg["error"]["code"] == "not_found" + + +async def test_entity_source_admin(hass, websocket_client, hass_admin_user): + """Check that we fetch sources correctly.""" + platform = MockEntityPlatform(hass) + + await platform.async_add_entities( + [MockEntity(name="Entity 1"), MockEntity(name="Entity 2")] + ) + + # Fetch all + await websocket_client.send_json({"id": 6, "type": "entity/source"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 6 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"] == { + "test_domain.entity_1": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + "test_domain.entity_2": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + } + + # Fetch one + await websocket_client.send_json( + {"id": 7, "type": "entity/source", "entity_id": ["test_domain.entity_2"]} + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 7 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"] == { + "test_domain.entity_2": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + } + + # Fetch two + await websocket_client.send_json( + { + "id": 8, + "type": "entity/source", + "entity_id": ["test_domain.entity_2", "test_domain.entity_1"], + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 8 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"] == { + "test_domain.entity_1": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + "test_domain.entity_2": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + } + + # Fetch non existing + await websocket_client.send_json( + { + "id": 9, + "type": "entity/source", + "entity_id": ["test_domain.entity_2", "test_domain.non_existing"], + } + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 9 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_NOT_FOUND + + # Mock policy + hass_admin_user.groups = [] + hass_admin_user.mock_policy( + {"entities": {"entity_ids": {"test_domain.entity_2": True}}} + ) + + # Fetch all + await websocket_client.send_json({"id": 10, "type": "entity/source"}) + + msg = await websocket_client.receive_json() + assert msg["id"] == 10 + assert msg["type"] == const.TYPE_RESULT + assert msg["success"] + assert msg["result"] == { + "test_domain.entity_2": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + } + + # Fetch unauthorized + await websocket_client.send_json( + {"id": 11, "type": "entity/source", "entity_id": ["test_domain.entity_1"]} + ) + + msg = await websocket_client.receive_json() + assert msg["id"] == 11 + assert msg["type"] == const.TYPE_RESULT + assert not msg["success"] + assert msg["error"]["code"] == const.ERR_UNAUTHORIZED diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 49f8fbdef7cc..1513d573b563 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -11,7 +11,13 @@ from homeassistant.core import Context from homeassistant.helpers import entity, entity_registry from tests.async_mock import MagicMock, PropertyMock, patch -from tests.common import get_test_home_assistant, mock_registry +from tests.common import ( + MockConfigEntry, + MockEntity, + MockEntityPlatform, + get_test_home_assistant, + mock_registry, +) def test_generate_entity_id_requires_hass_or_ids(): @@ -603,7 +609,7 @@ async def test_disabled_in_entity_registry(hass): entity_id="hello.world", unique_id="test-unique-id", platform="test-platform", - disabled_by="user", + disabled_by=None, ) registry = mock_registry(hass, {"hello.world": entry}) @@ -611,23 +617,24 @@ async def test_disabled_in_entity_registry(hass): ent.hass = hass ent.entity_id = "hello.world" ent.registry_entry = entry - ent.platform = MagicMock(platform_name="test-platform") + assert ent.enabled is True - await ent.async_internal_added_to_hass() - ent.async_write_ha_state() - assert hass.states.get("hello.world") is None + ent.add_to_platform_start(hass, MagicMock(platform_name="test-platform"), None) + await ent.add_to_platform_finish() + assert hass.states.get("hello.world") is not None - entry2 = registry.async_update_entity("hello.world", disabled_by=None) + entry2 = registry.async_update_entity("hello.world", disabled_by="user") await hass.async_block_till_done() assert entry2 != entry assert ent.registry_entry == entry2 - assert ent.enabled is True + assert ent.enabled is False + assert hass.states.get("hello.world") is None - entry3 = registry.async_update_entity("hello.world", disabled_by="user") + entry3 = registry.async_update_entity("hello.world", disabled_by=None) await hass.async_block_till_done() assert entry3 != entry2 - assert ent.registry_entry == entry3 - assert ent.enabled is False + # Entry is no longer updated, entity is no longer tracking changes + assert ent.registry_entry == entry2 async def test_capability_attrs(hass): @@ -690,3 +697,31 @@ async def test_warn_slow_write_state_custom_component(hass, caplog): "(.CustomComponentEntity'>) " "took 10.000 seconds. Please report it to the custom component author." ) in caplog.text + + +async def test_setup_source(hass): + """Check that we register sources correctly.""" + platform = MockEntityPlatform(hass) + + entity_platform = MockEntity(name="Platform Config Source") + await platform.async_add_entities([entity_platform]) + + platform.config_entry = MockConfigEntry() + entity_entry = MockEntity(name="Config Entry Source") + await platform.async_add_entities([entity_entry]) + + assert entity.entity_sources(hass) == { + "test_domain.platform_config_source": { + "source": entity.SOURCE_PLATFORM_CONFIG, + "domain": "test_platform", + }, + "test_domain.config_entry_source": { + "source": entity.SOURCE_CONFIG_ENTRY, + "config_entry": platform.config_entry.entry_id, + "domain": "test_platform", + }, + } + + await platform.async_reset() + + assert entity.entity_sources(hass) == {} diff --git a/tests/testing_config/custom_components/test/switch.py b/tests/testing_config/custom_components/test/switch.py index 7dd1862d88f0..0c8881346b07 100644 --- a/tests/testing_config/custom_components/test/switch.py +++ b/tests/testing_config/custom_components/test/switch.py @@ -29,4 +29,5 @@ async def async_setup_platform( hass, config, async_add_entities_callback, discovery_info=None ): """Return mock entities.""" + print("YOOO") async_add_entities_callback(ENTITIES)