Track entity sources (#37258)

Co-authored-by: David Mulcahey <david.mulcahey@me.com>
This commit is contained in:
Paulus Schoutsen 2020-08-19 14:57:38 +02:00 committed by GitHub
parent 24a16ff8fe
commit 3dc79aa60a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 379 additions and 154 deletions

View File

@ -77,11 +77,6 @@ class DeconzDevice(DeconzBase, Entity):
self.hass, self.gateway.signal_reachable, self.async_update_callback
)
)
self.listeners.append(
async_dispatcher_connect(
self.hass, self.gateway.signal_remove_entity, self.async_remove_self
)
)
async def async_will_remove_from_hass(self) -> None:
"""Disconnect device object when removed."""
@ -91,15 +86,6 @@ class DeconzDevice(DeconzBase, Entity):
for unsub_dispatcher in self.listeners:
unsub_dispatcher()
async def async_remove_self(self, deconz_ids: list) -> None:
"""Schedule removal of this entity.
Called by signal_remove_entity scheduled by async_added_to_hass.
"""
if self._device.deconz_id not in deconz_ids:
return
await self.async_remove()
@callback
def async_update_callback(self, force_update=False, ignore_update=False):
"""Update the device's state."""

View File

@ -164,15 +164,14 @@ class DeconzGateway:
else:
deconz_ids += [group.deconz_id for group in groups]
if deconz_ids:
async_dispatcher_send(self.hass, self.signal_remove_entity, deconz_ids)
entity_registry = await self.hass.helpers.entity_registry.async_get_registry()
for entity_id, deconz_id in self.deconz_ids.items():
if deconz_id in deconz_ids and entity_registry.async_is_registered(
entity_id
):
# Removing an entity from the entity registry will also remove them
# from Home Assistant
entity_registry.async_remove(entity_id)
@property
@ -197,11 +196,6 @@ class DeconzGateway:
}
return new_device[device_type]
@property
def signal_remove_entity(self) -> str:
"""Gateway specific event to signal removal of entity."""
return f"deconz-remove-{self.bridgeid}"
@callback
def async_add_device_callback(self, device_type, device) -> None:
"""Handle event of new device creation in deCONZ."""

View File

@ -27,9 +27,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
# Get Dyson Devices from parent component
device_ids = [device.unique_id for device in hass.data[DYSON_AIQ_DEVICES]]
new_entities = []
for device in hass.data[DYSON_DEVICES]:
print(device.serial)
if isinstance(device, DysonPureCool) and device.serial not in device_ids:
hass.data[DYSON_AIQ_DEVICES].append(DysonAirSensor(device))
new_entities.append(DysonAirSensor(device))
if not new_entities:
return
hass.data[DYSON_AIQ_DEVICES].extend(new_entities)
add_entities(hass.data[DYSON_AIQ_DEVICES])

View File

@ -41,18 +41,24 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
# Get Dyson Devices from parent component
device_ids = [device.unique_id for device in hass.data[DYSON_SENSOR_DEVICES]]
new_entities = []
for device in hass.data[DYSON_DEVICES]:
if isinstance(device, DysonPureCool):
if f"{device.serial}-temperature" not in device_ids:
devices.append(DysonTemperatureSensor(device, unit))
new_entities.append(DysonTemperatureSensor(device, unit))
if f"{device.serial}-humidity" not in device_ids:
devices.append(DysonHumiditySensor(device))
new_entities.append(DysonHumiditySensor(device))
elif isinstance(device, DysonPureCoolLink):
devices.append(DysonFilterLifeSensor(device))
devices.append(DysonDustSensor(device))
devices.append(DysonHumiditySensor(device))
devices.append(DysonTemperatureSensor(device, unit))
devices.append(DysonAirQualitySensor(device))
new_entities.append(DysonFilterLifeSensor(device))
new_entities.append(DysonDustSensor(device))
new_entities.append(DysonHumiditySensor(device))
new_entities.append(DysonTemperatureSensor(device, unit))
new_entities.append(DysonAirQualitySensor(device))
if not new_entities:
return
devices.extend(new_entities)
add_entities(devices)

