1
mirror of https://github.com/home-assistant/core synced 2024-07-12 07:21:24 +02:00

Add ws endpoint "auth/delete_all_refresh_tokens" (#98976)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Resch 2023-08-29 15:57:54 +02:00 committed by GitHub
parent 691bbedfc8
commit 6223b1f599
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 149 additions and 0 deletions

View File

@ -124,9 +124,11 @@ as part of a config flow.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from http import HTTPStatus from http import HTTPStatus
from logging import getLogger
from typing import Any, cast from typing import Any, cast
import uuid import uuid
@ -138,6 +140,7 @@ from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import ( from homeassistant.auth.models import (
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
Credentials, Credentials,
RefreshToken,
User, User,
) )
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
@ -188,6 +191,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
websocket_api.async_register_command(hass, websocket_create_long_lived_access_token) websocket_api.async_register_command(hass, websocket_create_long_lived_access_token)
websocket_api.async_register_command(hass, websocket_refresh_tokens) websocket_api.async_register_command(hass, websocket_refresh_tokens)
websocket_api.async_register_command(hass, websocket_delete_refresh_token) websocket_api.async_register_command(hass, websocket_delete_refresh_token)
websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens)
websocket_api.async_register_command(hass, websocket_sign_path) websocket_api.async_register_command(hass, websocket_sign_path)
await login_flow.async_setup(hass, store_result) await login_flow.async_setup(hass, store_result)
@ -598,6 +602,50 @@ async def websocket_delete_refresh_token(
connection.send_result(msg["id"], {}) connection.send_result(msg["id"], {})
@websocket_api.websocket_command(
{
vol.Required("type"): "auth/delete_all_refresh_tokens",
}
)
@websocket_api.ws_require_user()
@websocket_api.async_response
async def websocket_delete_all_refresh_tokens(
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
) -> None:
"""Handle delete all refresh tokens request."""
tasks = []
current_refresh_token: RefreshToken
for token in connection.user.refresh_tokens.values():
if token.id == connection.refresh_token_id:
# Skip the current refresh token as it has revoke_callback,
# which cancels/closes the connection.
# It will be removed after sending the result.
current_refresh_token = token
continue
tasks.append(
hass.async_create_task(hass.auth.async_remove_refresh_token(token))
)
remove_failed = False
if tasks:
for result in await asyncio.gather(*tasks, return_exceptions=True):
if isinstance(result, Exception):
getLogger(__name__).exception(
"During refresh token removal, the following error occurred: %s",
result,
)
remove_failed = True
if remove_failed:
connection.send_error(
msg["id"], "token_removing_error", "During removal, an error was raised."
)
else:
connection.send_result(msg["id"], {})
hass.async_create_task(hass.auth.async_remove_refresh_token(current_refresh_token))
@websocket_api.websocket_command( @websocket_api.websocket_command(
{ {
vol.Required("type"): "auth/sign_path", vol.Required("type"): "auth/sign_path",

View File

@ -1,6 +1,7 @@
"""Integration tests for the auth component.""" """Integration tests for the auth component."""
from datetime import timedelta from datetime import timedelta
from http import HTTPStatus from http import HTTPStatus
import logging
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -519,6 +520,106 @@ async def test_ws_delete_refresh_token(
assert refresh_token is None assert refresh_token is None
async def test_ws_delete_all_refresh_tokens_error(
hass: HomeAssistant,
hass_admin_user: MockUser,
hass_admin_credential: Credentials,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test deleting all refresh tokens, where a revoke callback raises an error."""
assert await async_setup_component(hass, "auth", {"http": {}})
# one token already exists
await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
token = await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential
)
def cb():
raise RuntimeError("I'm bad")
hass.auth.async_register_revoke_token_callback(token.id, cb)
ws_client = await hass_ws_client(hass, hass_access_token)
# get all tokens
await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"})
result = await ws_client.receive_json()
assert result["success"], result
tokens = result["result"]
await ws_client.send_json(
{
"id": 6,
"type": "auth/delete_all_refresh_tokens",
}
)
caplog.clear()
result = await ws_client.receive_json()
assert result, result["success"] is False
assert result["error"] == {
"code": "token_removing_error",
"message": "During removal, an error was raised.",
}
assert (
"homeassistant.components.auth",
logging.ERROR,
"During refresh token removal, the following error occurred: I'm bad",
) in caplog.record_tuples
for token in tokens:
refresh_token = await hass.auth.async_get_refresh_token(token["id"])
assert refresh_token is None
async def test_ws_delete_all_refresh_tokens(
hass: HomeAssistant,
hass_admin_user: MockUser,
hass_admin_credential: Credentials,
hass_ws_client: WebSocketGenerator,
hass_access_token: str,
) -> None:
"""Test deleting all refresh tokens."""
assert await async_setup_component(hass, "auth", {"http": {}})
# one token already exists
await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential
)
ws_client = await hass_ws_client(hass, hass_access_token)
# get all tokens
await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"})
result = await ws_client.receive_json()
assert result["success"], result
tokens = result["result"]
await ws_client.send_json(
{
"id": 6,
"type": "auth/delete_all_refresh_tokens",
}
)
result = await ws_client.receive_json()
assert result, result["success"]
for token in tokens:
refresh_token = await hass.auth.async_get_refresh_token(token["id"])
assert refresh_token is None
async def test_ws_sign_path( async def test_ws_sign_path(
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_access_token: str hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_access_token: str
) -> None: ) -> None: