Add strict connection for cloud (#115814)

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Resch 2024-04-24 09:57:38 +02:00 committed by GitHub
parent b520efb87a
commit a4829330f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 644 additions and 58 deletions

View File

@ -7,11 +7,14 @@ from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta
from enum import Enum
from typing import cast
from urllib.parse import quote_plus, urljoin
from hass_nabucasa import Cloud
import voluptuous as vol
from homeassistant.components import alexa, google_assistant
from homeassistant.components import alexa, google_assistant, http
from homeassistant.components.auth import STRICT_CONNECTION_URL
from homeassistant.components.http.auth import async_sign_path
from homeassistant.config_entries import SOURCE_SYSTEM, ConfigEntry
from homeassistant.const import (
CONF_DESCRIPTION,
@ -21,8 +24,21 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
Platform,
)
from homeassistant.core import Event, HassJob, HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.core import (
Event,
HassJob,
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
callback,
)
from homeassistant.exceptions import (
HomeAssistantError,
ServiceValidationError,
Unauthorized,
UnknownUser,
)
from homeassistant.helpers import config_validation as cv, entityfilter
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.discovery import async_load_platform
@ -31,6 +47,7 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_send,
)
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
@ -265,18 +282,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _shutdown)
_remote_handle_prefs_updated(cloud)
async def _service_handler(service: ServiceCall) -> None:
"""Handle service for cloud."""
if service.service == SERVICE_REMOTE_CONNECT:
await prefs.async_update(remote_enabled=True)
elif service.service == SERVICE_REMOTE_DISCONNECT:
await prefs.async_update(remote_enabled=False)
async_register_admin_service(hass, DOMAIN, SERVICE_REMOTE_CONNECT, _service_handler)
async_register_admin_service(
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
)
_setup_services(hass, prefs)
async def async_startup_repairs(_: datetime) -> None:
"""Create repair issues after startup."""
@ -395,3 +401,67 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
@callback
def _setup_services(hass: HomeAssistant, prefs: CloudPreferences) -> None:
"""Set up services for cloud component."""
async def _service_handler(service: ServiceCall) -> None:
"""Handle service for cloud."""
if service.service == SERVICE_REMOTE_CONNECT:
await prefs.async_update(remote_enabled=True)
elif service.service == SERVICE_REMOTE_DISCONNECT:
await prefs.async_update(remote_enabled=False)
async_register_admin_service(hass, DOMAIN, SERVICE_REMOTE_CONNECT, _service_handler)
async_register_admin_service(
hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler
)
async def create_temporary_strict_connection_url(
call: ServiceCall,
) -> ServiceResponse:
"""Create a strict connection url and return it."""
# Copied form homeassistant/helpers/service.py#_async_admin_handler
# as the helper supports no responses yet
if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id)
if user is None:
raise UnknownUser(context=call.context)
if not user.is_admin:
raise Unauthorized(context=call.context)
if prefs.strict_connection is http.const.StrictConnectionMode.DISABLED:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="strict_connection_not_enabled",
)
try:
url = get_url(hass, require_cloud=True)
except NoURLAvailableError as ex:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="no_url_available",
) from ex
path = async_sign_path(
hass,
STRICT_CONNECTION_URL,
timedelta(hours=1),
use_content_user=True,
)
url = urljoin(url, path)
return {
"url": f"https://login.home-assistant.io?u={quote_plus(url)}",
"direct_url": url,
}
hass.services.async_register(
DOMAIN,
"create_temporary_strict_connection_url",
create_temporary_strict_connection_url,
supports_response=SupportsResponse.ONLY,
)

View File

@ -250,6 +250,7 @@ class CloudClient(Interface):
"enabled": self._prefs.remote_enabled,
"instance_domain": self.cloud.remote.instance_domain,
"alias": self.cloud.remote.alias,
"strict_connection": self._prefs.strict_connection,
},
"version": HA_VERSION,
"instance_id": self.prefs.instance_id,

View File

@ -33,6 +33,7 @@ PREF_GOOGLE_SETTINGS_VERSION = "google_settings_version"
PREF_TTS_DEFAULT_VOICE = "tts_default_voice"
PREF_GOOGLE_CONNECTED = "google_connected"
PREF_REMOTE_ALLOW_REMOTE_ENABLE = "remote_allow_remote_enable"
PREF_STRICT_CONNECTION = "strict_connection"
DEFAULT_TTS_DEFAULT_VOICE = ("en-US", "JennyNeural")
DEFAULT_DISABLE_2FA = False
DEFAULT_ALEXA_REPORT_STATE = True

