"""Helper to track the current http request.""" from __future__ import annotations import asyncio from collections.abc import Awaitable, Callable from contextvars import ContextVar from http import HTTPStatus import logging from typing import Any, Final from aiohttp import web from aiohttp.typedefs import LooseHeaders from aiohttp.web import AppKey, Request from aiohttp.web_exceptions import ( HTTPBadRequest, HTTPInternalServerError, HTTPUnauthorized, ) from aiohttp.web_urldispatcher import AbstractResource, AbstractRoute import voluptuous as vol from homeassistant import exceptions from homeassistant.const import CONTENT_TYPE_JSON from homeassistant.core import Context, HomeAssistant, is_callback from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data from .json import find_paths_unserializable_data, json_bytes, json_dumps _LOGGER = logging.getLogger(__name__) type AllowCorsType = Callable[[AbstractRoute | AbstractResource], None] KEY_AUTHENTICATED: Final = "ha_authenticated" KEY_ALLOW_ALL_CORS = AppKey[AllowCorsType]("allow_all_cors") KEY_ALLOW_CONFIGRED_CORS = AppKey[AllowCorsType]("allow_configured_cors") KEY_HASS: AppKey[HomeAssistant] = AppKey("hass") current_request: ContextVar[Request | None] = ContextVar( "current_request", default=None ) def request_handler_factory( hass: HomeAssistant, view: HomeAssistantView, handler: Callable ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: """Wrap the handler classes.""" is_coroutinefunction = asyncio.iscoroutinefunction(handler) assert is_coroutinefunction or is_callback( handler ), "Handler should be a coroutine or a callback." async def handle(request: web.Request) -> web.StreamResponse: """Handle incoming request.""" if hass.is_stopping: return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE) authenticated = request.get(KEY_AUTHENTICATED, False) if view.requires_auth and not authenticated: raise HTTPUnauthorized if _LOGGER.isEnabledFor(logging.DEBUG): _LOGGER.debug( "Serving %s to %s (auth: %s)", request.path, request.remote, authenticated, ) try: if is_coroutinefunction: result = await handler(request, **request.match_info) else: result = handler(request, **request.match_info) except vol.Invalid as err: raise HTTPBadRequest from err except exceptions.ServiceNotFound as err: raise HTTPInternalServerError from err except exceptions.Unauthorized as err: raise HTTPUnauthorized from err if isinstance(result, web.StreamResponse): # The method handler returned a ready-made Response, how nice of it return result status_code = HTTPStatus.OK if isinstance(result, tuple): result, status_code = result if isinstance(result, bytes): return web.Response(body=result, status=status_code) if isinstance(result, str): return web.Response(text=result, status=status_code) if result is None: return web.Response(body=b"", status=status_code) raise TypeError( f"Result should be None, string, bytes or StreamResponse. Got: {result}" ) return handle class HomeAssistantView: """Base view for all views.""" url: str | None = None extra_urls: list[str] = [] # Views inheriting from this class can override this requires_auth = True cors_allowed = False @staticmethod def context(request: web.Request) -> Context: """Generate a context from a request.""" if (user := request.get("hass_user")) is None: return Context() return Context(user_id=user.id) @staticmethod def json( result: Any, status_code: HTTPStatus | int = HTTPStatus.OK, headers: LooseHeaders | None = None, ) -> web.Response: """Return a JSON response.""" try: msg = json_bytes(result) except JSON_ENCODE_EXCEPTIONS as err: _LOGGER.error( "Unable to serialize to JSON. Bad data found at %s", format_unserializable_data( find_paths_unserializable_data(result, dump=json_dumps) ), ) raise HTTPInternalServerError from err response = web.Response( body=msg, content_type=CONTENT_TYPE_JSON, status=int(status_code), headers=headers, zlib_executor_size=32768, ) response.enable_compression() return response def json_message( self, message: str, status_code: HTTPStatus | int = HTTPStatus.OK, message_code: str | None = None, headers: LooseHeaders | None = None, ) -> web.Response: """Return a JSON message response.""" data = {"message": message} if message_code is not None: data["code"] = message_code return self.json(data, status_code, headers=headers) def register( self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher ) -> None: """Register the view with a router.""" assert self.url is not None, "No url set for view" urls = [self.url, *self.extra_urls] routes: list[AbstractRoute] = [] for method in ("get", "post", "delete", "put", "patch", "head", "options"): if not (handler := getattr(self, method, None)): continue handler = request_handler_factory(hass, self, handler) routes.extend(router.add_route(method, url, handler) for url in urls) # Use `get` because CORS middleware is not be loaded in emulated_hue if self.cors_allowed: allow_cors = app.get(KEY_ALLOW_ALL_CORS) else: allow_cors = app.get(KEY_ALLOW_CONFIGRED_CORS) if allow_cors: for route in routes: allow_cors(route)