View File

@ -3,11 +3,12 @@ import asyncio
import voluptuous as vol
from homeassistant.auth.permissions.const import POLICY_READ
from homeassistant.auth.permissions.const import CAT_ENTITIES, POLICY_READ
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.const import EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL
from homeassistant.core import DOMAIN as HASS_DOMAIN, callback
from homeassistant.exceptions import HomeAssistantError, ServiceNotFound, Unauthorized
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import config_validation as cv, entity
from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.loader import IntegrationNotFound, async_get_integration
@ -30,6 +31,7 @@ def async_register_commands(hass, async_reg):
async_reg(hass, handle_render_template)
async_reg(hass, handle_manifest_list)
async_reg(hass, handle_manifest_get)
async_reg(hass, handle_entity_source)
def pong_message(iden):
@ -263,3 +265,46 @@ def handle_render_template(hass, connection, msg):
connection.send_result(msg["id"])
state_listener()
@callback
@decorators.websocket_command(
{vol.Required("type"): "entity/source", vol.Optional("entity_id"): [cv.entity_id]}
)
def handle_entity_source(hass, connection, msg):
"""Handle entity source command."""
raw_sources = entity.entity_sources(hass)
entity_perm = connection.user.permissions.check_entity
if "entity_id" not in msg:
if connection.user.permissions.access_all_entities("read"):
sources = raw_sources
else:
sources = {
entity_id: source
for entity_id, source in raw_sources.items()
if entity_perm(entity_id, "read")
}
connection.send_message(messages.result_message(msg["id"], sources))
return
sources = {}
for entity_id in msg["entity_id"]:
if not entity_perm(entity_id, "read"):
raise Unauthorized(
context=connection.context(msg),
permission=POLICY_READ,
perm_category=CAT_ENTITIES,
)
source = raw_sources.get(entity_id)
if source is None:
connection.send_error(msg["id"], ERR_NOT_FOUND, "Entity not found")
return
sources[entity_id] = source
connection.send_result(msg["id"], sources)

View File

@ -274,7 +274,6 @@ SIGNAL_REMOVE = "remove"
SIGNAL_SET_LEVEL = "set_level"
SIGNAL_STATE_ATTR = "update_state_attribute"
SIGNAL_UPDATE_DEVICE = "{}_zha_update_device"
SIGNAL_REMOVE_GROUP = "remove_group"
SIGNAL_GROUP_ENTITY_REMOVED = "group_entity_removed"
SIGNAL_GROUP_MEMBERSHIP_CHANGE = "group_membership_change"

View File

