Prevent deleting blueprints which are in use (#78444)

This commit is contained in:
Erik Montnemery 2022-09-14 16:47:08 +02:00 committed by GitHub
parent 855b0dfdba
commit 2ba0f42acc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 192 additions and 8 deletions

View File

@ -9,6 +9,7 @@ import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components import blueprint
from homeassistant.components.blueprint import CONF_USE_BLUEPRINT
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_MODE,
@ -20,6 +21,7 @@ from homeassistant.const import (
CONF_EVENT_DATA,
CONF_ID,
CONF_MODE,
CONF_PATH,
CONF_PLATFORM,
CONF_VARIABLES,
CONF_ZONE,
@ -233,6 +235,21 @@ def areas_in_automation(hass: HomeAssistant, entity_id: str) -> list[str]:
return list(cast(AutomationEntity, automation_entity).referenced_areas)
@callback
def automations_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str]:
"""Return all automations that reference the blueprint."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
return [
automation_entity.entity_id
for automation_entity in component.entities
if automation_entity.referenced_blueprint == blueprint_path
]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up all automations."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
@ -356,6 +373,13 @@ class AutomationEntity(ToggleEntity, RestoreEntity):
"""Return a set of referenced areas."""
return self.action_script.referenced_areas
@property
def referenced_blueprint(self) -> str | None:
"""Return referenced blueprint or None."""
if self._blueprint_inputs is None:
return None
return cast(str, self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH])
@property
def referenced_devices(self) -> set[str]:
"""Return a set of referenced devices."""

View File

@ -8,8 +8,15 @@ from .const import DOMAIN, LOGGER
DATA_BLUEPRINTS = "automation_blueprints"
def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
"""Return True if any automation references the blueprint."""
from . import automations_with_blueprint # pylint: disable=import-outside-toplevel
return len(automations_with_blueprint(hass, blueprint_path)) > 0
@singleton(DATA_BLUEPRINTS)
@callback
def async_get_blueprints(hass: HomeAssistant) -> blueprint.DomainBlueprints:
"""Get automation blueprints."""
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER)
return blueprint.DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)

View File

@ -3,7 +3,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.typing import ConfigType
from . import websocket_api
from .const import DOMAIN # noqa: F401
from .const import CONF_USE_BLUEPRINT, DOMAIN # noqa: F401
from .errors import ( # noqa: F401
BlueprintException,
BlueprintWithNameException,

View File

@ -91,3 +91,11 @@ class FileAlreadyExists(BlueprintWithNameException):
def __init__(self, domain: str, blueprint_name: str) -> None:
"""Initialize blueprint exception."""
super().__init__(domain, blueprint_name, "Blueprint already exists")
class BlueprintInUse(BlueprintWithNameException):
"""Error when a blueprint is in use."""
def __init__(self, domain: str, blueprint_name: str) -> None:
"""Initialize blueprint exception."""
super().__init__(domain, blueprint_name, "Blueprint in use")

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
import logging
import pathlib
import shutil
@ -35,6 +36,7 @@ from .const import (
)
from .errors import (
BlueprintException,
BlueprintInUse,
FailedToLoad,
FileAlreadyExists,
InvalidBlueprint,
@ -183,11 +185,13 @@ class DomainBlueprints:
hass: HomeAssistant,
domain: str,
logger: logging.Logger,
blueprint_in_use: Callable[[HomeAssistant, str], bool],
) -> None:
"""Initialize a domain blueprints instance."""
self.hass = hass
self.domain = domain
self.logger = logger
self._blueprint_in_use = blueprint_in_use
self._blueprints: dict[str, Blueprint | None] = {}
self._load_lock = asyncio.Lock()
@ -302,6 +306,8 @@ class DomainBlueprints:
async def async_remove_blueprint(self, blueprint_path: str) -> None:
"""Remove a blueprint file."""
if self._blueprint_in_use(self.hass, blueprint_path):
raise BlueprintInUse(self.domain, blueprint_path)
path = self.blueprint_folder / blueprint_path
await self.hass.async_add_executor_job(path.unlink)
self._blueprints[blueprint_path] = None

View File

@ -8,7 +8,7 @@ from typing import Any, cast
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.components.blueprint import BlueprintInputs
from homeassistant.components.blueprint import CONF_USE_BLUEPRINT, BlueprintInputs
from homeassistant.const import (
ATTR_ENTITY_ID,
ATTR_MODE,
@ -18,6 +18,7 @@ from homeassistant.const import (
CONF_ICON,
CONF_MODE,
CONF_NAME,
CONF_PATH,
CONF_SEQUENCE,
CONF_VARIABLES,
SERVICE_RELOAD,
@ -165,6 +166,21 @@ def areas_in_script(hass: HomeAssistant, entity_id: str) -> list[str]:
return list(script_entity.script.referenced_areas)
@callback
def scripts_with_blueprint(hass: HomeAssistant, blueprint_path: str) -> list[str]:
"""Return all scripts that reference the blueprint."""
if DOMAIN not in hass.data:
return []
component = hass.data[DOMAIN]
return [
script_entity.entity_id
for script_entity in component.entities
if script_entity.referenced_blueprint == blueprint_path
]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Load the scripts from the configuration."""
hass.data[DOMAIN] = component = EntityComponent(LOGGER, DOMAIN, hass)
@ -372,6 +388,13 @@ class ScriptEntity(ToggleEntity, RestoreEntity):
"""Return true if script is on."""
return self.script.is_running
@property
def referenced_blueprint(self):
"""Return referenced blueprint or None."""
if self._blueprint_inputs is None:
return None
return self._blueprint_inputs[CONF_USE_BLUEPRINT][CONF_PATH]
@callback
def async_change_listener(self):
"""Update state."""

View File

@ -8,8 +8,15 @@ from .const import DOMAIN, LOGGER
DATA_BLUEPRINTS = "script_blueprints"
def _blueprint_in_use(hass: HomeAssistant, blueprint_path: str) -> bool:
"""Return True if any script references the blueprint."""
from . import scripts_with_blueprint # pylint: disable=import-outside-toplevel
return len(scripts_with_blueprint(hass, blueprint_path)) > 0
@singleton(DATA_BLUEPRINTS)
@callback
def async_get_blueprints(hass: HomeAssistant) -> DomainBlueprints:
"""Get script blueprints."""
return DomainBlueprints(hass, DOMAIN, LOGGER)
return DomainBlueprints(hass, DOMAIN, LOGGER, _blueprint_in_use)

View File

@ -47,7 +47,9 @@ def blueprint_2():
@pytest.fixture
def domain_bps(hass):
"""Domain blueprints fixture."""
return models.DomainBlueprints(hass, "automation", logging.getLogger(__name__))
return models.DomainBlueprints(
hass, "automation", logging.getLogger(__name__), None
)
def test_blueprint_model_init():

View File

@ -8,13 +8,26 @@ from homeassistant.setup import async_setup_component
from homeassistant.util.yaml import parse_yaml
@pytest.fixture
def automation_config():
"""Automation config."""
return {}
@pytest.fixture
def script_config():
"""Script config."""
return {}
@pytest.fixture(autouse=True)
async def setup_bp(hass):
async def setup_bp(hass, automation_config, script_config):
"""Fixture to set up the blueprint component."""
assert await async_setup_component(hass, "blueprint", {})
# Trigger registration of automation blueprints
await async_setup_component(hass, "automation", {})
# Trigger registration of automation and script blueprints
await async_setup_component(hass, "automation", automation_config)
await async_setup_component(hass, "script", script_config)
async def test_list_blueprints(hass, hass_ws_client):
@ -251,3 +264,89 @@ async def test_delete_non_exist_file_blueprint(hass, aioclient_mock, hass_ws_cli
assert msg["id"] == 9
assert not msg["success"]
@pytest.mark.parametrize(
"automation_config",
(
{
"automation": {
"use_blueprint": {
"path": "test_event_service.yaml",
"input": {
"trigger_event": "blueprint_event",
"service_to_call": "test.automation",
"a_number": 5,
},
}
}
},
),
)
async def test_delete_blueprint_in_use_by_automation(
hass, aioclient_mock, hass_ws_client
):
"""Test deleting a blueprint which is in use."""
with patch("pathlib.Path.unlink", return_value=Mock()) as unlink_mock:
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 9,
"type": "blueprint/delete",
"path": "test_event_service.yaml",
"domain": "automation",
}
)
msg = await client.receive_json()
assert not unlink_mock.mock_calls
assert msg["id"] == 9
assert not msg["success"]
assert msg["error"] == {
"code": "unknown_error",
"message": "Blueprint in use",
}
@pytest.mark.parametrize(
"script_config",
(
{
"script": {
"test_script": {
"use_blueprint": {
"path": "test_service.yaml",
"input": {
"service_to_call": "test.automation",
},
}
}
}
},
),
)
async def test_delete_blueprint_in_use_by_script(hass, aioclient_mock, hass_ws_client):
"""Test deleting a blueprint which is in use."""
with patch("pathlib.Path.unlink", return_value=Mock()) as unlink_mock:
client = await hass_ws_client(hass)
await client.send_json(
{
"id": 9,
"type": "blueprint/delete",
"path": "test_service.yaml",
"domain": "script",
}
)
msg = await client.receive_json()
assert not unlink_mock.mock_calls
assert msg["id"] == 9
assert not msg["success"]
assert msg["error"] == {
"code": "unknown_error",
"message": "Blueprint in use",
}

View File

@ -0,0 +1,8 @@
blueprint:
name: "Call service"
domain: script
input:
service_to_call:
sequence:
service: !input service_to_call
entity_id: light.kitchen