View File

@ -19,7 +19,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
from hass_nabucasa.voice import TTS_VOICES
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components import http, websocket_api
from homeassistant.components.alexa import (
entities as alexa_entities,
errors as alexa_errors,
@ -46,6 +46,7 @@ from .const import (
PREF_GOOGLE_REPORT_STATE,
PREF_GOOGLE_SECURE_DEVICES_PIN,
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
PREF_STRICT_CONNECTION,
PREF_TTS_DEFAULT_VOICE,
REQUEST_TIMEOUT,
)
@ -452,6 +453,9 @@ def validate_language_voice(value: tuple[str, str]) -> tuple[str, str]:
vol.Coerce(tuple), validate_language_voice
),
vol.Optional(PREF_REMOTE_ALLOW_REMOTE_ENABLE): bool,
vol.Optional(PREF_STRICT_CONNECTION): vol.Coerce(
http.const.StrictConnectionMode
),
}
)
@websocket_api.async_response

View File

@ -1,5 +1,6 @@
{
"services": {
"create_temporary_strict_connection_url": "mdi:login-variant",
"remote_connect": "mdi:cloud",
"remote_disconnect": "mdi:cloud-off"
}

View File

@ -3,7 +3,7 @@
"name": "Home Assistant Cloud",
"after_dependencies": ["assist_pipeline", "google_assistant", "alexa"],
"codeowners": ["@home-assistant/cloud"],
"dependencies": ["http", "repairs", "webhook"],
"dependencies": ["auth", "http", "repairs", "webhook"],
"documentation": "https://www.home-assistant.io/integrations/cloud",
"integration_type": "system",
"iot_class": "cloud_push",

View File

@ -10,7 +10,7 @@ from hass_nabucasa.voice import MAP_VOICE
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.auth.models import User
from homeassistant.components import webhook
from homeassistant.components import http, webhook
from homeassistant.components.google_assistant.http import (
async_get_users as async_get_google_assistant_users,
)
@ -44,6 +44,7 @@ from .const import (
PREF_INSTANCE_ID,
PREF_REMOTE_ALLOW_REMOTE_ENABLE,
PREF_REMOTE_DOMAIN,
PREF_STRICT_CONNECTION,
PREF_TTS_DEFAULT_VOICE,
PREF_USERNAME,
)
@ -176,6 +177,7 @@ class CloudPreferences:
google_settings_version: int | UndefinedType = UNDEFINED,
google_connected: bool | UndefinedType = UNDEFINED,
remote_allow_remote_enable: bool | UndefinedType = UNDEFINED,
strict_connection: http.const.StrictConnectionMode | UndefinedType = UNDEFINED,
) -> None:
"""Update user preferences."""
prefs = {**self._prefs}
@ -195,6 +197,7 @@ class CloudPreferences:
(PREF_REMOTE_DOMAIN, remote_domain),
(PREF_GOOGLE_CONNECTED, google_connected),
(PREF_REMOTE_ALLOW_REMOTE_ENABLE, remote_allow_remote_enable),
(PREF_STRICT_CONNECTION, strict_connection),
):
if value is not UNDEFINED:
prefs[key] = value
@ -242,6 +245,7 @@ class CloudPreferences:
PREF_GOOGLE_SECURE_DEVICES_PIN: self.google_secure_devices_pin,
PREF_REMOTE_ALLOW_REMOTE_ENABLE: self.remote_allow_remote_enable,
PREF_TTS_DEFAULT_VOICE: self.tts_default_voice,
PREF_STRICT_CONNECTION: self.strict_connection,
}
@property
@ -358,6 +362,17 @@ class CloudPreferences:
"""
return self._prefs.get(PREF_TTS_DEFAULT_VOICE, DEFAULT_TTS_DEFAULT_VOICE) # type: ignore[no-any-return]
@property
def strict_connection(self) -> http.const.StrictConnectionMode:
"""Return the strict connection mode."""
mode = self._prefs.get(
PREF_STRICT_CONNECTION, http.const.StrictConnectionMode.DISABLED
)
if not isinstance(mode, http.const.StrictConnectionMode):
mode = http.const.StrictConnectionMode(mode)
return mode # type: ignore[no-any-return]
async def get_cloud_user(self) -> str:
"""Return ID of Home Assistant Cloud system user."""
user = await self._load_cloud_user()
@ -415,4 +430,5 @@ class CloudPreferences:
PREF_REMOTE_DOMAIN: None,
PREF_REMOTE_ALLOW_REMOTE_ENABLE: True,
PREF_USERNAME: username,
PREF_STRICT_CONNECTION: http.const.StrictConnectionMode.DISABLED,
}

View File

@ -5,6 +5,14 @@
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
}
},
"exceptions": {
"strict_connection_not_enabled": {
"message": "Strict connection is not enabled for cloud requests"
},
"no_url_available": {
"message": "No cloud URL available.\nPlease mark sure you have a working Remote UI."
}
},
"system_health": {
"info": {
"can_reach_cert_server": "Reach Certificate Server",
@ -73,6 +81,10 @@
}
},
"services": {
"create_temporary_strict_connection_url": {
"name": "Create a temporary strict connection URL",
"description": "Create a temporary strict connection URL, which can be used to login on another device."
},
"remote_connect": {
"name": "Remote connect",
"description": "Makes the instance UI accessible from outside of the local network by using Home Assistant Cloud."

View File

@ -0,0 +1,15 @@
"""Cloud util functions."""
from hass_nabucasa import Cloud
from homeassistant.components import http
from homeassistant.core import HomeAssistant
from .client import CloudClient
from .const import DOMAIN
def get_strict_connection_mode(hass: HomeAssistant) -> http.const.StrictConnectionMode:
"""Get the strict connection mode."""
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
return cloud.client.prefs.strict_connection

View File

@ -69,6 +69,7 @@ from homeassistant.util.json import json_loads
from .auth import async_setup_auth, async_sign_path
from .ban import setup_bans
from .const import ( # noqa: F401
DOMAIN,
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
StrictConnectionMode,
@ -82,8 +83,6 @@ from .security_filter import setup_security_filter
from .static import CACHE_HEADERS, CachingStaticResource
from .web_runner import HomeAssistantTCPSite
DOMAIN: Final = "http"
CONF_SERVER_HOST: Final = "server_host"
CONF_SERVER_PORT: Final = "server_port"
CONF_BASE_URL: Final = "base_url"
@ -149,7 +148,7 @@ HTTP_SCHEMA: Final = vol.All(
vol.Optional(CONF_USE_X_FRAME_OPTIONS, default=True): cv.boolean,
vol.Optional(
CONF_STRICT_CONNECTION, default=StrictConnectionMode.DISABLED
): vol.In([e.value for e in StrictConnectionMode]),
): vol.Coerce(StrictConnectionMode),
}
),
)
@ -628,7 +627,9 @@ def _setup_services(hass: HomeAssistant, conf: ConfData) -> None:
)
try:
url = get_url(hass, prefer_external=True, allow_internal=False)
url = get_url(
hass, prefer_external=True, allow_internal=False, allow_cloud=False
)
except NoURLAvailableError as ex:
raise ServiceValidationError(
translation_domain=DOMAIN,

View File

@ -25,6 +25,7 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import singleton
from homeassistant.helpers.http import current_request
from homeassistant.helpers.json import json_bytes
from homeassistant.helpers.network import is_cloud_connection
@ -32,6 +33,7 @@ from homeassistant.helpers.storage import Store
from homeassistant.util.network import is_local
from .const import (
DOMAIN,
KEY_AUTHENTICATED,
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
@ -50,8 +52,9 @@ STORAGE_VERSION = 1
STORAGE_KEY = "http.auth"
CONTENT_USER_NAME = "Home Assistant Content"
STRICT_CONNECTION_EXCLUDED_PATH = "/api/webhook/"
STRICT_CONNECTION_STATIC_PAGE_NAME = "strict_connection_static_page.html"
STRICT_CONNECTION_STATIC_PAGE = os.path.join(
os.path.dirname(__file__), "strict_connection_static_page.html"
os.path.dirname(__file__), STRICT_CONNECTION_STATIC_PAGE_NAME
)
@ -156,16 +159,10 @@ async def async_setup_auth(
await store.async_save(data)
hass.data[STORAGE_KEY] = refresh_token.id
strict_connection_static_file_content = None
if strict_connection_mode_non_cloud is StrictConnectionMode.STATIC_PAGE:
def read_static_page() -> str:
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
return file.read()
strict_connection_static_file_content = await hass.async_add_executor_job(
read_static_page
)
# Load the static page content on setup
await _read_strict_connection_static_page(hass)
@callback
def async_validate_auth_header(request: Request) -> bool:
@ -255,21 +252,36 @@ async def async_setup_auth(
authenticated = True
auth_type = "signed request"
if (
not authenticated
and strict_connection_mode_non_cloud is not StrictConnectionMode.DISABLED
and not request.path.startswith(STRICT_CONNECTION_EXCLUDED_PATH)
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
request
)
and (
resp := _async_perform_action_on_non_local(
request, strict_connection_static_file_content
)
)
is not None
if not authenticated and not request.path.startswith(
STRICT_CONNECTION_EXCLUDED_PATH
):
return resp
strict_connection_mode = strict_connection_mode_non_cloud
strict_connection_func = (
_async_perform_strict_connection_action_on_non_local
)
if is_cloud_connection(hass):
from homeassistant.components.cloud.util import ( # pylint: disable=import-outside-toplevel
get_strict_connection_mode,
)
strict_connection_mode = get_strict_connection_mode(hass)
strict_connection_func = _async_perform_strict_connection_action
if (
strict_connection_mode is not StrictConnectionMode.DISABLED
and not await hass.auth.session.async_validate_request_for_strict_connection_session(
request
)
and (
resp := await strict_connection_func(
hass,
request,
strict_connection_mode is StrictConnectionMode.STATIC_PAGE,
)
)
is not None
):
return resp
if authenticated and _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
@ -286,17 +298,17 @@ async def async_setup_auth(
app.middlewares.append(auth_middleware)
@callback
def _async_perform_action_on_non_local(
async def _async_perform_strict_connection_action_on_non_local(
hass: HomeAssistant,
request: Request,
strict_connection_static_file_content: str | None,
static_page: bool,
) -> StreamResponse | None:
"""Perform strict connection mode action if the request is not local.
The function does the following:
- Try to get the IP address of the request. If it fails, assume it's not local
- If the request is local, return None (allow the request to continue)
- If strict_connection_static_file_content is set, return a response with the content
- If static_page is True, return a response with the content
- Otherwise close the connection and raise an exception
"""
try:
@ -308,10 +320,25 @@ def _async_perform_action_on_non_local(
if ip_address_ and is_local(ip_address_):
return None
_LOGGER.debug("Perform strict connection action for %s", ip_address_)
if strict_connection_static_file_content:
return await _async_perform_strict_connection_action(hass, request, static_page)
async def _async_perform_strict_connection_action(
hass: HomeAssistant,
request: Request,
static_page: bool,
) -> StreamResponse | None:
"""Perform strict connection mode action.
The function does the following:
- If static_page is True, return a response with the content
- Otherwise close the connection and raise an exception
"""
_LOGGER.debug("Perform strict connection action for %s", request.remote)
if static_page:
return Response(
text=strict_connection_static_file_content,
text=await _read_strict_connection_static_page(hass),
content_type="text/html",
status=HTTPStatus.IM_A_TEAPOT,
)
@ -322,3 +349,14 @@ def _async_perform_action_on_non_local(
# We need to raise an exception to stop processing the request
raise HTTPBadRequest
@singleton.singleton(f"{DOMAIN}_{STRICT_CONNECTION_STATIC_PAGE_NAME}")
async def _read_strict_connection_static_page(hass: HomeAssistant) -> str:
"""Read the strict connection static page from disk via executor."""
def read_static_page() -> str:
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
return file.read()
return await hass.async_add_executor_job(read_static_page)

View File

@ -5,6 +5,8 @@ from typing import Final
from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
DOMAIN: Final = "http"
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"

View File

@ -122,6 +122,7 @@ def get_url(
require_current_request: bool = False,
require_ssl: bool = False,
require_standard_port: bool = False,
require_cloud: bool = False,
allow_internal: bool = True,
allow_external: bool = True,
allow_cloud: bool = True,
@ -145,7 +146,7 @@ def get_url(
# Try finding an URL in the order specified
for url_type in order:
if allow_internal and url_type == TYPE_URL_INTERNAL:
if allow_internal and url_type == TYPE_URL_INTERNAL and not require_cloud:
with suppress(NoURLAvailableError):
return _get_internal_url(
hass,
@ -155,7 +156,7 @@ def get_url(
require_standard_port=require_standard_port,
)
if allow_external and url_type == TYPE_URL_EXTERNAL:
if require_cloud or (allow_external and url_type == TYPE_URL_EXTERNAL):
with suppress(NoURLAvailableError):
return _get_external_url(
hass,
@ -165,7 +166,10 @@ def get_url(
require_current_request=require_current_request,
require_ssl=require_ssl,
require_standard_port=require_standard_port,
require_cloud=require_cloud,
)
if require_cloud:
raise NoURLAvailableError
# For current request, we accept loopback interfaces (e.g., 127.0.0.1),
# the Supervisor hostname and localhost transparently
@ -263,8 +267,12 @@ def _get_external_url(
require_current_request: bool = False,
require_ssl: bool = False,
require_standard_port: bool = False,
require_cloud: bool = False,
) -> str:
"""Get external URL of this instance."""
if require_cloud:
return _get_cloud_url(hass, require_current_request=require_current_request)
if prefer_cloud and allow_cloud:
with suppress(NoURLAvailableError):
return _get_cloud_url(hass)

View File

@ -152,6 +152,7 @@ IGNORE_VIOLATIONS = {
("demo", "manual"),
# This would be a circular dep
("http", "network"),
("http", "cloud"),
# This would be a circular dep
("zha", "homeassistant_hardware"),
("zha", "homeassistant_sky_connect"),

View File

@ -24,6 +24,7 @@ from homeassistant.components.homeassistant.exposed_entities import (
ExposedEntities,
async_expose_entity,
)
from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.const import CONTENT_TYPE_JSON, __version__ as HA_VERSION
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import entity_registry as er
@ -387,6 +388,7 @@ async def test_cloud_connection_info(hass: HomeAssistant) -> None:
"connected": False,
"enabled": False,
"instance_domain": None,
"strict_connection": StrictConnectionMode.DISABLED,
},
"version": HA_VERSION,
}

View File

@ -19,6 +19,7 @@ from homeassistant.components.assist_pipeline.pipeline import STORAGE_KEY
from homeassistant.components.cloud.const import DEFAULT_EXPOSED_DOMAINS, DOMAIN
from homeassistant.components.google_assistant.helpers import GoogleEntity
from homeassistant.components.homeassistant import exposed_entities
from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.components.websocket_api import ERR_INVALID_FORMAT
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import entity_registry as er
@ -782,6 +783,7 @@ async def test_websocket_status(
"google_report_state": True,
"remote_allow_remote_enable": True,
"remote_enabled": False,
"strict_connection": "disabled",
"tts_default_voice": ["en-US", "JennyNeural"],
},
"alexa_entities": {
@ -901,6 +903,7 @@ async def test_websocket_update_preferences(
assert cloud.client.prefs.alexa_enabled
assert cloud.client.prefs.google_secure_devices_pin is None
assert cloud.client.prefs.remote_allow_remote_enable is True
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DISABLED
client = await hass_ws_client(hass)
@ -912,6 +915,7 @@ async def test_websocket_update_preferences(
"google_secure_devices_pin": "1234",
"tts_default_voice": ["en-GB", "RyanNeural"],
"remote_allow_remote_enable": False,
"strict_connection": StrictConnectionMode.DROP_CONNECTION,
}
)
response = await client.receive_json()
@ -922,6 +926,7 @@ async def test_websocket_update_preferences(
assert cloud.client.prefs.google_secure_devices_pin == "1234"
assert cloud.client.prefs.remote_allow_remote_enable is False
assert cloud.client.prefs.tts_default_voice == ("en-GB", "RyanNeural")
assert cloud.client.prefs.strict_connection is StrictConnectionMode.DROP_CONNECTION
@pytest.mark.parametrize(

View File

@ -3,6 +3,7 @@
from collections.abc import Callable, Coroutine
from typing import Any
from unittest.mock import MagicMock, patch
from urllib.parse import quote_plus
from hass_nabucasa import Cloud
import pytest
@ -13,11 +14,16 @@ from homeassistant.components.cloud import (
CloudNotConnected,
async_get_or_create_cloudhook,
)
from homeassistant.components.cloud.const import DOMAIN, PREF_CLOUDHOOKS
from homeassistant.components.cloud.const import (
DOMAIN,
PREF_CLOUDHOOKS,
PREF_STRICT_CONNECTION,
)
from homeassistant.components.cloud.prefs import STORAGE_KEY
from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import Unauthorized
from homeassistant.exceptions import ServiceValidationError, Unauthorized
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry, MockUser
@ -295,3 +301,77 @@ async def test_cloud_logout(
await hass.async_block_till_done()
assert cloud.is_logged_in is False
async def test_service_create_temporary_strict_connection_url_strict_connection_disabled(
hass: HomeAssistant,
) -> None:
"""Test service create_temporary_strict_connection_url with strict_connection not enabled."""
mock_config_entry = MockConfigEntry(domain=DOMAIN)
mock_config_entry.add_to_hass(hass)
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
with pytest.raises(
ServiceValidationError,
match="Strict connection is not enabled for cloud requests",
):
await hass.services.async_call(
cloud.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
@pytest.mark.parametrize(
("mode"),
[
StrictConnectionMode.DROP_CONNECTION,
StrictConnectionMode.STATIC_PAGE,
],
)
async def test_service_create_temporary_strict_connection(
hass: HomeAssistant,
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
mode: StrictConnectionMode,
) -> None:
"""Test service create_temporary_strict_connection_url."""
mock_config_entry = MockConfigEntry(domain=DOMAIN)
mock_config_entry.add_to_hass(hass)
assert await async_setup_component(hass, DOMAIN, {"cloud": {}})
await hass.async_block_till_done()
await set_cloud_prefs(
{
PREF_STRICT_CONNECTION: mode,
}
)
# No cloud url set
with pytest.raises(ServiceValidationError, match="No cloud URL available"):
await hass.services.async_call(
cloud.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
# Patch cloud url
url = "https://example.com"
with patch(
"homeassistant.helpers.network._get_cloud_url",
return_value=url,
):
response = await hass.services.async_call(
cloud.DOMAIN,
"create_temporary_strict_connection_url",
blocking=True,
return_response=True,
)
assert isinstance(response, dict)
direct_url_prefix = f"{url}/auth/strict_connection/temp_token?authSig="
assert response.pop("direct_url").startswith(direct_url_prefix)
assert response.pop("url").startswith(
f"https://login.home-assistant.io?u={quote_plus(direct_url_prefix)}"
)
assert response == {} # No more keys in response

View File

@ -6,8 +6,13 @@ from unittest.mock import ANY, MagicMock, patch
import pytest
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.components.cloud.const import DOMAIN, PREF_TTS_DEFAULT_VOICE
from homeassistant.components.cloud.const import (
DOMAIN,
PREF_STRICT_CONNECTION,
PREF_TTS_DEFAULT_VOICE,
)
from homeassistant.components.cloud.prefs import STORAGE_KEY, CloudPreferences
from homeassistant.components.http.const import StrictConnectionMode
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
@ -174,3 +179,21 @@ async def test_tts_default_voice_legacy_gender(
await hass.async_block_till_done()
assert cloud.client.prefs.tts_default_voice == (expected_language, voice)
@pytest.mark.parametrize("mode", list(StrictConnectionMode))
async def test_strict_connection_convertion(
hass: HomeAssistant,
cloud: MagicMock,
hass_storage: dict[str, Any],
mode: StrictConnectionMode,
) -> None:
"""Test strict connection string value will be converted to the enum."""
hass_storage[STORAGE_KEY] = {
"version": 1,
"data": {PREF_STRICT_CONNECTION: mode.value},
}
assert await async_setup_component(hass, DOMAIN, {DOMAIN: {}})
await hass.async_block_till_done()
assert cloud.client.prefs.strict_connection is mode

View File

@ -0,0 +1,294 @@
"""Test strict connection mode for cloud."""
from collections.abc import Awaitable, Callable, Coroutine, Generator
from contextlib import contextmanager
from datetime import timedelta
from http import HTTPStatus
from typing import Any
from unittest.mock import MagicMock, Mock, patch
from aiohttp import ServerDisconnectedError, web
from aiohttp.test_utils import TestClient
from aiohttp_session import get_session
import pytest
from yarl import URL
from homeassistant.auth.models import RefreshToken
from homeassistant.auth.session import SESSION_ID, TEMP_TIMEOUT
from homeassistant.components.cloud.const import PREF_STRICT_CONNECTION
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import (
STRICT_CONNECTION_STATIC_PAGE,
async_setup_auth,
async_sign_path,
)
from homeassistant.components.http.const import KEY_AUTHENTICATED, StrictConnectionMode
from homeassistant.components.http.session import COOKIE_NAME, PREFIXED_COOKIE_NAME
from homeassistant.core import HomeAssistant
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
from tests.common import async_fire_time_changed
from tests.typing import ClientSessionGenerator
@pytest.fixture
async def refresh_token(hass: HomeAssistant, hass_access_token: str) -> RefreshToken:
"""Return a refresh token."""
refresh_token = hass.auth.async_validate_access_token(hass_access_token)
assert refresh_token
session = hass.auth.session
assert session._strict_connection_sessions == {}
assert session._temp_sessions == {}
return refresh_token
@contextmanager
def simulate_cloud_request() -> Generator[None, None, None]:
"""Simulate a cloud request."""
with patch(
"hass_nabucasa.remote.is_cloud_request", Mock(get=Mock(return_value=True))
):
yield
@pytest.fixture
def app_strict_connection(
hass: HomeAssistant, refresh_token: RefreshToken
) -> web.Application:
"""Fixture to set up a web.Application."""
async def handler(request):
"""Return if request was authenticated."""
return web.json_response(data={"authenticated": request[KEY_AUTHENTICATED]})
app = web.Application()
app[KEY_HASS] = hass
app.router.add_get("/", handler)
async def set_cookie(request: web.Request) -> web.Response:
hass = request.app[KEY_HASS]
# Clear all sessions
hass.auth.session._temp_sessions.clear()
hass.auth.session._strict_connection_sessions.clear()
if request.query["token"] == "refresh":
await hass.auth.session.async_create_session(request, refresh_token)
else:
await hass.auth.session.async_create_temp_unauthorized_session(request)
session = await get_session(request)
return web.Response(text=session[SESSION_ID])
app.router.add_get("/test/cookie", set_cookie)
return app
@pytest.fixture(name="client")
async def set_up_fixture(
hass: HomeAssistant,
aiohttp_client: ClientSessionGenerator,
app_strict_connection: web.Application,
cloud: MagicMock,
socket_enabled: None,
) -> TestClient:
"""Set up the fixture."""
await async_setup_auth(hass, app_strict_connection, StrictConnectionMode.DISABLED)
assert await async_setup_component(hass, "cloud", {"cloud": {}})
await hass.async_block_till_done()
return await aiohttp_client(app_strict_connection)
@pytest.mark.parametrize(
"strict_connection_mode", [e.value for e in StrictConnectionMode]
)
async def test_strict_connection_cloud_authenticated_requests(
hass: HomeAssistant,
client: TestClient,
hass_access_token: str,
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
refresh_token: RefreshToken,
strict_connection_mode: StrictConnectionMode,
) -> None:
"""Test authenticated requests with strict connection."""
assert hass.auth.session._strict_connection_sessions == {}
signed_path = async_sign_path(
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
)
await set_cloud_prefs(
{
PREF_STRICT_CONNECTION: strict_connection_mode,
}
)
with simulate_cloud_request():
assert is_cloud_connection(hass)
req = await client.get(
"/", headers={"Authorization": f"Bearer {hass_access_token}"}
)
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": True}
req = await client.get(signed_path)
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": True}
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests(
hass: HomeAssistant,
client: TestClient,
perform_unauthenticated_request: Callable[
[HomeAssistant, TestClient], Awaitable[None]
],
_: RefreshToken,
) -> None:
"""Test external unauthenticated requests with strict connection cloud enabled."""
with simulate_cloud_request():
assert is_cloud_connection(hass)
await perform_unauthenticated_request(hass, client)
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token(
hass: HomeAssistant,
client: TestClient,
perform_unauthenticated_request: Callable[
[HomeAssistant, TestClient], Awaitable[None]
],
refresh_token: RefreshToken,
) -> None:
"""Test external unauthenticated requests with strict connection cloud enabled and refresh token cookie."""
session = hass.auth.session
# set strict connection cookie with refresh token
session_id = await _modify_cookie_for_cloud(client, "refresh")
assert session._strict_connection_sessions == {session_id: refresh_token.id}
with simulate_cloud_request():
assert is_cloud_connection(hass)
req = await client.get("/")
assert req.status == HTTPStatus.OK
assert await req.json() == {"authenticated": False}
# Invalidate refresh token, which should also invalidate session
hass.auth.async_remove_refresh_token(refresh_token)
assert session._strict_connection_sessions == {}
await perform_unauthenticated_request(hass, client)
async def _test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session(
hass: HomeAssistant,
client: TestClient,
perform_unauthenticated_request: Callable[
[HomeAssistant, TestClient], Awaitable[None]
],
_: RefreshToken,
) -> None:
"""Test external unauthenticated requests with strict connection cloud enabled and temp cookie."""
session = hass.auth.session
# set strict connection cookie with temp session
assert session._temp_sessions == {}
session_id = await _modify_cookie_for_cloud(client, "temp")
assert session_id in session._temp_sessions
with simulate_cloud_request():
assert is_cloud_connection(hass)
resp = await client.get("/")
assert resp.status == HTTPStatus.OK
assert await resp.json() == {"authenticated": False}
async_fire_time_changed(hass, utcnow() + TEMP_TIMEOUT + timedelta(minutes=1))
await hass.async_block_till_done(wait_background_tasks=True)
assert session._temp_sessions == {}
await perform_unauthenticated_request(hass, client)
async def _drop_connection_unauthorized_request(
_: HomeAssistant, client: TestClient
) -> None:
with pytest.raises(ServerDisconnectedError):
# unauthorized requests should raise ServerDisconnectedError
await client.get("/")
async def _static_page_unauthorized_request(
hass: HomeAssistant, client: TestClient
) -> None:
req = await client.get("/")
assert req.status == HTTPStatus.IM_A_TEAPOT
def read_static_page() -> str:
with open(STRICT_CONNECTION_STATIC_PAGE, encoding="utf-8") as file:
return file.read()
assert await req.text() == await hass.async_add_executor_job(read_static_page)
@pytest.mark.parametrize(
"test_func",
[
_test_strict_connection_cloud_enabled_external_unauthenticated_requests,
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_refresh_token,
_test_strict_connection_cloud_enabled_external_unauthenticated_requests_temp_session,
],
ids=[
"no cookie",
"refresh token cookie",
"temp session cookie",
],
)
@pytest.mark.parametrize(
("strict_connection_mode", "request_func"),
[
(StrictConnectionMode.DROP_CONNECTION, _drop_connection_unauthorized_request),
(StrictConnectionMode.STATIC_PAGE, _static_page_unauthorized_request),
],
ids=["drop connection", "static page"],
)
async def test_strict_connection_cloud_external_unauthenticated_requests(
hass: HomeAssistant,
client: TestClient,
refresh_token: RefreshToken,
set_cloud_prefs: Callable[[dict[str, Any]], Coroutine[Any, Any, None]],
test_func: Callable[
[
HomeAssistant,
TestClient,
Callable[[HomeAssistant, TestClient], Awaitable[None]],
RefreshToken,
],
Awaitable[None],
],
strict_connection_mode: StrictConnectionMode,
request_func: Callable[[HomeAssistant, TestClient], Awaitable[None]],
) -> None:
"""Test external unauthenticated requests with strict connection cloud."""
await set_cloud_prefs(
{
PREF_STRICT_CONNECTION: strict_connection_mode,
}
)
await test_func(
hass,
client,
request_func,
refresh_token,
)
async def _modify_cookie_for_cloud(client: TestClient, token_type: str) -> str:
"""Modify cookie for cloud."""
# Cloud cookie has set secure=true and will not set on unsecure connection
# As we test with unsecure connection, we need to set it manually
# We get the session via http and modify the cookie name to the secure one
session_id = await (await client.get(f"/test/cookie?token={token_type}")).text()
cookie_jar = client.session.cookie_jar
localhost = URL("http://127.0.0.1")
cookie = cookie_jar.filter_cookies(localhost)[COOKIE_NAME].value
assert cookie
cookie_jar.clear()
cookie_jar.update_cookies({PREFIXED_COOKIE_NAME: cookie}, localhost)
return session_id

View File

@ -362,6 +362,18 @@ async def test_get_url_external(hass: HomeAssistant) -> None:
with pytest.raises(NoURLAvailableError):
_get_external_url(hass, require_current_request=True, require_ssl=True)
with pytest.raises(NoURLAvailableError):
_get_external_url(hass, require_cloud=True)
with patch(
"homeassistant.components.cloud.async_remote_ui_url",
return_value="https://example.nabu.casa",
):
hass.config.components.add("cloud")
assert (
_get_external_url(hass, require_cloud=True) == "https://example.nabu.casa"
)
async def test_get_cloud_url(hass: HomeAssistant) -> None:
"""Test getting an instance URL when the user has set an external URL."""