@ -57,7 +57,6 @@ from .const import (
SIGNAL_ADD_ENTITIES,
SIGNAL_GROUP_MEMBERSHIP_CHANGE,
SIGNAL_REMOVE,
SIGNAL_REMOVE_GROUP,
UNKNOWN_MANUFACTURER,
UNKNOWN_MODEL,
ZHA_GW_MSG,
@ -298,13 +297,10 @@ class ZHAGateway:
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_ADDED)
def group_removed(self, zigpy_group: ZigpyGroupType) -> None:
"""Handle zigpy group added event."""
"""Handle zigpy group removed event."""
self._send_group_gateway_message(zigpy_group, ZHA_GW_MSG_GROUP_REMOVED)
zha_group = self._groups.pop(zigpy_group.group_id, None)
zha_group.info("group_removed")
async_dispatcher_send(
self._hass, f"{SIGNAL_REMOVE_GROUP}_0x{zigpy_group.group_id:04x}"
)
self._cleanup_group_entity_registry_entries(zigpy_group)
def _send_group_gateway_message(
@ -619,7 +615,7 @@ class ZHAGateway:
if not group:
_LOGGER.debug("Group: %s:0x%04x could not be found", group.name, group_id)
return
if group and group.members:
if group.members:
tasks = []
for member in group.members:
tasks.append(member.async_remove_from_group())

View File

@ -24,7 +24,6 @@ from .core.const import (
SIGNAL_GROUP_ENTITY_REMOVED,
SIGNAL_GROUP_MEMBERSHIP_CHANGE,
SIGNAL_REMOVE,
SIGNAL_REMOVE_GROUP,
)
from .core.helpers import LogMixin
from .core.typing import CALLABLE_T, ChannelType, ZhaDeviceType
@ -217,32 +216,35 @@ class ZhaGroupEntity(BaseZhaEntity):
"""Initialize a light group."""
super().__init__(unique_id, zha_device, **kwargs)
self._available = False
self._name = (
f"{zha_device.gateway.groups.get(group_id).name}_zha_group_0x{group_id:04x}"
)
self._group = zha_device.gateway.groups.get(group_id)
self._name = f"{self._group.name}_zha_group_0x{group_id:04x}"
self._group_id: int = group_id
self._entity_ids: List[str] = entity_ids
self._async_unsub_state_changed: Optional[CALLBACK_TYPE] = None
self._handled_group_membership = False
@property
def available(self) -> bool:
"""Return entity availability."""
return self._available
async def _handle_group_membership_changed(self):
"""Handle group membership changed."""
# Make sure we don't call remove twice as members are removed
if self._handled_group_membership:
return
self._handled_group_membership = True
await self.async_remove()
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
await super().async_added_to_hass()
self.async_accept_signal(
None,
f"{SIGNAL_REMOVE_GROUP}_0x{self._group_id:04x}",
self.async_remove,
signal_override=True,
)
self.async_accept_signal(
None,
f"{SIGNAL_GROUP_MEMBERSHIP_CHANGE}_0x{self._group_id:04x}",
self.async_remove,
self._handle_group_membership_changed,
signal_override=True,
)

View File

@ -54,6 +54,10 @@ class Unauthorized(HomeAssistantError):
"""Unauthorized error."""
super().__init__(self.__class__.__name__)
self.context = context
if user_id is None and context is not None:
user_id = context.user_id
self.user_id = user_id
self.entity_id = entity_id
self.config_entry_id = config_entry_id

View File

@ -25,15 +25,26 @@ from homeassistant.const import (
TEMP_FAHRENHEIT,
)
from homeassistant.core import CALLBACK_TYPE, Context, HomeAssistant, callback
from homeassistant.exceptions import NoEntitySpecifiedError
from homeassistant.exceptions import HomeAssistantError, NoEntitySpecifiedError
from homeassistant.helpers.entity_platform import EntityPlatform
from homeassistant.helpers.entity_registry import RegistryEntry
from homeassistant.helpers.event import Event, async_track_entity_registry_updated_event
from homeassistant.helpers.typing import StateType
from homeassistant.loader import bind_hass
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
_LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10
DATA_ENTITY_SOURCE = "entity_info"
SOURCE_CONFIG_ENTRY = "config_entry"
SOURCE_PLATFORM_CONFIG = "platform_config"
@callback
@bind_hass
def entity_sources(hass: HomeAssistant) -> Dict[str, Dict[str, str]]:
"""Get the entity sources."""
return hass.data.get(DATA_ENTITY_SOURCE, {})
def generate_entity_id(
@ -109,6 +120,9 @@ class Entity(ABC):
_context: Optional[Context] = None
_context_set: Optional[datetime] = None
# If entity is added to an entity platform
_added = False
@property
def should_poll(self) -> bool:
"""Return True if entity has to be polled for state.
@ -477,10 +491,49 @@ class Entity(ABC):
To be extended by integrations.
"""
@callback
def add_to_platform_start(
self,
hass: HomeAssistant,
platform: EntityPlatform,
parallel_updates: Optional[asyncio.Semaphore],
) -> None:
"""Start adding an entity to a platform."""
if self._added:
raise HomeAssistantError(
f"Entity {self.entity_id} cannot be added a second time to an entity platform"
)
self.hass = hass
self.platform = platform
self.parallel_updates = parallel_updates
self._added = True
@callback
def add_to_platform_abort(self) -> None:
"""Abort adding an entity to a platform."""
self.hass = None
self.platform = None
self.parallel_updates = None
self._added = False
async def add_to_platform_finish(self) -> None:
"""Finish adding an entity to a platform."""
await self.async_internal_added_to_hass()
await self.async_added_to_hass()
self.async_write_ha_state()
async def async_remove(self) -> None:
"""Remove entity from Home Assistant."""
assert self.hass is not None
if self.platform and not self._added:
raise HomeAssistantError(
f"Entity {self.entity_id} async_remove called twice"
)
self._added = False
if self._on_remove is not None:
while self._on_remove:
self._on_remove.pop()()
@ -507,8 +560,25 @@ class Entity(ABC):
Not to be extended by integrations.
"""
assert self.hass is not None
if self.platform:
info = {"domain": self.platform.platform_name}
if self.platform.config_entry:
info["source"] = SOURCE_CONFIG_ENTRY
info["config_entry"] = self.platform.config_entry.entry_id
else:
info["source"] = SOURCE_PLATFORM_CONFIG
self.hass.data.setdefault(DATA_ENTITY_SOURCE, {})[self.entity_id] = info
if self.registry_entry is not None:
assert self.hass is not None
# This is an assert as it should never happen, but helps in tests
assert (
not self.registry_entry.disabled_by
), f"Entity {self.entity_id} is being added while it's disabled"
self.async_on_remove(
async_track_entity_registry_updated_event(
self.hass, self.entity_id, self._async_registry_updated
@ -520,6 +590,9 @@ class Entity(ABC):
Not to be extended by integrations.
"""
if self.platform:
assert self.hass is not None
self.hass.data[DATA_ENTITY_SOURCE].pop(self.entity_id)
async def _async_registry_updated(self, event: Event) -> None:
"""Handle entity registry update."""

View File

@ -6,7 +6,7 @@ from logging import Logger
from types import ModuleType
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Iterable, List, Optional
from homeassistant.config_entries import ConfigEntry
from homeassistant import config_entries
from homeassistant.const import DEVICE_DEFAULT_NAME
from homeassistant.core import (
CALLBACK_TYPE,
@ -60,7 +60,7 @@ class EntityPlatform:
self.platform = platform
self.scan_interval = scan_interval
self.entity_namespace = entity_namespace
self.config_entry: Optional[ConfigEntry] = None
self.config_entry: Optional[config_entries.ConfigEntry] = None
self.entities: Dict[str, Entity] = {} # pylint: disable=used-before-assignment
self._tasks: List[asyncio.Future] = []
# Method to cancel the state change listener
@ -149,7 +149,7 @@ class EntityPlatform:
await self._async_setup_platform(async_create_setup_task)
async def async_setup_entry(self, config_entry: ConfigEntry) -> bool:
async def async_setup_entry(self, config_entry: config_entries.ConfigEntry) -> bool:
"""Set up the platform from a config entry."""
# Store it so that we can save config entry ID in entity registry
self.config_entry = config_entry
@ -332,10 +332,10 @@ class EntityPlatform:
if entity is None:
raise ValueError("Entity cannot be None")
entity.hass = self.hass
entity.platform = self
entity.parallel_updates = self._get_parallel_updates_semaphore(
hasattr(entity, "async_update")
entity.add_to_platform_start(
self.hass,
self,
self._get_parallel_updates_semaphore(hasattr(entity, "async_update")),
)
# Update properties before we generate the entity_id
@ -344,8 +344,7 @@ class EntityPlatform:
await entity.async_device_update(warning=False)
except Exception: # pylint: disable=broad-except
self.logger.exception("%s: Error on device update!", self.platform_name)
entity.hass = None
entity.platform = None
entity.add_to_platform_abort()
return
requested_entity_id = None
@ -423,8 +422,7 @@ class EntityPlatform:
or entity.name
or f'"{self.platform_name} {entity.unique_id}"',
)
entity.hass = None
entity.platform = None
entity.add_to_platform_abort()
return
# We won't generate an entity ID if the platform has already set one
@ -450,8 +448,7 @@ class EntityPlatform:
# Make sure it is valid in case an entity set the value themselves
if not valid_entity_id(entity.entity_id):
entity.hass = None
entity.platform = None
entity.add_to_platform_abort()
raise HomeAssistantError(f"Invalid entity id: {entity.entity_id}")
already_exists = entity.entity_id in self.entities
@ -472,18 +469,14 @@ class EntityPlatform:
else:
msg = f"Entity id already exists - ignoring: {entity.entity_id}"
self.logger.error(msg)
entity.hass = None
entity.platform = None
entity.add_to_platform_abort()
return
entity_id = entity.entity_id
self.entities[entity_id] = entity
entity.async_on_remove(lambda: self.entities.pop(entity_id))
await entity.async_internal_added_to_hass()
await entity.async_added_to_hass()
entity.async_write_ha_state()
await entity.add_to_platform_finish()
async def async_reset(self) -> None:
"""Remove all entities and reset data.

View File

@ -1,85 +1,55 @@
"""The tests for the Switch component."""
# pylint: disable=protected-access
import unittest
import pytest
from homeassistant import core
from homeassistant.components import switch
from homeassistant.const import CONF_PLATFORM
from homeassistant.setup import async_setup_component, setup_component
from homeassistant.setup import async_setup_component
from tests.common import get_test_home_assistant, mock_entity_platform
from tests.components.switch import common
class TestSwitch(unittest.TestCase):
"""Test the switch module."""
# pylint: disable=invalid-name
def setUp(self):
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
platform = getattr(self.hass.components, "test.switch")
platform.init()
# Switch 1 is ON, switch 2 is OFF
self.switch_1, self.switch_2, self.switch_3 = platform.ENTITIES
self.addCleanup(self.hass.stop)
def test_methods(self):
"""Test is_on, turn_on, turn_off methods."""
assert setup_component(
self.hass, switch.DOMAIN, {switch.DOMAIN: {CONF_PLATFORM: "test"}}
)
self.hass.block_till_done()
assert switch.is_on(self.hass, self.switch_1.entity_id)
assert not switch.is_on(self.hass, self.switch_2.entity_id)
assert not switch.is_on(self.hass, self.switch_3.entity_id)
common.turn_off(self.hass, self.switch_1.entity_id)
common.turn_on(self.hass, self.switch_2.entity_id)
self.hass.block_till_done()
assert not switch.is_on(self.hass, self.switch_1.entity_id)
assert switch.is_on(self.hass, self.switch_2.entity_id)
# Turn all off
common.turn_off(self.hass)
self.hass.block_till_done()
assert not switch.is_on(self.hass, self.switch_1.entity_id)
assert not switch.is_on(self.hass, self.switch_2.entity_id)
assert not switch.is_on(self.hass, self.switch_3.entity_id)
# Turn all on
common.turn_on(self.hass)
self.hass.block_till_done()
assert switch.is_on(self.hass, self.switch_1.entity_id)
assert switch.is_on(self.hass, self.switch_2.entity_id)
assert switch.is_on(self.hass, self.switch_3.entity_id)
def test_setup_two_platforms(self):
"""Test with bad configuration."""
# Test if switch component returns 0 switches
test_platform = getattr(self.hass.components, "test.switch")
test_platform.init(True)
mock_entity_platform(self.hass, "switch.test2", test_platform)
test_platform.init(False)
assert setup_component(
self.hass,
switch.DOMAIN,
{
switch.DOMAIN: {CONF_PLATFORM: "test"},
f"{switch.DOMAIN} 2": {CONF_PLATFORM: "test2"},
},
)
@pytest.fixture(autouse=True)
def entities(hass):
"""Initialize the test switch."""
platform = getattr(hass.components, "test.switch")
platform.init()
yield platform.ENTITIES
async def test_switch_context(hass, hass_admin_user):
async def test_methods(hass, entities):
"""Test is_on, turn_on, turn_off methods."""
switch_1, switch_2, switch_3 = entities
assert await async_setup_component(
hass, switch.DOMAIN, {switch.DOMAIN: {CONF_PLATFORM: "test"}}
)
await hass.async_block_till_done()
assert switch.is_on(hass, switch_1.entity_id)
assert not switch.is_on(hass, switch_2.entity_id)
assert not switch.is_on(hass, switch_3.entity_id)
await common.async_turn_off(hass, switch_1.entity_id)
await common.async_turn_on(hass, switch_2.entity_id)
assert not switch.is_on(hass, switch_1.entity_id)
assert switch.is_on(hass, switch_2.entity_id)
# Turn all off
await common.async_turn_off(hass)
assert not switch.is_on(hass, switch_1.entity_id)
assert not switch.is_on(hass, switch_2.entity_id)
assert not switch.is_on(hass, switch_3.entity_id)
# Turn all on
await common.async_turn_on(hass)
assert switch.is_on(hass, switch_1.entity_id)
assert switch.is_on(hass, switch_2.entity_id)
assert switch.is_on(hass, switch_3.entity_id)
async def test_switch_context(hass, entities, hass_admin_user):
"""Test that switch context works."""
assert await async_setup_component(hass, "switch", {"switch": {"platform": "test"}})

View File

@ -10,10 +10,11 @@ from homeassistant.components.websocket_api.auth import (
from homeassistant.components.websocket_api.const import URL
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import entity
from homeassistant.loader import async_get_integration
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
from tests.common import MockEntity, MockEntityPlatform, async_mock_service
async def test_call_service(hass, websocket_client):
@ -519,3 +520,116 @@ async def test_manifest_get(hass, websocket_client):
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == "not_found"
async def test_entity_source_admin(hass, websocket_client, hass_admin_user):
"""Check that we fetch sources correctly."""
platform = MockEntityPlatform(hass)
await platform.async_add_entities(
[MockEntity(name="Entity 1"), MockEntity(name="Entity 2")]
)
# Fetch all
await websocket_client.send_json({"id": 6, "type": "entity/source"})
msg = await websocket_client.receive_json()
assert msg["id"] == 6
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_1": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
"test_domain.entity_2": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
}
# Fetch one
await websocket_client.send_json(
{"id": 7, "type": "entity/source", "entity_id": ["test_domain.entity_2"]}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 7
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_2": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
}
# Fetch two
await websocket_client.send_json(
{
"id": 8,
"type": "entity/source",
"entity_id": ["test_domain.entity_2", "test_domain.entity_1"],
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 8
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_1": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
"test_domain.entity_2": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
}
# Fetch non existing
await websocket_client.send_json(
{
"id": 9,
"type": "entity/source",
"entity_id": ["test_domain.entity_2", "test_domain.non_existing"],
}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 9
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_NOT_FOUND
# Mock policy
hass_admin_user.groups = []
hass_admin_user.mock_policy(
{"entities": {"entity_ids": {"test_domain.entity_2": True}}}
)
# Fetch all
await websocket_client.send_json({"id": 10, "type": "entity/source"})
msg = await websocket_client.receive_json()
assert msg["id"] == 10
assert msg["type"] == const.TYPE_RESULT
assert msg["success"]
assert msg["result"] == {
"test_domain.entity_2": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
}
# Fetch unauthorized
await websocket_client.send_json(
{"id": 11, "type": "entity/source", "entity_id": ["test_domain.entity_1"]}
)
msg = await websocket_client.receive_json()
assert msg["id"] == 11
assert msg["type"] == const.TYPE_RESULT
assert not msg["success"]
assert msg["error"]["code"] == const.ERR_UNAUTHORIZED

View File

@ -11,7 +11,13 @@ from homeassistant.core import Context
from homeassistant.helpers import entity, entity_registry
from tests.async_mock import MagicMock, PropertyMock, patch
from tests.common import get_test_home_assistant, mock_registry
from tests.common import (
MockConfigEntry,
MockEntity,
MockEntityPlatform,
get_test_home_assistant,
mock_registry,
)
def test_generate_entity_id_requires_hass_or_ids():
@ -603,7 +609,7 @@ async def test_disabled_in_entity_registry(hass):
entity_id="hello.world",
unique_id="test-unique-id",
platform="test-platform",
disabled_by="user",
disabled_by=None,
)
registry = mock_registry(hass, {"hello.world": entry})
@ -611,23 +617,24 @@ async def test_disabled_in_entity_registry(hass):
ent.hass = hass
ent.entity_id = "hello.world"
ent.registry_entry = entry
ent.platform = MagicMock(platform_name="test-platform")
assert ent.enabled is True
await ent.async_internal_added_to_hass()
ent.async_write_ha_state()
assert hass.states.get("hello.world") is None
ent.add_to_platform_start(hass, MagicMock(platform_name="test-platform"), None)
await ent.add_to_platform_finish()
assert hass.states.get("hello.world") is not None
entry2 = registry.async_update_entity("hello.world", disabled_by=None)
entry2 = registry.async_update_entity("hello.world", disabled_by="user")
await hass.async_block_till_done()
assert entry2 != entry
assert ent.registry_entry == entry2
assert ent.enabled is True
assert ent.enabled is False
assert hass.states.get("hello.world") is None
entry3 = registry.async_update_entity("hello.world", disabled_by="user")
entry3 = registry.async_update_entity("hello.world", disabled_by=None)
await hass.async_block_till_done()
assert entry3 != entry2
assert ent.registry_entry == entry3
assert ent.enabled is False
# Entry is no longer updated, entity is no longer tracking changes
assert ent.registry_entry == entry2
async def test_capability_attrs(hass):
@ -690,3 +697,31 @@ async def test_warn_slow_write_state_custom_component(hass, caplog):
"(<class 'custom_components.bla.sensor.test_warn_slow_write_state_custom_component.<locals>.CustomComponentEntity'>) "
"took 10.000 seconds. Please report it to the custom component author."
) in caplog.text
async def test_setup_source(hass):
"""Check that we register sources correctly."""
platform = MockEntityPlatform(hass)
entity_platform = MockEntity(name="Platform Config Source")
await platform.async_add_entities([entity_platform])
platform.config_entry = MockConfigEntry()
entity_entry = MockEntity(name="Config Entry Source")
await platform.async_add_entities([entity_entry])
assert entity.entity_sources(hass) == {
"test_domain.platform_config_source": {
"source": entity.SOURCE_PLATFORM_CONFIG,
"domain": "test_platform",
},
"test_domain.config_entry_source": {
"source": entity.SOURCE_CONFIG_ENTRY,
"config_entry": platform.config_entry.entry_id,
"domain": "test_platform",
},
}
await platform.async_reset()
assert entity.entity_sources(hass) == {}

View File

@ -29,4 +29,5 @@ async def async_setup_platform(
hass, config, async_add_entities_callback, discovery_info=None
):
"""Return mock entities."""
print("YOOO")
async_add_entities_callback(ENTITIES)