Support blocking trusted network from new ip (#44630)

Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
Joakim Plate 2021-01-28 12:06:20 +01:00 committed by GitHub
parent e4a7692610
commit 38d2cacf7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 381 additions and 131 deletions

View File

@ -24,6 +24,14 @@ _ProviderKey = Tuple[str, Optional[str]]
_ProviderDict = Dict[_ProviderKey, AuthProvider]
class InvalidAuthError(Exception):
"""Raised when a authentication error occurs."""
class InvalidProvider(Exception):
"""Authentication provider not found."""
async def auth_manager_from_config(
hass: HomeAssistant,
provider_configs: List[Dict[str, Any]],
@ -96,7 +104,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return result
# we got final result
if isinstance(result["data"], models.User):
if isinstance(result["data"], models.Credentials):
result["result"] = result["data"]
return result
@ -120,11 +128,12 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
modules = await self.auth_manager.async_get_enabled_mfa(user)
if modules:
flow.credential = credentials
flow.user = user
flow.available_mfa_modules = modules
return await flow.async_step_select_mfa_module()
result["result"] = await self.auth_manager.async_get_or_create_user(credentials)
result["result"] = credentials
return result
@ -156,7 +165,7 @@ class AuthManager:
return list(self._mfa_modules.values())
def get_auth_provider(
self, provider_type: str, provider_id: str
self, provider_type: str, provider_id: Optional[str]
) -> Optional[AuthProvider]:
"""Return an auth provider, None if not found."""
return self._providers.get((provider_type, provider_id))
@ -367,6 +376,7 @@ class AuthManager:
client_icon: Optional[str] = None,
token_type: Optional[str] = None,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
credential: Optional[models.Credentials] = None,
) -> models.RefreshToken:
"""Create a new refresh token for a user."""
if not user.is_active:
@ -415,6 +425,7 @@ class AuthManager:
client_icon,
token_type,
access_token_expiration,
credential,
)
async def async_get_refresh_token(
@ -440,6 +451,8 @@ class AuthManager:
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
) -> str:
"""Create a new access token."""
self.async_validate_refresh_token(refresh_token, remote_ip)
self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
now = dt_util.utcnow()
@ -453,6 +466,40 @@ class AuthManager:
algorithm="HS256",
).decode()
@callback
def _async_resolve_provider(
self, refresh_token: models.RefreshToken
) -> Optional[AuthProvider]:
"""Get the auth provider for the given refresh token.
Raises an exception if the expected provider is no longer available or return
None if no provider was expected for this refresh token.
"""
if refresh_token.credential is None:
return None
provider = self.get_auth_provider(
refresh_token.credential.auth_provider_type,
refresh_token.credential.auth_provider_id,
)
if provider is None:
raise InvalidProvider(
f"Auth provider {refresh_token.credential.auth_provider_type}, {refresh_token.credential.auth_provider_id} not available"
)
return provider
@callback
def async_validate_refresh_token(
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
) -> None:
"""Validate that a refresh token is usable.
Will raise InvalidAuthError on errors.
"""
provider = self._async_resolve_provider(refresh_token)
if provider:
provider.async_validate_refresh_token(refresh_token, remote_ip)
async def async_validate_access_token(
self, token: str
) -> Optional[models.RefreshToken]:

View File

@ -208,6 +208,7 @@ class AuthStore:
client_icon: Optional[str] = None,
token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
credential: Optional[models.Credentials] = None,
) -> models.RefreshToken:
"""Create a new token for a user."""
kwargs: Dict[str, Any] = {
@ -215,6 +216,7 @@ class AuthStore:
"client_id": client_id,
"token_type": token_type,
"access_token_expiration": access_token_expiration,
"credential": credential,
}
if client_name:
kwargs["client_name"] = client_name
@ -309,6 +311,7 @@ class AuthStore:
users: Dict[str, models.User] = OrderedDict()
groups: Dict[str, models.Group] = OrderedDict()
credentials: Dict[str, models.Credentials] = OrderedDict()
# Soft-migrating data as we load. We are going to make sure we have a
# read only group and an admin group. There are two states that we can
@ -415,15 +418,15 @@ class AuthStore:
)
for cred_dict in data["credentials"]:
users[cred_dict["user_id"]].credentials.append(
models.Credentials(
id=cred_dict["id"],
is_new=False,
auth_provider_type=cred_dict["auth_provider_type"],
auth_provider_id=cred_dict["auth_provider_id"],
data=cred_dict["data"],
)
credential = models.Credentials(
id=cred_dict["id"],
is_new=False,
auth_provider_type=cred_dict["auth_provider_type"],
auth_provider_id=cred_dict["auth_provider_id"],
data=cred_dict["data"],
)
credentials[cred_dict["id"]] = credential
users[cred_dict["user_id"]].credentials.append(credential)
for rt_dict in data["refresh_tokens"]:
# Filter out the old keys that don't have jwt_key (pre-0.76)
@ -469,6 +472,8 @@ class AuthStore:
jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at,
last_used_ip=rt_dict.get("last_used_ip"),
credential=credentials.get(rt_dict.get("credential_id")),
version=rt_dict.get("version"),
)
users[rt_dict["user_id"]].refresh_tokens[token.id] = token
@ -542,6 +547,10 @@ class AuthStore:
if refresh_token.last_used_at
else None,
"last_used_ip": refresh_token.last_used_ip,
"credential_id": refresh_token.credential.id
if refresh_token.credential
else None,
"version": refresh_token.version,
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()

