ha-supervisor/supervisor/api/utils.py

196 lines
5.7 KiB
Python

"""Init file for Supervisor util for RESTful API."""
import json
from typing import Any
from aiohttp import web
from aiohttp.hdrs import AUTHORIZATION
from aiohttp.web_exceptions import HTTPUnauthorized
from aiohttp.web_request import Request
import voluptuous as vol
from voluptuous.humanize import humanize_error
from ..const import (
HEADER_TOKEN,
HEADER_TOKEN_OLD,
JSON_DATA,
JSON_JOB_ID,
JSON_MESSAGE,
JSON_RESULT,
REQUEST_FROM,
RESULT_ERROR,
RESULT_OK,
)
from ..coresys import CoreSys
from ..exceptions import APIError, APIForbidden, DockerAPIError, HassioError
from ..utils import check_exception_chain, get_message_from_exception_chain
from ..utils.json import json_dumps, json_loads as json_loads_util
from ..utils.log_format import format_message
from .const import CONTENT_TYPE_BINARY
def excract_supervisor_token(request: web.Request) -> str | None:
"""Extract Supervisor token from request."""
if supervisor_token := request.headers.get(HEADER_TOKEN):
return supervisor_token
# Old Supervisor fallback
if supervisor_token := request.headers.get(HEADER_TOKEN_OLD):
return supervisor_token
# API access only
if supervisor_token := request.headers.get(AUTHORIZATION):
return supervisor_token.split(" ")[-1]
return None
def json_loads(data: Any) -> dict[str, Any]:
"""Extract json from string with support for '' and None."""
if not data:
return {}
try:
return json_loads_util(data)
except json.JSONDecodeError as err:
raise APIError("Invalid json") from err
def api_process(method):
"""Wrap function with true/false calls to rest api."""
async def wrap_api(api, *args, **kwargs):
"""Return API information."""
try:
answer = await method(api, *args, **kwargs)
except (APIError, APIForbidden, HassioError) as err:
return api_return_error(error=err)
if isinstance(answer, (dict, list)):
return api_return_ok(data=answer)
if isinstance(answer, web.Response):
return answer
if isinstance(answer, web.StreamResponse):
return answer
elif isinstance(answer, bool) and not answer:
return api_return_error()
return api_return_ok()
return wrap_api
def require_home_assistant(method):
"""Ensure that the request comes from Home Assistant."""
async def wrap_api(api, *args, **kwargs):
"""Return API information."""
coresys: CoreSys = api.coresys
request: Request = args[0]
if request[REQUEST_FROM] != coresys.homeassistant:
raise HTTPUnauthorized()
return await method(api, *args, **kwargs)
return wrap_api
def api_process_raw(content, *, error_type=None):
"""Wrap content_type into function."""
def wrap_method(method):
"""Wrap function with raw output to rest api."""
async def wrap_api(api, *args, **kwargs):
"""Return api information."""
try:
msg_data = await method(api, *args, **kwargs)
if isinstance(msg_data, (web.Response, web.StreamResponse)):
return msg_data
msg_type = content
except (APIError, APIForbidden, HassioError) as err:
msg_data = str(err).encode()
msg_type = error_type or CONTENT_TYPE_BINARY
except HassioError:
msg_data = b""
msg_type = error_type or CONTENT_TYPE_BINARY
return web.Response(body=msg_data, content_type=msg_type)
return wrap_api
return wrap_method
def api_process_custom(content_type):
"""Ensure errors are handled and returned with specified content_type."""
def decorator(method):
async def wrapper(api, *args, **kwargs):
status = 200
try:
response = await method(api, *args, **kwargs)
except HassioError as err:
response = str(err)
status = 400
if isinstance(response, (web.Response, web.StreamResponse)):
return response
return web.Response(body=response, status=status, content_type=content_type)
return wrapper
return decorator
def api_return_error(
error: Exception | None = None, message: str | None = None
) -> web.Response:
"""Return an API error message."""
if error and not message:
message = get_message_from_exception_chain(error)
if check_exception_chain(error, DockerAPIError):
message = format_message(message)
result = {
JSON_RESULT: RESULT_ERROR,
JSON_MESSAGE: message or "Unknown error, see supervisor",
}
status = 400
if isinstance(error, APIError):
status = error.status
if error.job_id:
result[JSON_JOB_ID] = error.job_id
return web.json_response(
result,
status=status,
dumps=json_dumps,
)
def api_return_ok(data: dict[str, Any] | None = None) -> web.Response:
"""Return an API ok answer."""
return web.json_response(
{JSON_RESULT: RESULT_OK, JSON_DATA: data or {}},
dumps=json_dumps,
)
async def api_validate(
schema: vol.Schema, request: web.Request, origin: list[str] | None = None
) -> dict[str, Any]:
"""Validate request data with schema."""
data: dict[str, Any] = await request.json(loads=json_loads)
try:
data_validated = schema(data)
except vol.Invalid as ex:
raise APIError(humanize_error(data, ex)) from None
if not origin:
return data_validated
for origin_value in origin:
if origin_value not in data_validated:
continue
data_validated[origin_value] = data[origin_value]
return data_validated