mirror of
https://github.com/home-assistant/core
synced 2024-08-02 23:40:32 +02:00
parent
7e81c6a591
commit
28c07f5c43
@ -298,21 +298,24 @@ class HomeAssistantHTTP:
|
||||
# Should be instance of aiohttp.web_exceptions._HTTPMove.
|
||||
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
|
||||
|
||||
self.app.router.add_route("GET", url, redirect)
|
||||
self.app["allow_configured_cors"](
|
||||
self.app.router.add_route("GET", url, redirect)
|
||||
)
|
||||
|
||||
def register_static_path(
|
||||
self, url_path: str, path: str, cache_headers: bool = True
|
||||
) -> web.FileResponse | None:
|
||||
) -> None:
|
||||
"""Register a folder or file to serve as a static path."""
|
||||
if os.path.isdir(path):
|
||||
if cache_headers:
|
||||
resource: type[
|
||||
CachingStaticResource | web.StaticResource
|
||||
] = CachingStaticResource
|
||||
resource: CachingStaticResource | web.StaticResource = (
|
||||
CachingStaticResource(url_path, path)
|
||||
)
|
||||
else:
|
||||
resource = web.StaticResource
|
||||
self.app.router.register_resource(resource(url_path, path))
|
||||
return None
|
||||
resource = web.StaticResource(url_path, path)
|
||||
self.app.router.register_resource(resource)
|
||||
self.app["allow_configured_cors"](resource)
|
||||
return
|
||||
|
||||
async def serve_file(request: web.Request) -> web.FileResponse:
|
||||
"""Serve file from disk."""
|
||||
@ -320,8 +323,9 @@ class HomeAssistantHTTP:
|
||||
return web.FileResponse(path, headers=CACHE_HEADERS)
|
||||
return web.FileResponse(path)
|
||||
|
||||
self.app.router.add_route("GET", url_path, serve_file)
|
||||
return None
|
||||
self.app["allow_configured_cors"](
|
||||
self.app.router.add_route("GET", url_path, serve_file)
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the aiohttp server."""
|
||||
|
@ -70,7 +70,7 @@ def setup_cors(app: Application, origins: list[str]) -> None:
|
||||
cors.add(route, config)
|
||||
cors_added.add(path_str)
|
||||
|
||||
app["allow_cors"] = lambda route: _allow_cors(
|
||||
app["allow_all_cors"] = lambda route: _allow_cors(
|
||||
route,
|
||||
{
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
@ -79,12 +79,7 @@ def setup_cors(app: Application, origins: list[str]) -> None:
|
||||
},
|
||||
)
|
||||
|
||||
if not origins:
|
||||
return
|
||||
|
||||
async def cors_startup(app: Application) -> None:
|
||||
"""Initialize CORS when app starts up."""
|
||||
for resource in list(app.router.resources()):
|
||||
_allow_cors(resource)
|
||||
|
||||
app.on_startup.append(cors_startup)
|
||||
if origins:
|
||||
app["allow_configured_cors"] = _allow_cors
|
||||
else:
|
||||
app["allow_configured_cors"] = lambda _: None
|
||||
|
@ -94,11 +94,11 @@ class HomeAssistantView:
|
||||
for url in urls:
|
||||
routes.append(router.add_route(method, url, handler))
|
||||
|
||||
if not self.cors_allowed:
|
||||
return
|
||||
|
||||
allow_cors = (
|
||||
app["allow_all_cors"] if self.cors_allowed else app["allow_configured_cors"]
|
||||
)
|
||||
for route in routes:
|
||||
app["allow_cors"](route)
|
||||
allow_cors(route)
|
||||
|
||||
|
||||
def request_handler_factory(
|
||||
|
@ -52,8 +52,8 @@ async def mock_handler(request):
|
||||
def client(loop, aiohttp_client):
|
||||
"""Fixture to set up a web.Application."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/", mock_handler)
|
||||
setup_cors(app, [TRUSTED_ORIGIN])
|
||||
app["allow_configured_cors"](app.router.add_get("/", mock_handler))
|
||||
return loop.run_until_complete(aiohttp_client(app))
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ async def get_client(aiohttp_client, validator):
|
||||
"""Generate a client that hits a view decorated with validator."""
|
||||
app = web.Application()
|
||||
app["hass"] = Mock(is_stopping=False)
|
||||
app["allow_configured_cors"] = lambda _: None
|
||||
|
||||
class TestView(HomeAssistantView):
|
||||
url = "/"
|
||||
|
Loading…
Reference in New Issue
Block a user