View File

@ -6,6 +6,7 @@ import uuid
import attr
from homeassistant.const import __version__
from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl
@ -106,6 +107,10 @@ class RefreshToken:
last_used_at: Optional[datetime] = attr.ib(default=None)
last_used_ip: Optional[str] = attr.ib(default=None)
credential: Optional["Credentials"] = attr.ib(default=None)
version: Optional[str] = attr.ib(default=__version__)
@attr.s(slots=True)
class Credentials:

View File

@ -16,7 +16,7 @@ from homeassistant.util.decorator import Registry
from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION
from ..models import Credentials, User, UserMeta
from ..models import Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed"
@ -117,6 +117,16 @@ class AuthProvider:
async def async_initialize(self) -> None:
"""Initialize the auth provider."""
@callback
def async_validate_refresh_token(
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
) -> None:
"""Verify a refresh token is still valid.
Optional hook for an auth provider to verify validity of a refresh token.
Should raise InvalidAuthError on errors.
"""
async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
@ -182,6 +192,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
self.created_at = dt_util.utcnow()
self.invalid_mfa_times = 0
self.user: Optional[User] = None
self.credential: Optional[Credentials] = None
async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None
@ -222,6 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]:
"""Handle the step of mfa validation."""
assert self.credential
assert self.user
errors = {}
@ -257,7 +269,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
return self.async_abort(reason="too_many_retry")
if not errors:
return await self.async_finish(self.user)
return await self.async_finish(self.credential)
description_placeholders: Dict[str, Optional[str]] = {
"mfa_module_name": auth_module.name,

View File

@ -8,13 +8,12 @@ from typing import Any, Dict, Optional, cast
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
from .. import AuthManager
from ..models import Credentials, User, UserMeta
from ..models import Credentials, UserMeta
AUTH_PROVIDER_TYPE = "legacy_api_password"
CONF_API_PASSWORD = "api_password"
@ -30,23 +29,6 @@ class InvalidAuthError(HomeAssistantError):
"""Raised when submitting invalid authentication."""
async def async_validate_password(hass: HomeAssistant, password: str) -> Optional[User]:
"""Return a user if password is valid. None if not."""
auth = cast(AuthManager, hass.auth) # type: ignore
providers = auth.get_auth_providers(AUTH_PROVIDER_TYPE)
if not providers:
raise ValueError("Legacy API password provider not found")
try:
provider = cast(LegacyApiPasswordAuthProvider, providers[0])
provider.async_validate_login(password)
return await auth.async_get_or_create_user(
await provider.async_get_or_create_credentials({})
)
except InvalidAuthError:
return None
@AUTH_PROVIDERS.register(AUTH_PROVIDER_TYPE)
class LegacyApiPasswordAuthProvider(AuthProvider):
"""An auth provider support legacy api_password."""

View File

@ -3,7 +3,14 @@
It shows list of users if access from trusted network.
Abort login flow if not access from trusted network.
"""
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network
from ipaddress import (
IPv4Address,
IPv4Network,
IPv6Address,
IPv6Network,
ip_address,
ip_network,
)
from typing import Any, Dict, List, Optional, Union, cast
import voluptuous as vol
@ -13,7 +20,8 @@ from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
from ..models import Credentials, UserMeta
from .. import InvalidAuthError
from ..models import Credentials, RefreshToken, UserMeta
IPAddress = Union[IPv4Address, IPv6Address]
IPNetwork = Union[IPv4Network, IPv6Network]
@ -46,10 +54,6 @@ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
)
class InvalidAuthError(HomeAssistantError):
"""Raised when try to access from untrusted networks."""
class InvalidUserError(HomeAssistantError):
"""Raised when try to login as invalid user."""
@ -163,6 +167,17 @@ class TrustedNetworksAuthProvider(AuthProvider):
):
raise InvalidAuthError("Not in trusted_networks")
@callback
def async_validate_refresh_token(
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
) -> None:
"""Verify a refresh token is still valid."""
if remote_ip is None:
raise InvalidAuthError(
"Unknown remote ip can't be used for trusted network provider."
)
self.async_validate_access(ip_address(remote_ip))
class TrustedNetworksLoginFlow(LoginFlow):
"""Handler for the login flow."""

View File

@ -115,11 +115,13 @@ Result will be a long-lived access token:
"""
from datetime import timedelta
from typing import Union
import uuid
from aiohttp import web
import voluptuous as vol
from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import (
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
Credentials,
@ -180,9 +182,11 @@ RESULT_TYPE_USER = "user"
@bind_hass
def create_auth_code(hass, client_id: str, user: User) -> str:
def create_auth_code(
hass, client_id: str, credential_or_user: Union[Credentials, User]
) -> str:
"""Create an authorization code to fetch tokens."""
return hass.data[DOMAIN](client_id, user)
return hass.data[DOMAIN](client_id, credential_or_user)
async def async_setup(hass, config):
@ -228,9 +232,9 @@ class TokenView(HomeAssistantView):
requires_auth = False
cors_allowed = True
def __init__(self, retrieve_user):
def __init__(self, retrieve_auth):
"""Initialize the token view."""
self._retrieve_user = retrieve_user
self._retrieve_auth = retrieve_auth
@log_invalid_auth
async def post(self, request):
@ -293,16 +297,15 @@ class TokenView(HomeAssistantView):
status_code=HTTP_BAD_REQUEST,
)
user = self._retrieve_user(client_id, RESULT_TYPE_USER, code)
credential = self._retrieve_auth(client_id, RESULT_TYPE_CREDENTIALS, code)
if user is None or not isinstance(user, User):
if credential is None or not isinstance(credential, Credentials):
return self.json(
{"error": "invalid_request", "error_description": "Invalid code"},
status_code=HTTP_BAD_REQUEST,
)
# refresh user
user = await hass.auth.async_get_user(user.id)
user = await hass.auth.async_get_or_create_user(credential)
if not user.is_active:
return self.json(
@ -310,8 +313,18 @@ class TokenView(HomeAssistantView):
status_code=HTTP_FORBIDDEN,
)
refresh_token = await hass.auth.async_create_refresh_token(user, client_id)
access_token = hass.auth.async_create_access_token(refresh_token, remote_addr)
refresh_token = await hass.auth.async_create_refresh_token(
user, client_id, credential=credential
)
try:
access_token = hass.auth.async_create_access_token(
refresh_token, remote_addr
)
except InvalidAuthError as exc:
return self.json(
{"error": "access_denied", "error_description": str(exc)},
status_code=HTTP_FORBIDDEN,
)
return self.json(
{
@ -346,7 +359,15 @@ class TokenView(HomeAssistantView):
if refresh_token.client_id != client_id:
return self.json({"error": "invalid_request"}, status_code=HTTP_BAD_REQUEST)
access_token = hass.auth.async_create_access_token(refresh_token, remote_addr)
try:
access_token = hass.auth.async_create_access_token(
refresh_token, remote_addr
)
except InvalidAuthError as exc:
return self.json(
{"error": "access_denied", "error_description": str(exc)},
status_code=HTTP_FORBIDDEN,
)
return self.json(
{
@ -482,7 +503,12 @@ async def websocket_create_long_lived_access_token(
access_token_expiration=timedelta(days=msg["lifespan"]),
)
access_token = hass.auth.async_create_access_token(refresh_token)
try:
access_token = hass.auth.async_create_access_token(refresh_token)
except InvalidAuthError as exc:
return websocket_api.error_message(
msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc)
)
connection.send_message(websocket_api.result_message(msg["id"], access_token))

View File

@ -9,7 +9,7 @@ import jwt
from homeassistant.core import callback
from homeassistant.util import dt as dt_util
from .const import KEY_AUTHENTICATED, KEY_HASS_USER
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
# mypy: allow-untyped-defs, no-check-untyped-defs
@ -62,6 +62,7 @@ def setup_auth(hass, app):
return False
request[KEY_HASS_USER] = refresh_token.user
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
return True
async def async_validate_signed_request(request):
@ -92,6 +93,7 @@ def setup_auth(hass, app):
return False
request[KEY_HASS_USER] = refresh_token.user
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
return True
@middleware

View File

@ -2,3 +2,4 @@
KEY_AUTHENTICATED = "ha_authenticated"
KEY_HASS = "hass"
KEY_HASS_USER = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id"

View File

@ -5,6 +5,7 @@ import voluptuous as vol
from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.components.auth import indieauth
from homeassistant.components.http.const import KEY_HASS_REFRESH_TOKEN_ID
from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.const import HTTP_BAD_REQUEST, HTTP_FORBIDDEN
@ -132,7 +133,9 @@ class UserOnboardingView(_BaseOnboardingView):
# Return authorization code for fetching tokens and connect
# during onboarding.
auth_code = hass.components.auth.create_auth_code(data["client_id"], user)
auth_code = hass.components.auth.create_auth_code(
data["client_id"], credentials
)
return self.json({"auth_code": auth_code})
@ -183,7 +186,7 @@ class IntegrationOnboardingView(_BaseOnboardingView):
async def post(self, request, data):
"""Handle token creation."""
hass = request.app["hass"]
user = request["hass_user"]
refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID]
async with self._lock:
if self._async_is_done():
@ -201,8 +204,16 @@ class IntegrationOnboardingView(_BaseOnboardingView):
"invalid client id or redirect uri", HTTP_BAD_REQUEST
)
refresh_token = await hass.auth.async_get_refresh_token(refresh_token_id)
if refresh_token is None or refresh_token.credential is None:
return self.json_message(
"Credentials for user not available", HTTP_FORBIDDEN
)
# Return authorization code so we can redirect user and log them in
auth_code = hass.components.auth.create_auth_code(data["client_id"], user)
auth_code = hass.components.auth.create_auth_code(
data["client_id"], refresh_token.credential
)
return self.json({"auth_code": auth_code})

View File

@ -131,7 +131,7 @@ async def test_login(hass):
result["flow_id"], {"pin": "123456"}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["data"].id == "mock-user"
assert result["data"].id == "mock-id"
async def test_setup_flow(hass):

View File

@ -229,7 +229,7 @@ async def test_login_flow_validates_mfa(hass):
result["flow_id"], {"code": MOCK_CODE}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["data"].id == "mock-user"
assert result["data"].id == "mock-id"
async def test_setup_user_notify_service(hass):

View File

@ -127,7 +127,7 @@ async def test_login_flow_validates_mfa(hass):
result["flow_id"], {"code": MOCK_CODE}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["data"].id == "mock-user"
assert result["data"].id == "mock-id"
async def test_race_condition_in_data_loading(hass):

View File

@ -1,5 +1,6 @@
"""Test the Trusted Networks auth provider."""
from ipaddress import ip_address, ip_network
from unittest.mock import Mock, patch
import pytest
import voluptuous as vol
@ -142,6 +143,16 @@ async def test_validate_access(provider):
provider.async_validate_access(ip_address("2001:db8::ff00:42:8329"))
async def test_validate_refresh_token(provider):
"""Verify re-validation of refresh token."""
with patch.object(provider, "async_validate_access") as mock:
with pytest.raises(tn_auth.InvalidAuthError):
provider.async_validate_refresh_token(Mock(), None)
provider.async_validate_refresh_token(Mock(), "127.0.0.1")
mock.assert_called_once_with(ip_address("127.0.0.1"))
async def test_login_flow(manager, provider):
"""Test login flow."""
owner = await manager.async_create_user("test-owner")

View File

@ -37,6 +37,7 @@ async def test_loading_no_group_data_format(hass, hass_storage):
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
@ -87,12 +88,14 @@ async def test_loading_no_group_data_format(hass, hass_storage):
assert len(owner.refresh_tokens) == 1
owner_token = list(owner.refresh_tokens.values())[0]
assert owner_token.id == "user-token-id"
assert owner_token.version == "1.2.3"
assert system.system_generated is True
assert system.groups == []
assert len(system.refresh_tokens) == 1
system_token = list(system.refresh_tokens.values())[0]
assert system_token.id == "system-token-id"
assert system_token.version is None
async def test_loading_all_access_group_data_format(hass, hass_storage):
@ -129,6 +132,7 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
"token": "some-token",
"user_id": "user-id",
"version": "1.2.3",
},
{
"access_token_expiration": 1800.0,
@ -139,6 +143,7 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
"token": "some-token",
"user_id": "system-id",
"version": None,
},
{
"access_token_expiration": 1800.0,
@ -179,12 +184,14 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
assert len(owner.refresh_tokens) == 1
owner_token = list(owner.refresh_tokens.values())[0]
assert owner_token.id == "user-token-id"
assert owner_token.version == "1.2.3"
assert system.system_generated is True
assert system.groups == []
assert len(system.refresh_tokens) == 1
system_token = list(system.refresh_tokens.values())[0]
assert system_token.id == "system-token-id"
assert system_token.version is None
async def test_loading_empty_data(hass, hass_storage):

View File

@ -7,7 +7,12 @@ import pytest
import voluptuous as vol
from homeassistant import auth, data_entry_flow
from homeassistant.auth import auth_store, const as auth_const, models as auth_models
from homeassistant.auth import (
InvalidAuthError,
auth_store,
const as auth_const,
models as auth_models,
)
from homeassistant.auth.const import MFA_SESSION_EXPIRATION
from homeassistant.core import callback
from homeassistant.util import dt as dt_util
@ -162,7 +167,10 @@ async def test_create_new_user(hass):
step["flow_id"], {"username": "test-user", "password": "test-pass"}
)
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
user = step["result"]
credential = step["result"]
assert credential is not None
user = await manager.async_get_or_create_user(credential)
assert user is not None
assert user.is_owner is False
assert user.name == "Test Name"
@ -229,7 +237,8 @@ async def test_login_as_existing_user(mock_hass):
)
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
user = step["result"]
credential = step["result"]
user = await manager.async_get_user_by_credentials(credential)
assert user is not None
assert user.id == "mock-user"
assert user.is_owner is False
@ -259,7 +268,8 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage):
step = await manager.login_flow.async_configure(
step["flow_id"], {"username": "test-user", "password": "test-pass"}
)
user = step["result"]
credential = step["result"]
user = await manager.async_get_or_create_user(credential)
assert user is not None
step = await manager.login_flow.async_init(
@ -293,13 +303,19 @@ async def test_saving_loading(hass, hass_storage):
step = await manager.login_flow.async_configure(
step["flow_id"], {"username": "test-user", "password": "test-pass"}
)
user = step["result"]
credential = step["result"]
user = await manager.async_get_or_create_user(credential)
await manager.async_activate_user(user)
# the first refresh token will be used to create access token
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
refresh_token = await manager.async_create_refresh_token(
user, CLIENT_ID, credential=credential
)
manager.async_create_access_token(refresh_token, "192.168.0.1")
# the second refresh token will not be used
await manager.async_create_refresh_token(user, "dummy-client")
await manager.async_create_refresh_token(
user, "dummy-client", credential=credential
)
await flush_store(manager._store._store)
@ -452,6 +468,46 @@ async def test_refresh_token_type_long_lived_access_token(hass):
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
async def test_refresh_token_provider_validation(mock_hass):
"""Test that creating access token from refresh token checks with provider."""
manager = await auth.auth_manager_from_config(
mock_hass,
[
{
"type": "insecure_example",
"users": [{"username": "test-user", "password": "test-pass"}],
}
],
[],
)
credential = auth_models.Credentials(
id="mock-credential-id",
auth_provider_type="insecure_example",
auth_provider_id=None,
data={"username": "test-user"},
is_new=False,
)
user = MockUser().add_to_auth_manager(manager)
user.credentials.append(credential)
refresh_token = await manager.async_create_refresh_token(
user, CLIENT_ID, credential=credential
)
ip = "127.0.0.1"
assert manager.async_create_access_token(refresh_token, ip) is not None
with patch(
"homeassistant.auth.providers.insecure_example.ExampleAuthProvider.async_validate_refresh_token",
side_effect=InvalidAuthError("Invalid access"),
) as call:
with pytest.raises(InvalidAuthError):
manager.async_create_access_token(refresh_token, ip)
call.assert_called_with(refresh_token, ip)
async def test_cannot_deactive_owner(mock_hass):
"""Test that we cannot deactivate the owner."""
manager = await auth.auth_manager_from_config(mock_hass, [], [])
@ -626,14 +682,10 @@ async def test_login_with_auth_module(mock_hass):
step["flow_id"], {"pin": "test-pin"}
)
# Finally passed, get user
# Finally passed, get credential
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
user = step["result"]
assert user is not None
assert user.id == "mock-user"
assert user.is_owner is False
assert user.is_active is False
assert user.name == "Paulus"
assert step["result"]
assert step["result"].id == "mock-id"
async def test_login_with_multi_auth_module(mock_hass):
@ -703,14 +755,10 @@ async def test_login_with_multi_auth_module(mock_hass):
step["flow_id"], {"pin": "test-pin2"}
)
# Finally passed, get user
# Finally passed, get credential
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
user = step["result"]
assert user is not None
assert user.id == "mock-user"
assert user.is_owner is False
assert user.is_active is False
assert user.name == "Paulus"
assert step["result"]
assert step["result"].id == "mock-id"
async def test_auth_module_expired_session(mock_hass):
@ -792,7 +840,8 @@ async def test_enable_mfa_for_user(hass, hass_storage):
step = await manager.login_flow.async_configure(
step["flow_id"], {"username": "test-user", "password": "test-pass"}
)
user = step["result"]
credential = step["result"]
user = await manager.async_get_or_create_user(credential)
assert user is not None
# new user don't have mfa enabled

