Properly handle errors and use consistent content type in logs endpoints

This commit is contained in:
Jan Čermák 2024-03-26 17:46:56 +01:00
parent 39339249d7
commit afeea5b5b8
No known key found for this signature in database
GPG Key ID: A78C897AA3AF012B
7 changed files with 73 additions and 19 deletions

View File

@ -7,7 +7,6 @@ from typing import Any
from aiohttp import web from aiohttp import web
from aiohttp_fast_url_dispatcher import FastUrlDispatcher, attach_fast_url_dispatcher from aiohttp_fast_url_dispatcher import FastUrlDispatcher, attach_fast_url_dispatcher
from ..addons.addon import Addon
from ..const import AddonState from ..const import AddonState
from ..coresys import CoreSys, CoreSysAttributes from ..coresys import CoreSys, CoreSysAttributes
from ..exceptions import APIAddonNotInstalled from ..exceptions import APIAddonNotInstalled
@ -17,6 +16,7 @@ from .audio import APIAudio
from .auth import APIAuth from .auth import APIAuth
from .backups import APIBackups from .backups import APIBackups
from .cli import APICli from .cli import APICli
from .const import CONTENT_TYPE_TEXT
from .discovery import APIDiscovery from .discovery import APIDiscovery
from .dns import APICoreDNS from .dns import APICoreDNS
from .docker import APIDocker from .docker import APIDocker
@ -38,7 +38,7 @@ from .security import APISecurity
from .services import APIServices from .services import APIServices
from .store import APIStore from .store import APIStore
from .supervisor import APISupervisor from .supervisor import APISupervisor
from .utils import api_process from .utils import api_process, api_process_custom
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -518,8 +518,9 @@ class RestAPI(CoreSysAttributes):
] ]
) )
@api_process_custom(CONTENT_TYPE_TEXT)
async def get_addon_logs(request, *args, **kwargs): async def get_addon_logs(request, *args, **kwargs):
addon: Addon = api_addons.get_addon_for_request(request) addon = api_addons.get_addon_for_request(request)
kwargs["identifier"] = f"addon_{addon.slug}" kwargs["identifier"] = f"addon_{addon.slug}"
return await self._api_host.advanced_logs(request, *args, **kwargs) return await self._api_host.advanced_logs(request, *args, **kwargs)

View File

@ -53,7 +53,7 @@ from .const import (
CONTENT_TYPE_TEXT, CONTENT_TYPE_TEXT,
CONTENT_TYPE_X_LOG, CONTENT_TYPE_X_LOG,
) )
from .utils import api_process, api_validate from .utils import api_process, api_process_custom, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -163,7 +163,7 @@ class APIHost(CoreSysAttributes):
raise APIError() from err raise APIError() from err
return possible_offset return possible_offset
@api_process @api_process_custom(CONTENT_TYPE_TEXT)
async def advanced_logs( async def advanced_logs(
self, request: web.Request, identifier: str | None = None, follow: bool = False self, request: web.Request, identifier: str | None = None, follow: bool = False
) -> web.StreamResponse: ) -> web.StreamResponse:

View File

@ -49,8 +49,8 @@ from ..store.validate import repositories
from ..utils.sentry import close_sentry, init_sentry from ..utils.sentry import close_sentry, init_sentry
from ..utils.validate import validate_timezone from ..utils.validate import validate_timezone
from ..validate import version_tag, wait_boot from ..validate import version_tag, wait_boot
from .const import CONTENT_TYPE_BINARY from .const import CONTENT_TYPE_TEXT
from .utils import api_process, api_process_raw, api_validate from .utils import api_process, api_process_custom, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -229,7 +229,7 @@ class APISupervisor(CoreSysAttributes):
"""Soft restart Supervisor.""" """Soft restart Supervisor."""
return asyncio.shield(self.sys_supervisor.restart()) return asyncio.shield(self.sys_supervisor.restart())
@api_process_raw(CONTENT_TYPE_BINARY) @api_process_custom(CONTENT_TYPE_TEXT)
def logs(self, request: web.Request) -> Awaitable[bytes]: def logs(self, request: web.Request) -> Awaitable[bytes]:
"""Return supervisor Docker logs.""" """Return supervisor Docker logs."""
return self.sys_supervisor.logs() return self.sys_supervisor.logs()

View File

