"""The tests for the Entity component helper.""" # pylint: disable=protected-access from collections import OrderedDict from datetime import timedelta import logging import pytest import voluptuous as vol from homeassistant.const import ENTITY_MATCH_ALL, ENTITY_MATCH_NONE import homeassistant.core as ha from homeassistant.exceptions import PlatformNotReady from homeassistant.helpers import discovery from homeassistant.helpers.entity_component import EntityComponent from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util from tests.async_mock import AsyncMock, Mock, patch from tests.common import ( MockConfigEntry, MockEntity, MockModule, MockPlatform, async_fire_time_changed, mock_entity_platform, mock_integration, ) _LOGGER = logging.getLogger(__name__) DOMAIN = "test_domain" async def test_setup_loads_platforms(hass): """Test the loading of the platforms.""" component_setup = Mock(return_value=True) platform_setup = Mock(return_value=None) mock_integration(hass, MockModule("test_component", setup=component_setup)) # mock the dependencies mock_integration(hass, MockModule("mod2", dependencies=["test_component"])) mock_entity_platform(hass, "test_domain.mod2", MockPlatform(platform_setup)) component = EntityComponent(_LOGGER, DOMAIN, hass) assert not component_setup.called assert not platform_setup.called component.setup({DOMAIN: {"platform": "mod2"}}) await hass.async_block_till_done() assert component_setup.called assert platform_setup.called async def test_setup_recovers_when_setup_raises(hass): """Test the setup if exceptions are happening.""" platform1_setup = Mock(side_effect=Exception("Broken")) platform2_setup = Mock(return_value=None) mock_entity_platform(hass, "test_domain.mod1", MockPlatform(platform1_setup)) mock_entity_platform(hass, "test_domain.mod2", MockPlatform(platform2_setup)) component = EntityComponent(_LOGGER, DOMAIN, hass) assert not platform1_setup.called assert not platform2_setup.called component.setup( OrderedDict( [ (DOMAIN, {"platform": "mod1"}), (f"{DOMAIN} 2", {"platform": "non_exist"}), (f"{DOMAIN} 3", {"platform": "mod2"}), ] ) ) await hass.async_block_till_done() assert platform1_setup.called assert platform2_setup.called @patch( "homeassistant.helpers.entity_component.EntityComponent.async_setup_platform", ) @patch("homeassistant.setup.async_setup_component", return_value=True) async def test_setup_does_discovery(mock_setup_component, mock_setup, hass): """Test setup for discovery.""" component = EntityComponent(_LOGGER, DOMAIN, hass) component.setup({}) discovery.load_platform( hass, DOMAIN, "platform_test", {"msg": "discovery_info"}, {DOMAIN: {}} ) await hass.async_block_till_done() assert mock_setup.called assert ("platform_test", {}, {"msg": "discovery_info"}) == mock_setup.call_args[0] @patch("homeassistant.helpers.entity_platform.async_track_time_interval") async def test_set_scan_interval_via_config(mock_track, hass): """Test the setting of the scan interval via configuration.""" def platform_setup(hass, config, add_entities, discovery_info=None): """Test the platform setup.""" add_entities([MockEntity(should_poll=True)]) mock_entity_platform(hass, "test_domain.platform", MockPlatform(platform_setup)) component = EntityComponent(_LOGGER, DOMAIN, hass) component.setup( {DOMAIN: {"platform": "platform", "scan_interval": timedelta(seconds=30)}} ) await hass.async_block_till_done() assert mock_track.called assert timedelta(seconds=30) == mock_track.call_args[0][2] async def test_set_entity_namespace_via_config(hass): """Test setting an entity namespace.""" def platform_setup(hass, config, add_entities, discovery_info=None): """Test the platform setup.""" add_entities([MockEntity(name="beer"), MockEntity(name=None)]) platform = MockPlatform(platform_setup) mock_entity_platform(hass, "test_domain.platform", platform) component = EntityComponent(_LOGGER, DOMAIN, hass) component.setup({DOMAIN: {"platform": "platform", "entity_namespace": "yummy"}}) await hass.async_block_till_done() assert sorted(hass.states.async_entity_ids()) == [ "test_domain.yummy_beer", "test_domain.yummy_unnamed_device", ] async def test_extract_from_service_available_device(hass): """Test the extraction of entity from service and device is available.""" component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities( [ MockEntity(name="test_1"), MockEntity(name="test_2", available=False), MockEntity(name="test_3"), MockEntity(name="test_4", available=False), ] ) call_1 = ha.ServiceCall("test", "service", data={"entity_id": ENTITY_MATCH_ALL}) assert ["test_domain.test_1", "test_domain.test_3"] == sorted( ent.entity_id for ent in (await component.async_extract_from_service(call_1)) ) call_2 = ha.ServiceCall( "test", "service", data={"entity_id": ["test_domain.test_3", "test_domain.test_4"]}, ) assert ["test_domain.test_3"] == sorted( ent.entity_id for ent in (await component.async_extract_from_service(call_2)) ) async def test_platform_not_ready(hass, legacy_patchable_time): """Test that we retry when platform not ready.""" platform1_setup = Mock(side_effect=[PlatformNotReady, PlatformNotReady, None]) mock_integration(hass, MockModule("mod1")) mock_entity_platform(hass, "test_domain.mod1", MockPlatform(platform1_setup)) component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_setup({DOMAIN: {"platform": "mod1"}}) await hass.async_block_till_done() assert len(platform1_setup.mock_calls) == 1 assert "test_domain.mod1" not in hass.config.components utcnow = dt_util.utcnow() with patch("homeassistant.util.dt.utcnow", return_value=utcnow): # Should not trigger attempt 2 async_fire_time_changed(hass, utcnow + timedelta(seconds=29)) await hass.async_block_till_done() assert len(platform1_setup.mock_calls) == 1 # Should trigger attempt 2 async_fire_time_changed(hass, utcnow + timedelta(seconds=30)) await hass.async_block_till_done() assert len(platform1_setup.mock_calls) == 2 assert "test_domain.mod1" not in hass.config.components # This should not trigger attempt 3 async_fire_time_changed(hass, utcnow + timedelta(seconds=59)) await hass.async_block_till_done() assert len(platform1_setup.mock_calls) == 2 # Trigger attempt 3, which succeeds async_fire_time_changed(hass, utcnow + timedelta(seconds=60)) await hass.async_block_till_done() assert len(platform1_setup.mock_calls) == 3 assert "test_domain.mod1" in hass.config.components async def test_extract_from_service_fails_if_no_entity_id(hass): """Test the extraction of everything from service.""" component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities( [MockEntity(name="test_1"), MockEntity(name="test_2")] ) assert ( await component.async_extract_from_service(ha.ServiceCall("test", "service")) == [] ) assert ( await component.async_extract_from_service( ha.ServiceCall("test", "service", {"entity_id": ENTITY_MATCH_NONE}) ) == [] ) assert ( await component.async_extract_from_service( ha.ServiceCall("test", "service", {"area_id": ENTITY_MATCH_NONE}) ) == [] ) async def test_extract_from_service_filter_out_non_existing_entities(hass): """Test the extraction of non existing entities from service.""" component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities( [MockEntity(name="test_1"), MockEntity(name="test_2")] ) call = ha.ServiceCall( "test", "service", {"entity_id": ["test_domain.test_2", "test_domain.non_exist"]}, ) assert ["test_domain.test_2"] == [ ent.entity_id for ent in await component.async_extract_from_service(call) ] async def test_extract_from_service_no_group_expand(hass): """Test not expanding a group.""" component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities([MockEntity(entity_id="group.test_group")]) call = ha.ServiceCall("test", "service", {"entity_id": ["group.test_group"]}) extracted = await component.async_extract_from_service(call, expand_group=False) assert len(extracted) == 1 assert extracted[0].entity_id == "group.test_group" async def test_setup_dependencies_platform(hass): """Test we setup the dependencies of a platform. We're explicitly testing that we process dependencies even if a component with the same name has already been loaded. """ mock_integration( hass, MockModule("test_component", dependencies=["test_component2"]) ) mock_integration(hass, MockModule("test_component2")) mock_entity_platform(hass, "test_domain.test_component", MockPlatform()) component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_setup({DOMAIN: {"platform": "test_component"}}) await hass.async_block_till_done() assert "test_component" in hass.config.components assert "test_component2" in hass.config.components assert "test_domain.test_component" in hass.config.components async def test_setup_entry(hass): """Test setup entry calls async_setup_entry on platform.""" mock_setup_entry = AsyncMock(return_value=True) mock_entity_platform( hass, "test_domain.entry_domain", MockPlatform( async_setup_entry=mock_setup_entry, scan_interval=timedelta(seconds=5) ), ) component = EntityComponent(_LOGGER, DOMAIN, hass) entry = MockConfigEntry(domain="entry_domain") assert await component.async_setup_entry(entry) assert len(mock_setup_entry.mock_calls) == 1 p_hass, p_entry, _ = mock_setup_entry.mock_calls[0][1] assert p_hass is hass assert p_entry is entry assert component._platforms[entry.entry_id].scan_interval == timedelta(seconds=5) async def test_setup_entry_platform_not_exist(hass): """Test setup entry fails if platform does not exist.""" component = EntityComponent(_LOGGER, DOMAIN, hass) entry = MockConfigEntry(domain="non_existing") assert (await component.async_setup_entry(entry)) is False async def test_setup_entry_fails_duplicate(hass): """Test we don't allow setting up a config entry twice.""" mock_setup_entry = AsyncMock(return_value=True) mock_entity_platform( hass, "test_domain.entry_domain", MockPlatform(async_setup_entry=mock_setup_entry), ) component = EntityComponent(_LOGGER, DOMAIN, hass) entry = MockConfigEntry(domain="entry_domain") assert await component.async_setup_entry(entry) with pytest.raises(ValueError): await component.async_setup_entry(entry) async def test_unload_entry_resets_platform(hass): """Test unloading an entry removes all entities.""" mock_setup_entry = AsyncMock(return_value=True) mock_entity_platform( hass, "test_domain.entry_domain", MockPlatform(async_setup_entry=mock_setup_entry), ) component = EntityComponent(_LOGGER, DOMAIN, hass) entry = MockConfigEntry(domain="entry_domain") assert await component.async_setup_entry(entry) assert len(mock_setup_entry.mock_calls) == 1 add_entities = mock_setup_entry.mock_calls[0][1][2] add_entities([MockEntity()]) await hass.async_block_till_done() assert len(hass.states.async_entity_ids()) == 1 assert await component.async_unload_entry(entry) assert len(hass.states.async_entity_ids()) == 0 async def test_unload_entry_fails_if_never_loaded(hass): """.""" component = EntityComponent(_LOGGER, DOMAIN, hass) entry = MockConfigEntry(domain="entry_domain") with pytest.raises(ValueError): await component.async_unload_entry(entry) async def test_update_entity(hass): """Test that we can update an entity with the helper.""" component = EntityComponent(_LOGGER, DOMAIN, hass) entity = MockEntity() entity.async_write_ha_state = Mock() entity.async_update_ha_state = AsyncMock(return_value=None) await component.async_add_entities([entity]) # Called as part of async_add_entities assert len(entity.async_write_ha_state.mock_calls) == 1 await hass.helpers.entity_component.async_update_entity(entity.entity_id) assert len(entity.async_update_ha_state.mock_calls) == 1 assert entity.async_update_ha_state.mock_calls[-1][1][0] is True async def test_set_service_race(hass): """Test race condition on setting service.""" exception = False def async_loop_exception_handler(_, _2) -> None: """Handle all exception inside the core loop.""" nonlocal exception exception = True hass.loop.set_exception_handler(async_loop_exception_handler) await async_setup_component(hass, "group", {}) component = EntityComponent(_LOGGER, DOMAIN, hass) for _ in range(2): hass.async_create_task(component.async_add_entities([MockEntity()])) await hass.async_block_till_done() assert not exception async def test_extract_all_omit_entity_id(hass, caplog): """Test extract all with None and *.""" component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities( [MockEntity(name="test_1"), MockEntity(name="test_2")] ) call = ha.ServiceCall("test", "service") assert [] == sorted( ent.entity_id for ent in await component.async_extract_from_service(call) ) async def test_extract_all_use_match_all(hass, caplog): """Test extract all with None and *.""" component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities( [MockEntity(name="test_1"), MockEntity(name="test_2")] ) call = ha.ServiceCall("test", "service", {"entity_id": "all"}) assert ["test_domain.test_1", "test_domain.test_2"] == sorted( ent.entity_id for ent in await component.async_extract_from_service(call) ) assert ( "Not passing an entity ID to a service to target all entities is deprecated" ) not in caplog.text async def test_register_entity_service(hass): """Test not expanding a group.""" entity = MockEntity(entity_id=f"{DOMAIN}.entity") calls = [] @ha.callback def appender(**kwargs): calls.append(kwargs) entity.async_called_by_service = appender component = EntityComponent(_LOGGER, DOMAIN, hass) await component.async_add_entities([entity]) component.async_register_entity_service( "hello", {"some": str}, "async_called_by_service" ) with pytest.raises(vol.Invalid): await hass.services.async_call( DOMAIN, "hello", {"entity_id": entity.entity_id, "invalid": "data"}, blocking=True, ) assert len(calls) == 0 await hass.services.async_call( DOMAIN, "hello", {"entity_id": entity.entity_id, "some": "data"}, blocking=True ) assert len(calls) == 1 assert calls[0] == {"some": "data"} await hass.services.async_call( DOMAIN, "hello", {"entity_id": ENTITY_MATCH_ALL, "some": "data"}, blocking=True ) assert len(calls) == 2 assert calls[1] == {"some": "data"} await hass.services.async_call( DOMAIN, "hello", {"entity_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True ) assert len(calls) == 2 await hass.services.async_call( DOMAIN, "hello", {"area_id": ENTITY_MATCH_NONE, "some": "data"}, blocking=True ) assert len(calls) == 2