View File

@ -2,6 +2,7 @@
from datetime import timedelta
from unittest.mock import patch
from homeassistant.auth import InvalidAuthError
from homeassistant.auth.models import Credentials
from homeassistant.components import auth
from homeassistant.components.auth import RESULT_TYPE_USER
@ -13,6 +14,24 @@ from . import async_setup_auth
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
async def async_setup_user_refresh_token(hass):
"""Create a testing user with a connected credential."""
user = await hass.auth.async_create_user("Test User")
credential = Credentials(
id="mock-credential-id",
auth_provider_type="insecure_example",
auth_provider_id=None,
data={"username": "test-user"},
is_new=False,
)
user.credentials.append(credential)
return await hass.auth.async_create_refresh_token(
user, CLIENT_ID, credential=credential
)
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
"""Test logging in with new user and refreshing tokens."""
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
@ -107,12 +126,6 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
user = refresh_token.user
credential = Credentials(
auth_provider_type="homeassistant", auth_provider_id=None, data={}, id="test-id"
)
user.credentials.append(credential)
assert len(user.credentials) == 1
client = await hass_ws_client(hass, hass_access_token)
await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER})
@ -185,8 +198,7 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
async def test_refresh_token_different_client_id(hass, aiohttp_client):
"""Test that we verify client ID."""
client = await async_setup_auth(hass, aiohttp_client)
user = await hass.auth.async_create_user("Test User")
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
refresh_token = await async_setup_user_refresh_token(hass)
# No client ID
resp = await client.post(
@ -229,11 +241,37 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
)
async def test_refresh_token_provider_rejected(
hass, aiohttp_client, hass_admin_user, hass_admin_credential
):
"""Test that we verify client ID."""
client = await async_setup_auth(hass, aiohttp_client)
refresh_token = await async_setup_user_refresh_token(hass)
# Rejected by provider
with patch(
"homeassistant.auth.providers.insecure_example.ExampleAuthProvider.async_validate_refresh_token",
side_effect=InvalidAuthError("Invalid access"),
):
resp = await client.post(
"/auth/token",
data={
"client_id": CLIENT_ID,
"grant_type": "refresh_token",
"refresh_token": refresh_token.token,
},
)
assert resp.status == 403
result = await resp.json()
assert result["error"] == "access_denied"
assert result["error_description"] == "Invalid access"
async def test_revoking_refresh_token(hass, aiohttp_client):
"""Test that we can revoke refresh tokens."""
client = await async_setup_auth(hass, aiohttp_client)
user = await hass.auth.async_create_user("Test User")
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
refresh_token = await async_setup_user_refresh_token(hass)
# Test that we can create an access token
resp = await client.post(

View File

@ -48,7 +48,9 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
id="hij", name="Inactive User", is_active=False, groups=[group]
).add_to_hass(hass)
refresh_token = await hass.auth.async_create_refresh_token(owner, CLIENT_ID)
refresh_token = await hass.auth.async_create_refresh_token(
owner, CLIENT_ID, credential=owner.credentials[0]
)
access_token = hass.auth.async_create_access_token(refresh_token)
client = await hass_ws_client(hass, access_token)
@ -60,13 +62,13 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
assert len(data) == 4
assert data[0] == {
"id": hass_admin_user.id,
"username": None,
"username": "admin",
"name": "Mock User",
"is_owner": False,
"is_active": True,
"system_generated": False,
"group_ids": [group.id for group in hass_admin_user.groups],
"credentials": [],
"credentials": [{"type": "homeassistant"}],
}
assert data[1] == {
"id": owner.id,

View File

@ -4,24 +4,19 @@ import pytest
from homeassistant.auth.providers import homeassistant as prov_ha
from homeassistant.components.config import auth_provider_homeassistant as auth_ha
from tests.common import CLIENT_ID, MockUser, register_auth_provider
from tests.common import CLIENT_ID, MockUser
@pytest.fixture(autouse=True)
def setup_config(hass):
"""Fixture that sets up the auth provider homeassistant module."""
hass.loop.run_until_complete(
register_auth_provider(hass, {"type": "homeassistant"})
)
hass.loop.run_until_complete(auth_ha.async_setup(hass))
async def setup_config(hass, local_auth):
"""Fixture that sets up the auth provider ."""
await auth_ha.async_setup(hass)
@pytest.fixture
async def auth_provider(hass):
async def auth_provider(local_auth):
"""Hass auth provider."""
provider = hass.auth.auth_providers[0]
await provider.async_initialize()
return provider
return local_auth
@pytest.fixture
@ -34,8 +29,8 @@ async def owner_access_token(hass, hass_owner_user):
@pytest.fixture
async def test_user_credential(hass, auth_provider):
"""Add a test user."""
async def hass_admin_credential(hass, auth_provider):
"""Overload credentials to admin user."""
await hass.async_add_executor_job(
auth_provider.data.add_auth, "test-user", "test-pass"
)
@ -124,7 +119,7 @@ async def test_create_auth(hass, hass_ws_client, hass_storage):
"id": 5,
"type": "config/auth_provider/homeassistant/create",
"user_id": user.id,
"username": "test-user",
"username": "test-user2",
"password": "test-pass",
}
)
@ -135,10 +130,10 @@ async def test_create_auth(hass, hass_ws_client, hass_storage):
creds = user.credentials[0]
assert creds.auth_provider_type == "homeassistant"
assert creds.auth_provider_id is None
assert creds.data == {"username": "test-user"}
assert creds.data == {"username": "test-user2"}
assert prov_ha.STORAGE_KEY in hass_storage
entry = hass_storage[prov_ha.STORAGE_KEY]["data"]["users"][0]
assert entry["username"] == "test-user"
entry = hass_storage[prov_ha.STORAGE_KEY]["data"]["users"][1]
assert entry["username"] == "test-user2"
async def test_create_auth_duplicate_username(hass, hass_ws_client, hass_storage):
@ -242,7 +237,7 @@ async def test_delete_unknown_auth(hass, hass_ws_client):
{
"id": 5,
"type": "config/auth_provider/homeassistant/delete",
"username": "test-user",
"username": "test-user2",
}
)
@ -251,12 +246,8 @@ async def test_delete_unknown_auth(hass, hass_ws_client):
assert result["error"]["code"] == "auth_not_found"
async def test_change_password(
hass, hass_ws_client, hass_admin_user, auth_provider, test_user_credential
):
async def test_change_password(hass, hass_ws_client, auth_provider):
"""Test that change password succeeds with valid password."""
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
client = await hass_ws_client(hass)
await client.send_json(
{
@ -273,10 +264,9 @@ async def test_change_password(
async def test_change_password_wrong_pw(
hass, hass_ws_client, hass_admin_user, auth_provider, test_user_credential
hass, hass_ws_client, hass_admin_user, auth_provider
):
"""Test that change password fails with invalid password."""
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
client = await hass_ws_client(hass)
await client.send_json(
@ -295,8 +285,9 @@ async def test_change_password_wrong_pw(
await auth_provider.async_validate_login("test-user", "new-pass")
async def test_change_password_no_creds(hass, hass_ws_client):
async def test_change_password_no_creds(hass, hass_ws_client, hass_admin_user):
"""Test that change password fails with no credentials."""
hass_admin_user.credentials.clear()
client = await hass_ws_client(hass)
await client.send_json(
@ -313,9 +304,7 @@ async def test_change_password_no_creds(hass, hass_ws_client):
assert result["error"]["code"] == "credentials_not_found"
async def test_admin_change_password_not_owner(
hass, hass_ws_client, auth_provider, test_user_credential
):
async def test_admin_change_password_not_owner(hass, hass_ws_client, auth_provider):
"""Test that change password fails when not owner."""
client = await hass_ws_client(hass)
@ -358,6 +347,8 @@ async def test_admin_change_password_no_cred(
hass, hass_ws_client, owner_access_token, hass_admin_user
):
"""Test that change password fails with unknown credential."""
hass_admin_user.credentials.clear()
client = await hass_ws_client(hass, owner_access_token)
await client.send_json(
@ -379,12 +370,9 @@ async def test_admin_change_password(
hass_ws_client,
owner_access_token,
auth_provider,
test_user_credential,
hass_admin_user,
):
"""Test that owners can change any password."""
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
client = await hass_ws_client(hass, owner_access_token)
await client.send_json(

View File

@ -247,7 +247,7 @@ async def test_onboarding_user_race(hass, hass_storage, aiohttp_client):
assert sorted([res1.status, res2.status]) == [200, HTTP_FORBIDDEN]
async def test_onboarding_integration(hass, hass_storage, hass_client):
async def test_onboarding_integration(hass, hass_storage, hass_client, hass_admin_user):
"""Test finishing integration step."""
mock_storage(hass_storage, {"done": [const.STEP_USER]})
@ -288,6 +288,28 @@ async def test_onboarding_integration(hass, hass_storage, hass_client):
assert len(user.refresh_tokens) == 2, user
async def test_onboarding_integration_missing_credential(
hass, hass_storage, hass_client, hass_access_token
):
"""Test that we fail integration step if user is missing credentials."""
mock_storage(hass_storage, {"done": [const.STEP_USER]})
assert await async_setup_component(hass, "onboarding", {})
await hass.async_block_till_done()
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
refresh_token.credential = None
client = await hass_client()
resp = await client.post(
"/api/onboarding/integration",
json={"client_id": CLIENT_ID, "redirect_uri": CLIENT_REDIRECT_URI},
)
assert resp.status == 403
async def test_onboarding_integration_invalid_redirect_uri(
hass, hass_storage, hass_client
):

View File

@ -14,6 +14,7 @@ import requests_mock as _requests_mock
from homeassistant import core as ha, loader, runner, util
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
from homeassistant.auth.models import Credentials
from homeassistant.auth.providers import homeassistant, legacy_api_password
from homeassistant.components import mqtt
from homeassistant.components.websocket_api.auth import (
@ -201,10 +202,20 @@ def mock_device_tracker_conf():
@pytest.fixture
def hass_access_token(hass, hass_admin_user):
async def hass_admin_credential(hass, local_auth):
"""Provide credentials for admin user."""
await hass.async_add_executor_job(local_auth.data.add_auth, "admin", "admin-pass")
return await local_auth.async_get_or_create_credentials({"username": "admin"})
@pytest.fixture
async def hass_access_token(hass, hass_admin_user, hass_admin_credential):
"""Return an access token to access Home Assistant."""
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID)
await hass.auth.async_link_user(hass_admin_user, hass_admin_credential)
refresh_token = await hass.auth.async_create_refresh_token(
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
)
return hass.auth.async_create_access_token(refresh_token)
@ -234,10 +245,21 @@ def hass_read_only_user(hass, local_auth):
@pytest.fixture
def hass_read_only_access_token(hass, hass_read_only_user):
def hass_read_only_access_token(hass, hass_read_only_user, local_auth):
"""Return a Home Assistant read only user."""
credential = Credentials(
id="mock-readonly-credential-id",
auth_provider_type="homeassistant",
auth_provider_id=None,
data={"username": "readonly"},
is_new=False,
)
hass_read_only_user.credentials.append(credential)
refresh_token = hass.loop.run_until_complete(
hass.auth.async_create_refresh_token(hass_read_only_user, CLIENT_ID)
hass.auth.async_create_refresh_token(
hass_read_only_user, CLIENT_ID, credential=credential
)
)
return hass.auth.async_create_access_token(refresh_token)
@ -260,6 +282,7 @@ def local_auth(hass):
prv = homeassistant.HassAuthProvider(
hass, hass.auth._store, {"type": "homeassistant"}
)
hass.loop.run_until_complete(prv.async_initialize())
hass.auth._providers[(prv.type, prv.id)] = prv
return prv