From 6223b1f599941b8865000d5c122d9adfd04d2368 Mon Sep 17 00:00:00 2001 From: Robert Resch Date: Tue, 29 Aug 2023 15:57:54 +0200 Subject: [PATCH] Add ws endpoint "auth/delete_all_refresh_tokens" (#98976) Co-authored-by: Martin Hjelmare --- homeassistant/components/auth/__init__.py | 48 ++++++++++ tests/components/auth/test_init.py | 101 ++++++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/homeassistant/components/auth/__init__.py b/homeassistant/components/auth/__init__.py index deaf3b7892d1..78a1383012d6 100644 --- a/homeassistant/components/auth/__init__.py +++ b/homeassistant/components/auth/__init__.py @@ -124,9 +124,11 @@ as part of a config flow. """ from __future__ import annotations +import asyncio from collections.abc import Callable from datetime import datetime, timedelta from http import HTTPStatus +from logging import getLogger from typing import Any, cast import uuid @@ -138,6 +140,7 @@ from homeassistant.auth import InvalidAuthError from homeassistant.auth.models import ( TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials, + RefreshToken, User, ) from homeassistant.components import websocket_api @@ -188,6 +191,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: websocket_api.async_register_command(hass, websocket_create_long_lived_access_token) websocket_api.async_register_command(hass, websocket_refresh_tokens) websocket_api.async_register_command(hass, websocket_delete_refresh_token) + websocket_api.async_register_command(hass, websocket_delete_all_refresh_tokens) websocket_api.async_register_command(hass, websocket_sign_path) await login_flow.async_setup(hass, store_result) @@ -598,6 +602,50 @@ async def websocket_delete_refresh_token( connection.send_result(msg["id"], {}) +@websocket_api.websocket_command( + { + vol.Required("type"): "auth/delete_all_refresh_tokens", + } +) +@websocket_api.ws_require_user() +@websocket_api.async_response +async def websocket_delete_all_refresh_tokens( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle delete all refresh tokens request.""" + tasks = [] + current_refresh_token: RefreshToken + for token in connection.user.refresh_tokens.values(): + if token.id == connection.refresh_token_id: + # Skip the current refresh token as it has revoke_callback, + # which cancels/closes the connection. + # It will be removed after sending the result. + current_refresh_token = token + continue + tasks.append( + hass.async_create_task(hass.auth.async_remove_refresh_token(token)) + ) + + remove_failed = False + if tasks: + for result in await asyncio.gather(*tasks, return_exceptions=True): + if isinstance(result, Exception): + getLogger(__name__).exception( + "During refresh token removal, the following error occurred: %s", + result, + ) + remove_failed = True + + if remove_failed: + connection.send_error( + msg["id"], "token_removing_error", "During removal, an error was raised." + ) + else: + connection.send_result(msg["id"], {}) + + hass.async_create_task(hass.auth.async_remove_refresh_token(current_refresh_token)) + + @websocket_api.websocket_command( { vol.Required("type"): "auth/sign_path", diff --git a/tests/components/auth/test_init.py b/tests/components/auth/test_init.py index 923a633e76a9..a33ca702bcf7 100644 --- a/tests/components/auth/test_init.py +++ b/tests/components/auth/test_init.py @@ -1,6 +1,7 @@ """Integration tests for the auth component.""" from datetime import timedelta from http import HTTPStatus +import logging from unittest.mock import patch import pytest @@ -519,6 +520,106 @@ async def test_ws_delete_refresh_token( assert refresh_token is None +async def test_ws_delete_all_refresh_tokens_error( + hass: HomeAssistant, + hass_admin_user: MockUser, + hass_admin_credential: Credentials, + hass_ws_client: WebSocketGenerator, + hass_access_token: str, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test deleting all refresh tokens, where a revoke callback raises an error.""" + assert await async_setup_component(hass, "auth", {"http": {}}) + + # one token already exists + await hass.auth.async_create_refresh_token( + hass_admin_user, CLIENT_ID, credential=hass_admin_credential + ) + token = await hass.auth.async_create_refresh_token( + hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential + ) + + def cb(): + raise RuntimeError("I'm bad") + + hass.auth.async_register_revoke_token_callback(token.id, cb) + + ws_client = await hass_ws_client(hass, hass_access_token) + + # get all tokens + await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"}) + result = await ws_client.receive_json() + assert result["success"], result + + tokens = result["result"] + + await ws_client.send_json( + { + "id": 6, + "type": "auth/delete_all_refresh_tokens", + } + ) + + caplog.clear() + result = await ws_client.receive_json() + assert result, result["success"] is False + assert result["error"] == { + "code": "token_removing_error", + "message": "During removal, an error was raised.", + } + + assert ( + "homeassistant.components.auth", + logging.ERROR, + "During refresh token removal, the following error occurred: I'm bad", + ) in caplog.record_tuples + + for token in tokens: + refresh_token = await hass.auth.async_get_refresh_token(token["id"]) + assert refresh_token is None + + +async def test_ws_delete_all_refresh_tokens( + hass: HomeAssistant, + hass_admin_user: MockUser, + hass_admin_credential: Credentials, + hass_ws_client: WebSocketGenerator, + hass_access_token: str, +) -> None: + """Test deleting all refresh tokens.""" + assert await async_setup_component(hass, "auth", {"http": {}}) + + # one token already exists + await hass.auth.async_create_refresh_token( + hass_admin_user, CLIENT_ID, credential=hass_admin_credential + ) + await hass.auth.async_create_refresh_token( + hass_admin_user, CLIENT_ID + "_1", credential=hass_admin_credential + ) + + ws_client = await hass_ws_client(hass, hass_access_token) + + # get all tokens + await ws_client.send_json({"id": 5, "type": "auth/refresh_tokens"}) + result = await ws_client.receive_json() + assert result["success"], result + + tokens = result["result"] + + await ws_client.send_json( + { + "id": 6, + "type": "auth/delete_all_refresh_tokens", + } + ) + + result = await ws_client.receive_json() + assert result, result["success"] + for token in tokens: + refresh_token = await hass.auth.async_get_refresh_token(token["id"]) + assert refresh_token is None + + async def test_ws_sign_path( hass: HomeAssistant, hass_ws_client: WebSocketGenerator, hass_access_token: str ) -> None: