mirror of https://github.com/home-assistant/core
Make yaml file writes safer (#59384)
This commit is contained in:
parent
751098c220
commit
ebb25ab0e6
|
@ -11,6 +11,7 @@ from homeassistant.const import CONF_ID, EVENT_COMPONENT_LOADED
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import ATTR_COMPONENT
|
||||
from homeassistant.util.file import write_utf8_file
|
||||
from homeassistant.util.yaml import dump, load_yaml
|
||||
|
||||
DOMAIN = "config"
|
||||
|
@ -252,6 +253,5 @@ def _write(path, data):
|
|||
"""Write YAML helper."""
|
||||
# Do it before opening file. If dump causes error it will now not
|
||||
# truncate the file.
|
||||
data = dump(data)
|
||||
with open(path, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(data)
|
||||
contents = dump(data)
|
||||
write_utf8_file(path, contents)
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
"""File utility functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WriteError(HomeAssistantError):
|
||||
"""Error writing the data."""
|
||||
|
||||
|
||||
def write_utf8_file(
|
||||
filename: str,
|
||||
utf8_data: str,
|
||||
private: bool = False,
|
||||
) -> None:
|
||||
"""Write a file and rename it into place.
|
||||
|
||||
Writes all or nothing.
|
||||
"""
|
||||
|
||||
tmp_filename = ""
|
||||
tmp_path = os.path.split(filename)[0]
|
||||
try:
|
||||
# Modern versions of Python tempfile create this file with mode 0o600
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", encoding="utf-8", dir=tmp_path, delete=False
|
||||
) as fdesc:
|
||||
fdesc.write(utf8_data)
|
||||
tmp_filename = fdesc.name
|
||||
if not private:
|
||||
os.chmod(tmp_filename, 0o644)
|
||||
os.replace(tmp_filename, filename)
|
||||
except OSError as error:
|
||||
_LOGGER.exception("Saving file failed: %s", filename)
|
||||
raise WriteError(error) from error
|
||||
finally:
|
||||
if os.path.exists(tmp_filename):
|
||||
try:
|
||||
os.remove(tmp_filename)
|
||||
except OSError as err:
|
||||
# If we are cleaning up then something else went wrong, so
|
||||
# we should suppress likely follow-on errors in the cleanup
|
||||
_LOGGER.error(
|
||||
"File replacement cleanup failed for %s while saving %s: %s",
|
||||
tmp_filename,
|
||||
filename,
|
||||
err,
|
||||
)
|
|
@ -5,13 +5,13 @@ from collections import deque
|
|||
from collections.abc import Callable
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.core import Event, State
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .file import write_utf8_file
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -61,29 +61,7 @@ def save_json(
|
|||
_LOGGER.error(msg)
|
||||
raise SerializationError(msg) from error
|
||||
|
||||
tmp_filename = ""
|
||||
tmp_path = os.path.split(filename)[0]
|
||||
try:
|
||||
# Modern versions of Python tempfile create this file with mode 0o600
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", encoding="utf-8", dir=tmp_path, delete=False
|
||||
) as fdesc:
|
||||
fdesc.write(json_data)
|
||||
tmp_filename = fdesc.name
|
||||
if not private:
|
||||
os.chmod(tmp_filename, 0o644)
|
||||
os.replace(tmp_filename, filename)
|
||||
except OSError as error:
|
||||
_LOGGER.exception("Saving JSON file failed: %s", filename)
|
||||
raise WriteError(error) from error
|
||||
finally:
|
||||
if os.path.exists(tmp_filename):
|
||||
try:
|
||||
os.remove(tmp_filename)
|
||||
except OSError as err:
|
||||
# If we are cleaning up then something else went wrong, so
|
||||
# we should suppress likely follow-on errors in the cleanup
|
||||
_LOGGER.error("JSON replacement cleanup failed: %s", err)
|
||||
write_utf8_file(filename, json_data, private)
|
||||
|
||||
|
||||
def format_unserializable_data(data: dict[str, Any]) -> str:
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
"""Test Group config panel."""
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from homeassistant.bootstrap import async_setup_component
|
||||
from homeassistant.components import config
|
||||
from homeassistant.components.config import group
|
||||
from homeassistant.util.file import write_utf8_file
|
||||
from homeassistant.util.yaml import dump, load_yaml
|
||||
|
||||
VIEW_NAME = "api:config:group:config"
|
||||
|
||||
|
@ -113,3 +117,49 @@ async def test_update_device_config_invalid_json(hass, hass_client):
|
|||
resp = await client.post("/api/config/group/config/hello_beer", data="not json")
|
||||
|
||||
assert resp.status == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
async def test_update_config_write_to_temp_file(hass, hass_client, tmpdir):
|
||||
"""Test config with a temp file."""
|
||||
test_dir = await hass.async_add_executor_job(tmpdir.mkdir, "files")
|
||||
group_yaml = Path(test_dir / "group.yaml")
|
||||
|
||||
with patch.object(group, "GROUP_CONFIG_PATH", group_yaml), patch.object(
|
||||
config, "SECTIONS", ["group"]
|
||||
):
|
||||
await async_setup_component(hass, "config", {})
|
||||
|
||||
client = await hass_client()
|
||||
|
||||
orig_data = {
|
||||
"hello.beer": {"ignored": True},
|
||||
"other.entity": {"polling_intensity": 2},
|
||||
}
|
||||
contents = dump(orig_data)
|
||||
await hass.async_add_executor_job(write_utf8_file, group_yaml, contents)
|
||||
|
||||
mock_call = AsyncMock()
|
||||
|
||||
with patch.object(hass.services, "async_call", mock_call):
|
||||
resp = await client.post(
|
||||
"/api/config/group/config/hello_beer",
|
||||
data=json.dumps(
|
||||
{"name": "Beer", "entities": ["light.top", "light.bottom"]}
|
||||
),
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
result = await resp.json()
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
new_data = await hass.async_add_executor_job(load_yaml, group_yaml)
|
||||
|
||||
assert new_data == {
|
||||
**orig_data,
|
||||
"hello_beer": {
|
||||
"name": "Beer",
|
||||
"entities": ["light.top", "light.bottom"],
|
||||
},
|
||||
}
|
||||
mock_call.assert_called_once_with("group", "reload")
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
"""Test Home Assistant file utility functions."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.util.file import WriteError, write_utf8_file
|
||||
|
||||
|
||||
def test_write_utf8_file_private(tmpdir):
|
||||
"""Test files can be written as 0o600 or 0o644."""
|
||||
test_dir = tmpdir.mkdir("files")
|
||||
test_file = Path(test_dir / "test.json")
|
||||
|
||||
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||
with open(test_file) as fh:
|
||||
assert fh.read() == '{"some":"data"}'
|
||||
assert os.stat(test_file).st_mode & 0o777 == 0o644
|
||||
|
||||
write_utf8_file(test_file, '{"some":"data"}', True)
|
||||
with open(test_file) as fh:
|
||||
assert fh.read() == '{"some":"data"}'
|
||||
assert os.stat(test_file).st_mode & 0o777 == 0o600
|
||||
|
||||
|
||||
def test_write_utf8_file_fails_at_creation(tmpdir):
|
||||
"""Test that failed creation of the temp file does not create an empty file."""
|
||||
test_dir = tmpdir.mkdir("files")
|
||||
test_file = Path(test_dir / "test.json")
|
||||
|
||||
with pytest.raises(WriteError), patch(
|
||||
"homeassistant.util.file.tempfile.NamedTemporaryFile", side_effect=OSError
|
||||
):
|
||||
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||
|
||||
assert not os.path.exists(test_file)
|
||||
|
||||
|
||||
def test_write_utf8_file_fails_at_rename(tmpdir, caplog):
|
||||
"""Test that if rename fails not not remove, we do not log the failed cleanup."""
|
||||
test_dir = tmpdir.mkdir("files")
|
||||
test_file = Path(test_dir / "test.json")
|
||||
|
||||
with pytest.raises(WriteError), patch(
|
||||
"homeassistant.util.file.os.replace", side_effect=OSError
|
||||
):
|
||||
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||
|
||||
assert not os.path.exists(test_file)
|
||||
|
||||
assert "File replacement cleanup failed" not in caplog.text
|
||||
|
||||
|
||||
def test_write_utf8_file_fails_at_rename_and_remove(tmpdir, caplog):
|
||||
"""Test that if rename and remove both fail, we log the failed cleanup."""
|
||||
test_dir = tmpdir.mkdir("files")
|
||||
test_file = Path(test_dir / "test.json")
|
||||
|
||||
with pytest.raises(WriteError), patch(
|
||||
"homeassistant.util.file.os.remove", side_effect=OSError
|
||||
), patch("homeassistant.util.file.os.replace", side_effect=OSError):
|
||||
write_utf8_file(test_file, '{"some":"data"}', False)
|
||||
|
||||
assert "File replacement cleanup failed" in caplog.text
|
Loading…
Reference in New Issue