1
mirror of https://github.com/home-assistant/core synced 2024-08-15 18:25:44 +02:00

Fix overriding a script's entity_id (#78765)

This commit is contained in:
Erik Montnemery 2022-09-28 10:37:34 +02:00 committed by GitHub
parent 4a432db611
commit 6ef33b1d39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 18 deletions

View File

@ -28,7 +28,7 @@ from homeassistant.const import (
STATE_ON,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.helpers import extract_domain_configs
from homeassistant.helpers import entity_registry as er, extract_domain_configs
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.config_validation import make_entity_service_schema
from homeassistant.helpers.entity import ToggleEntity
@ -237,7 +237,7 @@ async def _async_process_config(hass, config, component) -> bool:
for config_key in extract_domain_configs(config, DOMAIN):
conf: dict[str, dict[str, Any] | BlueprintInputs] = config[config_key]
for object_id, config_block in conf.items():
for key, config_block in conf.items():
raw_blueprint_inputs = None
raw_config = None
@ -264,16 +264,15 @@ async def _async_process_config(hass, config, component) -> bool:
raw_config = cast(ScriptConfig, config_block).raw_config
entities.append(
ScriptEntity(
hass, object_id, config_block, raw_config, raw_blueprint_inputs
)
ScriptEntity(hass, key, config_block, raw_config, raw_blueprint_inputs)
)
await component.async_add_entities(entities)
async def service_handler(service: ServiceCall) -> None:
"""Execute a service call to script.<script name>."""
entity_id = ENTITY_ID_FORMAT.format(service.service)
entity_registry = er.async_get(hass)
entity_id = entity_registry.async_get_entity_id(DOMAIN, DOMAIN, service.service)
script_entity = component.get_entity(entity_id)
await script_entity.async_turn_on(
variables=service.data, context=service.context
@ -282,7 +281,7 @@ async def _async_process_config(hass, config, component) -> bool:
# Register services for all entities that were created successfully.
for entity in entities:
hass.services.async_register(
DOMAIN, entity.object_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA
DOMAIN, entity.unique_id, service_handler, schema=SCRIPT_SERVICE_SCHEMA
)
# Register the service description
@ -291,7 +290,8 @@ async def _async_process_config(hass, config, component) -> bool:
CONF_DESCRIPTION: entity.description,
CONF_FIELDS: entity.fields,
}
async_set_service_schema(hass, DOMAIN, entity.object_id, service_desc)
unique_id = cast(str, entity.unique_id)
async_set_service_schema(hass, DOMAIN, unique_id, service_desc)
return blueprints_used
@ -301,29 +301,27 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
icon = None
def __init__(self, hass, object_id, cfg, raw_config, blueprint_inputs):
def __init__(self, hass, key, cfg, raw_config, blueprint_inputs):
"""Initialize the script."""
self.object_id = object_id
self.icon = cfg.get(CONF_ICON)
self.description = cfg[CONF_DESCRIPTION]
self.fields = cfg[CONF_FIELDS]
# The object ID of scripts need / are unique already
# they cannot be changed from the UI after creating
self._attr_unique_id = object_id
# The key of scripts are unique and cannot be changed from the UI after creating
self._attr_unique_id = key
self.entity_id = ENTITY_ID_FORMAT.format(object_id)
self.entity_id = ENTITY_ID_FORMAT.format(key)
self.script = Script(
hass,
cfg[CONF_SEQUENCE],
cfg.get(CONF_ALIAS, object_id),
cfg.get(CONF_ALIAS, key),
DOMAIN,
running_description="script sequence",
change_listener=self.async_change_listener,
script_mode=cfg[CONF_MODE],
max_runs=cfg[CONF_MAX],
max_exceeded=cfg[CONF_MAX_EXCEEDED],
logger=logging.getLogger(f"{__name__}.{object_id}"),
logger=logging.getLogger(f"{__name__}.{key}"),
variables=cfg.get(CONF_VARIABLES),
)
self._changed = asyncio.Event()
@ -407,7 +405,7 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
async def _async_run(self, variables, context):
with trace_script(
self.hass,
self.object_id,
self.unique_id,
self._raw_config,
self._blueprint_inputs,
context,
@ -440,4 +438,4 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
await self.script.async_stop()
# remove service
self.hass.services.async_remove(DOMAIN, self.object_id)
self.hass.services.async_remove(DOMAIN, self.unique_id)

View File

@ -1023,3 +1023,44 @@ async def test_setup_with_duplicate_scripts(
)
assert "Duplicate script detected with name: 'duplicate'" in caplog.text
assert len(hass.states.async_entity_ids("script")) == 1
async def test_script_service_changed_entity_id(hass: HomeAssistant) -> None:
"""Test the script service works for scripts with overridden entity_id."""
entity_reg = er.async_get(hass)
entry = entity_reg.async_get_or_create("script", "script", "test")
entry = entity_reg.async_update_entity(
entry.entity_id, new_entity_id="script.custom_entity_id"
)
assert entry.entity_id == "script.custom_entity_id"
calls = []
@callback
def record_call(service):
"""Add recorded event to set."""
calls.append(service)
hass.services.async_register("test", "script", record_call)
assert await async_setup_component(
hass,
"script",
{
"script": {
"test": {
"sequence": {
"service": "test.script",
"data_template": {"entity_id": "{{ this.entity_id }}"},
}
}
}
},
)
await hass.services.async_call(DOMAIN, "test", {"greeting": "world"})
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].data["entity_id"] == "script.custom_entity_id"