@ -91,7 +91,7 @@ def require_home_assistant(method):
return wrap_api return wrap_api
def api_process_raw(content): def api_process_raw(content, *, error_type=None):
"""Wrap content_type into function.""" """Wrap content_type into function."""
def wrap_method(method): def wrap_method(method):
@ -101,13 +101,15 @@ def api_process_raw(content):
"""Return api information.""" """Return api information."""
try: try:
msg_data = await method(api, *args, **kwargs) msg_data = await method(api, *args, **kwargs)
if isinstance(msg_data, (web.Response, web.StreamResponse)):
return msg_data
msg_type = content msg_type = content
except (APIError, APIForbidden) as err: except (APIError, APIForbidden, HassioError) as err:
msg_data = str(err).encode() msg_data = str(err).encode()
msg_type = CONTENT_TYPE_BINARY msg_type = error_type or CONTENT_TYPE_BINARY
except HassioError: except HassioError:
msg_data = b"" msg_data = b""
msg_type = CONTENT_TYPE_BINARY msg_type = error_type or CONTENT_TYPE_BINARY
return web.Response(body=msg_data, content_type=msg_type) return web.Response(body=msg_data, content_type=msg_type)
@ -116,6 +118,28 @@ def api_process_raw(content):
return wrap_method return wrap_method
def api_process_custom(content_type):
"""Ensure errors are handled and returned with specified content_type."""
def decorator(method):
async def wrapper(api, *args, **kwargs):
status = 200
try:
response = await method(api, *args, **kwargs)
except HassioError as err:
response = str(err)
status = 400
if isinstance(response, (web.Response, web.StreamResponse)):
return response
return web.Response(body=response, status=status, content_type=content_type)
return wrapper
return decorator
def api_return_error( def api_return_error(
error: Exception | None = None, message: str | None = None error: Exception | None = None, message: str | None = None
) -> web.Response: ) -> web.Response:

View File

@ -13,6 +13,7 @@ from supervisor.coresys import CoreSys
from supervisor.docker.addon import DockerAddon from supervisor.docker.addon import DockerAddon
from supervisor.docker.const import ContainerState from supervisor.docker.const import ContainerState
from supervisor.docker.monitor import DockerContainerStateEvent from supervisor.docker.monitor import DockerContainerStateEvent
from supervisor.exceptions import HassioError
from supervisor.store.repository import Repository from supervisor.store.repository import Repository
from ..const import TEST_ADDON_SLUG from ..const import TEST_ADDON_SLUG
@ -76,6 +77,32 @@ async def test_api_addon_logs(
) )
async def test_api_addon_logs_not_installed(api_client: TestClient):
"""Test error is returned for non-existing add-on."""
resp = await api_client.get("/addons/hic_sunt_leones/logs")
assert resp.status == 400
assert resp.content_type == "text/plain"
content = await resp.text()
assert content == "Addon hic_sunt_leones does not exist"
async def test_api_addon_logs_error(
api_client: TestClient,
journald_logs: MagicMock,
docker_logs: MagicMock,
install_addon_ssh: Addon,
):
"""Test errors are properly handled for add-on logs."""
journald_logs.side_effect = HassioError("Something bad happened!")
resp = await api_client.get("/addons/local_ssh/logs")
assert resp.status == 400
assert resp.content_type == "text/plain"
content = await resp.text()
assert content == "Something bad happened!"
async def test_api_addon_start_healthcheck( async def test_api_addon_start_healthcheck(
api_client: TestClient, api_client: TestClient,
coresys: CoreSys, coresys: CoreSys,

View File

@ -310,15 +310,17 @@ async def test_advanced_logs_errors(api_client: TestClient):
"""Test advanced logging API errors.""" """Test advanced logging API errors."""
# coresys = coresys_logs_control # coresys = coresys_logs_control
resp = await api_client.get("/host/logs") resp = await api_client.get("/host/logs")
result = await resp.json() assert resp.content_type == "text/plain"
assert result["result"] == "error" assert resp.status == 400
assert result["message"] == "No systemd-journal-gatewayd Unix socket available" content = await resp.text()
assert content == "No systemd-journal-gatewayd Unix socket available"
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
resp = await api_client.get("/host/logs", headers=headers) resp = await api_client.get("/host/logs", headers=headers)
result = await resp.json() assert resp.content_type == "text/plain"
assert result["result"] == "error" assert resp.status == 400
content = await resp.text()
assert ( assert (
result["message"] content
== "Invalid content type requested. Only text/plain and text/x-log supported for now." == "Invalid content type requested. Only text/plain and text/x-log supported for now."
) )

View File

@ -169,7 +169,7 @@ async def test_api_supervisor_fallback(
) )
assert resp.status == 200 assert resp.status == 200
assert resp.content_type == "application/octet-stream" assert resp.content_type == "text/plain"
content = await resp.read() content = await resp.read()
assert content.split(b"\n")[0:2] == [ assert content.split(b"\n")[0:2] == [
b"\x1b[36m22-10-11 14:04:23 DEBUG (MainThread) [supervisor.utils.dbus] D-Bus call - org.freedesktop.DBus.Properties.call_get_all on /io/hass/os\x1b[0m", b"\x1b[36m22-10-11 14:04:23 DEBUG (MainThread) [supervisor.utils.dbus] D-Bus call - org.freedesktop.DBus.Properties.call_get_all on /io/hass/os\x1b[0m",