Improve http decorator typing (#75541)

This commit is contained in:
Marc Mueller 2022-07-21 13:07:42 +02:00 committed by GitHub
parent 1d7d2875e1
commit b1ed1543c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 37 additions and 22 deletions

View File

@ -257,7 +257,7 @@ class LoginFlowResourceView(LoginFlowBaseView):
@RequestDataValidator(vol.Schema({"client_id": str}, extra=vol.ALLOW_EXTRA))
@log_invalid_auth
async def post(self, request, flow_id, data):
async def post(self, request, data, flow_id):
"""Handle progressing a login flow request."""
client_id = data.pop("client_id")

View File

@ -2,17 +2,18 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from contextlib import suppress
from datetime import datetime
from http import HTTPStatus
from ipaddress import IPv4Address, IPv6Address, ip_address
import logging
from socket import gethostbyaddr, herror
from typing import Any, Final
from typing import Any, Final, TypeVar
from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web import Application, Request, Response, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol
from homeassistant.components import persistent_notification
@ -24,6 +25,9 @@ from homeassistant.util import dt as dt_util, yaml
from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_LOGGER: Final = logging.getLogger(__name__)
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
@ -82,13 +86,13 @@ async def ban_middleware(
def log_invalid_auth(
func: Callable[..., Awaitable[StreamResponse]]
) -> Callable[..., Awaitable[StreamResponse]]:
func: Callable[Concatenate[_HassViewT, Request, _P], Awaitable[Response]]
) -> Callable[Concatenate[_HassViewT, Request, _P], Coroutine[Any, Any, Response]]:
"""Decorate function to handle invalid auth or failed login attempts."""
async def handle_req(
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
) -> StreamResponse:
view: _HassViewT, request: Request, *args: _P.args, **kwargs: _P.kwargs
) -> Response:
"""Try to log failed login attempts if response status >= BAD_REQUEST."""
resp = await func(view, request, *args, **kwargs)
if resp.status >= HTTPStatus.BAD_REQUEST:

View File

@ -1,17 +1,21 @@
"""Decorator for view methods to help with data validation."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from functools import wraps
from http import HTTPStatus
import logging
from typing import Any
from typing import Any, TypeVar
from aiohttp import web
from typing_extensions import Concatenate, ParamSpec
import voluptuous as vol
from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
_P = ParamSpec("_P")
_LOGGER = logging.getLogger(__name__)
@ -33,33 +37,40 @@ class RequestDataValidator:
self._allow_empty = allow_empty
def __call__(
self, method: Callable[..., Awaitable[web.StreamResponse]]
) -> Callable:
self,
method: Callable[
Concatenate[_HassViewT, web.Request, dict[str, Any], _P],
Awaitable[web.Response],
],
) -> Callable[
Concatenate[_HassViewT, web.Request, _P],
Coroutine[Any, Any, web.Response],
]:
"""Decorate a function."""
@wraps(method)
async def wrapper(
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
) -> web.StreamResponse:
view: _HassViewT, request: web.Request, *args: _P.args, **kwargs: _P.kwargs
) -> web.Response:
"""Wrap a request handler with data validation."""
data = None
raw_data = None
try:
data = await request.json()
raw_data = await request.json()
except ValueError:
if not self._allow_empty or (await request.content.read()) != b"":
_LOGGER.error("Invalid JSON received")
return view.json_message("Invalid JSON.", HTTPStatus.BAD_REQUEST)
data = {}
raw_data = {}
try:
kwargs["data"] = self._schema(data)
data: dict[str, Any] = self._schema(raw_data)
except vol.Invalid as err:
_LOGGER.error("Data does not match schema: %s", err)
return view.json_message(
f"Message format incorrect: {err}", HTTPStatus.BAD_REQUEST
)
result = await method(view, request, *args, **kwargs)
result = await method(view, request, data, *args, **kwargs)
return result
return wrapper

View File

@ -113,7 +113,7 @@ class RepairsFlowIndexView(FlowManagerIndexView):
result = self._prepare_result_json(result)
return self.json(result) # pylint: disable=arguments-differ
return self.json(result)
class RepairsFlowResourceView(FlowManagerResourceView):
@ -136,4 +136,4 @@ class RepairsFlowResourceView(FlowManagerResourceView):
raise Unauthorized(permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter
return await super().post(request, flow_id) # type: ignore[no-any-return]
return await super().post(request, flow_id)

View File

@ -102,7 +102,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(
self, request: web.Request, flow_id: str, data: dict[str, Any]
self, request: web.Request, data: dict[str, Any], flow_id: str
) -> web.Response:
"""Handle a POST request."""
try: