mirror of
https://github.com/home-assistant/core
synced 2024-10-01 05:30:36 +02:00
Support blocking trusted network from new ip (#44630)
Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io>
This commit is contained in:
parent
e4a7692610
commit
38d2cacf7a
@ -24,6 +24,14 @@ _ProviderKey = Tuple[str, Optional[str]]
|
|||||||
_ProviderDict = Dict[_ProviderKey, AuthProvider]
|
_ProviderDict = Dict[_ProviderKey, AuthProvider]
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuthError(Exception):
|
||||||
|
"""Raised when a authentication error occurs."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidProvider(Exception):
|
||||||
|
"""Authentication provider not found."""
|
||||||
|
|
||||||
|
|
||||||
async def auth_manager_from_config(
|
async def auth_manager_from_config(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
provider_configs: List[Dict[str, Any]],
|
provider_configs: List[Dict[str, Any]],
|
||||||
@ -96,7 +104,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
# we got final result
|
# we got final result
|
||||||
if isinstance(result["data"], models.User):
|
if isinstance(result["data"], models.Credentials):
|
||||||
result["result"] = result["data"]
|
result["result"] = result["data"]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -120,11 +128,12 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
|
|||||||
modules = await self.auth_manager.async_get_enabled_mfa(user)
|
modules = await self.auth_manager.async_get_enabled_mfa(user)
|
||||||
|
|
||||||
if modules:
|
if modules:
|
||||||
|
flow.credential = credentials
|
||||||
flow.user = user
|
flow.user = user
|
||||||
flow.available_mfa_modules = modules
|
flow.available_mfa_modules = modules
|
||||||
return await flow.async_step_select_mfa_module()
|
return await flow.async_step_select_mfa_module()
|
||||||
|
|
||||||
result["result"] = await self.auth_manager.async_get_or_create_user(credentials)
|
result["result"] = credentials
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -156,7 +165,7 @@ class AuthManager:
|
|||||||
return list(self._mfa_modules.values())
|
return list(self._mfa_modules.values())
|
||||||
|
|
||||||
def get_auth_provider(
|
def get_auth_provider(
|
||||||
self, provider_type: str, provider_id: str
|
self, provider_type: str, provider_id: Optional[str]
|
||||||
) -> Optional[AuthProvider]:
|
) -> Optional[AuthProvider]:
|
||||||
"""Return an auth provider, None if not found."""
|
"""Return an auth provider, None if not found."""
|
||||||
return self._providers.get((provider_type, provider_id))
|
return self._providers.get((provider_type, provider_id))
|
||||||
@ -367,6 +376,7 @@ class AuthManager:
|
|||||||
client_icon: Optional[str] = None,
|
client_icon: Optional[str] = None,
|
||||||
token_type: Optional[str] = None,
|
token_type: Optional[str] = None,
|
||||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||||
|
credential: Optional[models.Credentials] = None,
|
||||||
) -> models.RefreshToken:
|
) -> models.RefreshToken:
|
||||||
"""Create a new refresh token for a user."""
|
"""Create a new refresh token for a user."""
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
@ -415,6 +425,7 @@ class AuthManager:
|
|||||||
client_icon,
|
client_icon,
|
||||||
token_type,
|
token_type,
|
||||||
access_token_expiration,
|
access_token_expiration,
|
||||||
|
credential,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_refresh_token(
|
async def async_get_refresh_token(
|
||||||
@ -440,6 +451,8 @@ class AuthManager:
|
|||||||
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new access token."""
|
"""Create a new access token."""
|
||||||
|
self.async_validate_refresh_token(refresh_token, remote_ip)
|
||||||
|
|
||||||
self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
|
self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
|
||||||
|
|
||||||
now = dt_util.utcnow()
|
now = dt_util.utcnow()
|
||||||
@ -453,6 +466,40 @@ class AuthManager:
|
|||||||
algorithm="HS256",
|
algorithm="HS256",
|
||||||
).decode()
|
).decode()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_resolve_provider(
|
||||||
|
self, refresh_token: models.RefreshToken
|
||||||
|
) -> Optional[AuthProvider]:
|
||||||
|
"""Get the auth provider for the given refresh token.
|
||||||
|
|
||||||
|
Raises an exception if the expected provider is no longer available or return
|
||||||
|
None if no provider was expected for this refresh token.
|
||||||
|
"""
|
||||||
|
if refresh_token.credential is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider = self.get_auth_provider(
|
||||||
|
refresh_token.credential.auth_provider_type,
|
||||||
|
refresh_token.credential.auth_provider_id,
|
||||||
|
)
|
||||||
|
if provider is None:
|
||||||
|
raise InvalidProvider(
|
||||||
|
f"Auth provider {refresh_token.credential.auth_provider_type}, {refresh_token.credential.auth_provider_id} not available"
|
||||||
|
)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_validate_refresh_token(
|
||||||
|
self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Validate that a refresh token is usable.
|
||||||
|
|
||||||
|
Will raise InvalidAuthError on errors.
|
||||||
|
"""
|
||||||
|
provider = self._async_resolve_provider(refresh_token)
|
||||||
|
if provider:
|
||||||
|
provider.async_validate_refresh_token(refresh_token, remote_ip)
|
||||||
|
|
||||||
async def async_validate_access_token(
|
async def async_validate_access_token(
|
||||||
self, token: str
|
self, token: str
|
||||||
) -> Optional[models.RefreshToken]:
|
) -> Optional[models.RefreshToken]:
|
||||||
|
@ -208,6 +208,7 @@ class AuthStore:
|
|||||||
client_icon: Optional[str] = None,
|
client_icon: Optional[str] = None,
|
||||||
token_type: str = models.TOKEN_TYPE_NORMAL,
|
token_type: str = models.TOKEN_TYPE_NORMAL,
|
||||||
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
|
||||||
|
credential: Optional[models.Credentials] = None,
|
||||||
) -> models.RefreshToken:
|
) -> models.RefreshToken:
|
||||||
"""Create a new token for a user."""
|
"""Create a new token for a user."""
|
||||||
kwargs: Dict[str, Any] = {
|
kwargs: Dict[str, Any] = {
|
||||||
@ -215,6 +216,7 @@ class AuthStore:
|
|||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"token_type": token_type,
|
"token_type": token_type,
|
||||||
"access_token_expiration": access_token_expiration,
|
"access_token_expiration": access_token_expiration,
|
||||||
|
"credential": credential,
|
||||||
}
|
}
|
||||||
if client_name:
|
if client_name:
|
||||||
kwargs["client_name"] = client_name
|
kwargs["client_name"] = client_name
|
||||||
@ -309,6 +311,7 @@ class AuthStore:
|
|||||||
|
|
||||||
users: Dict[str, models.User] = OrderedDict()
|
users: Dict[str, models.User] = OrderedDict()
|
||||||
groups: Dict[str, models.Group] = OrderedDict()
|
groups: Dict[str, models.Group] = OrderedDict()
|
||||||
|
credentials: Dict[str, models.Credentials] = OrderedDict()
|
||||||
|
|
||||||
# Soft-migrating data as we load. We are going to make sure we have a
|
# Soft-migrating data as we load. We are going to make sure we have a
|
||||||
# read only group and an admin group. There are two states that we can
|
# read only group and an admin group. There are two states that we can
|
||||||
@ -415,15 +418,15 @@ class AuthStore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for cred_dict in data["credentials"]:
|
for cred_dict in data["credentials"]:
|
||||||
users[cred_dict["user_id"]].credentials.append(
|
credential = models.Credentials(
|
||||||
models.Credentials(
|
id=cred_dict["id"],
|
||||||
id=cred_dict["id"],
|
is_new=False,
|
||||||
is_new=False,
|
auth_provider_type=cred_dict["auth_provider_type"],
|
||||||
auth_provider_type=cred_dict["auth_provider_type"],
|
auth_provider_id=cred_dict["auth_provider_id"],
|
||||||
auth_provider_id=cred_dict["auth_provider_id"],
|
data=cred_dict["data"],
|
||||||
data=cred_dict["data"],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
credentials[cred_dict["id"]] = credential
|
||||||
|
users[cred_dict["user_id"]].credentials.append(credential)
|
||||||
|
|
||||||
for rt_dict in data["refresh_tokens"]:
|
for rt_dict in data["refresh_tokens"]:
|
||||||
# Filter out the old keys that don't have jwt_key (pre-0.76)
|
# Filter out the old keys that don't have jwt_key (pre-0.76)
|
||||||
@ -469,6 +472,8 @@ class AuthStore:
|
|||||||
jwt_key=rt_dict["jwt_key"],
|
jwt_key=rt_dict["jwt_key"],
|
||||||
last_used_at=last_used_at,
|
last_used_at=last_used_at,
|
||||||
last_used_ip=rt_dict.get("last_used_ip"),
|
last_used_ip=rt_dict.get("last_used_ip"),
|
||||||
|
credential=credentials.get(rt_dict.get("credential_id")),
|
||||||
|
version=rt_dict.get("version"),
|
||||||
)
|
)
|
||||||
users[rt_dict["user_id"]].refresh_tokens[token.id] = token
|
users[rt_dict["user_id"]].refresh_tokens[token.id] = token
|
||||||
|
|
||||||
@ -542,6 +547,10 @@ class AuthStore:
|
|||||||
if refresh_token.last_used_at
|
if refresh_token.last_used_at
|
||||||
else None,
|
else None,
|
||||||
"last_used_ip": refresh_token.last_used_ip,
|
"last_used_ip": refresh_token.last_used_ip,
|
||||||
|
"credential_id": refresh_token.credential.id
|
||||||
|
if refresh_token.credential
|
||||||
|
else None,
|
||||||
|
"version": refresh_token.version,
|
||||||
}
|
}
|
||||||
for user in self._users.values()
|
for user in self._users.values()
|
||||||
for refresh_token in user.refresh_tokens.values()
|
for refresh_token in user.refresh_tokens.values()
|
||||||
|
@ -6,6 +6,7 @@ import uuid
|
|||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from homeassistant.const import __version__
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from . import permissions as perm_mdl
|
from . import permissions as perm_mdl
|
||||||
@ -106,6 +107,10 @@ class RefreshToken:
|
|||||||
last_used_at: Optional[datetime] = attr.ib(default=None)
|
last_used_at: Optional[datetime] = attr.ib(default=None)
|
||||||
last_used_ip: Optional[str] = attr.ib(default=None)
|
last_used_ip: Optional[str] = attr.ib(default=None)
|
||||||
|
|
||||||
|
credential: Optional["Credentials"] = attr.ib(default=None)
|
||||||
|
|
||||||
|
version: Optional[str] = attr.ib(default=__version__)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class Credentials:
|
class Credentials:
|
||||||
|
@ -16,7 +16,7 @@ from homeassistant.util.decorator import Registry
|
|||||||
|
|
||||||
from ..auth_store import AuthStore
|
from ..auth_store import AuthStore
|
||||||
from ..const import MFA_SESSION_EXPIRATION
|
from ..const import MFA_SESSION_EXPIRATION
|
||||||
from ..models import Credentials, User, UserMeta
|
from ..models import Credentials, RefreshToken, User, UserMeta
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
DATA_REQS = "auth_prov_reqs_processed"
|
DATA_REQS = "auth_prov_reqs_processed"
|
||||||
@ -117,6 +117,16 @@ class AuthProvider:
|
|||||||
async def async_initialize(self) -> None:
|
async def async_initialize(self) -> None:
|
||||||
"""Initialize the auth provider."""
|
"""Initialize the auth provider."""
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_validate_refresh_token(
|
||||||
|
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Verify a refresh token is still valid.
|
||||||
|
|
||||||
|
Optional hook for an auth provider to verify validity of a refresh token.
|
||||||
|
Should raise InvalidAuthError on errors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
async def auth_provider_from_config(
|
async def auth_provider_from_config(
|
||||||
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
|
||||||
@ -182,6 +192,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||||||
self.created_at = dt_util.utcnow()
|
self.created_at = dt_util.utcnow()
|
||||||
self.invalid_mfa_times = 0
|
self.invalid_mfa_times = 0
|
||||||
self.user: Optional[User] = None
|
self.user: Optional[User] = None
|
||||||
|
self.credential: Optional[Credentials] = None
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: Optional[Dict[str, str]] = None
|
||||||
@ -222,6 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||||||
self, user_input: Optional[Dict[str, str]] = None
|
self, user_input: Optional[Dict[str, str]] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Handle the step of mfa validation."""
|
"""Handle the step of mfa validation."""
|
||||||
|
assert self.credential
|
||||||
assert self.user
|
assert self.user
|
||||||
|
|
||||||
errors = {}
|
errors = {}
|
||||||
@ -257,7 +269,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
|
|||||||
return self.async_abort(reason="too_many_retry")
|
return self.async_abort(reason="too_many_retry")
|
||||||
|
|
||||||
if not errors:
|
if not errors:
|
||||||
return await self.async_finish(self.user)
|
return await self.async_finish(self.credential)
|
||||||
|
|
||||||
description_placeholders: Dict[str, Optional[str]] = {
|
description_placeholders: Dict[str, Optional[str]] = {
|
||||||
"mfa_module_name": auth_module.name,
|
"mfa_module_name": auth_module.name,
|
||||||
|
@ -8,13 +8,12 @@ from typing import Any, Dict, Optional, cast
|
|||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
|
||||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||||
from .. import AuthManager
|
from ..models import Credentials, UserMeta
|
||||||
from ..models import Credentials, User, UserMeta
|
|
||||||
|
|
||||||
AUTH_PROVIDER_TYPE = "legacy_api_password"
|
AUTH_PROVIDER_TYPE = "legacy_api_password"
|
||||||
CONF_API_PASSWORD = "api_password"
|
CONF_API_PASSWORD = "api_password"
|
||||||
@ -30,23 +29,6 @@ class InvalidAuthError(HomeAssistantError):
|
|||||||
"""Raised when submitting invalid authentication."""
|
"""Raised when submitting invalid authentication."""
|
||||||
|
|
||||||
|
|
||||||
async def async_validate_password(hass: HomeAssistant, password: str) -> Optional[User]:
|
|
||||||
"""Return a user if password is valid. None if not."""
|
|
||||||
auth = cast(AuthManager, hass.auth) # type: ignore
|
|
||||||
providers = auth.get_auth_providers(AUTH_PROVIDER_TYPE)
|
|
||||||
if not providers:
|
|
||||||
raise ValueError("Legacy API password provider not found")
|
|
||||||
|
|
||||||
try:
|
|
||||||
provider = cast(LegacyApiPasswordAuthProvider, providers[0])
|
|
||||||
provider.async_validate_login(password)
|
|
||||||
return await auth.async_get_or_create_user(
|
|
||||||
await provider.async_get_or_create_credentials({})
|
|
||||||
)
|
|
||||||
except InvalidAuthError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@AUTH_PROVIDERS.register(AUTH_PROVIDER_TYPE)
|
@AUTH_PROVIDERS.register(AUTH_PROVIDER_TYPE)
|
||||||
class LegacyApiPasswordAuthProvider(AuthProvider):
|
class LegacyApiPasswordAuthProvider(AuthProvider):
|
||||||
"""An auth provider support legacy api_password."""
|
"""An auth provider support legacy api_password."""
|
||||||
|
@ -3,7 +3,14 @@
|
|||||||
It shows list of users if access from trusted network.
|
It shows list of users if access from trusted network.
|
||||||
Abort login flow if not access from trusted network.
|
Abort login flow if not access from trusted network.
|
||||||
"""
|
"""
|
||||||
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network
|
from ipaddress import (
|
||||||
|
IPv4Address,
|
||||||
|
IPv4Network,
|
||||||
|
IPv6Address,
|
||||||
|
IPv6Network,
|
||||||
|
ip_address,
|
||||||
|
ip_network,
|
||||||
|
)
|
||||||
from typing import Any, Dict, List, Optional, Union, cast
|
from typing import Any, Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -13,7 +20,8 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
|
||||||
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
|
||||||
from ..models import Credentials, UserMeta
|
from .. import InvalidAuthError
|
||||||
|
from ..models import Credentials, RefreshToken, UserMeta
|
||||||
|
|
||||||
IPAddress = Union[IPv4Address, IPv6Address]
|
IPAddress = Union[IPv4Address, IPv6Address]
|
||||||
IPNetwork = Union[IPv4Network, IPv6Network]
|
IPNetwork = Union[IPv4Network, IPv6Network]
|
||||||
@ -46,10 +54,6 @@ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidAuthError(HomeAssistantError):
|
|
||||||
"""Raised when try to access from untrusted networks."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidUserError(HomeAssistantError):
|
class InvalidUserError(HomeAssistantError):
|
||||||
"""Raised when try to login as invalid user."""
|
"""Raised when try to login as invalid user."""
|
||||||
|
|
||||||
@ -163,6 +167,17 @@ class TrustedNetworksAuthProvider(AuthProvider):
|
|||||||
):
|
):
|
||||||
raise InvalidAuthError("Not in trusted_networks")
|
raise InvalidAuthError("Not in trusted_networks")
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_validate_refresh_token(
|
||||||
|
self, refresh_token: RefreshToken, remote_ip: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Verify a refresh token is still valid."""
|
||||||
|
if remote_ip is None:
|
||||||
|
raise InvalidAuthError(
|
||||||
|
"Unknown remote ip can't be used for trusted network provider."
|
||||||
|
)
|
||||||
|
self.async_validate_access(ip_address(remote_ip))
|
||||||
|
|
||||||
|
|
||||||
class TrustedNetworksLoginFlow(LoginFlow):
|
class TrustedNetworksLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
@ -115,11 +115,13 @@ Result will be a long-lived access token:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Union
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.auth.models import (
|
from homeassistant.auth.models import (
|
||||||
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN,
|
||||||
Credentials,
|
Credentials,
|
||||||
@ -180,9 +182,11 @@ RESULT_TYPE_USER = "user"
|
|||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def create_auth_code(hass, client_id: str, user: User) -> str:
|
def create_auth_code(
|
||||||
|
hass, client_id: str, credential_or_user: Union[Credentials, User]
|
||||||
|
) -> str:
|
||||||
"""Create an authorization code to fetch tokens."""
|
"""Create an authorization code to fetch tokens."""
|
||||||
return hass.data[DOMAIN](client_id, user)
|
return hass.data[DOMAIN](client_id, credential_or_user)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
@ -228,9 +232,9 @@ class TokenView(HomeAssistantView):
|
|||||||
requires_auth = False
|
requires_auth = False
|
||||||
cors_allowed = True
|
cors_allowed = True
|
||||||
|
|
||||||
def __init__(self, retrieve_user):
|
def __init__(self, retrieve_auth):
|
||||||
"""Initialize the token view."""
|
"""Initialize the token view."""
|
||||||
self._retrieve_user = retrieve_user
|
self._retrieve_auth = retrieve_auth
|
||||||
|
|
||||||
@log_invalid_auth
|
@log_invalid_auth
|
||||||
async def post(self, request):
|
async def post(self, request):
|
||||||
@ -293,16 +297,15 @@ class TokenView(HomeAssistantView):
|
|||||||
status_code=HTTP_BAD_REQUEST,
|
status_code=HTTP_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
user = self._retrieve_user(client_id, RESULT_TYPE_USER, code)
|
credential = self._retrieve_auth(client_id, RESULT_TYPE_CREDENTIALS, code)
|
||||||
|
|
||||||
if user is None or not isinstance(user, User):
|
if credential is None or not isinstance(credential, Credentials):
|
||||||
return self.json(
|
return self.json(
|
||||||
{"error": "invalid_request", "error_description": "Invalid code"},
|
{"error": "invalid_request", "error_description": "Invalid code"},
|
||||||
status_code=HTTP_BAD_REQUEST,
|
status_code=HTTP_BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
# refresh user
|
user = await hass.auth.async_get_or_create_user(credential)
|
||||||
user = await hass.auth.async_get_user(user.id)
|
|
||||||
|
|
||||||
if not user.is_active:
|
if not user.is_active:
|
||||||
return self.json(
|
return self.json(
|
||||||
@ -310,8 +313,18 @@ class TokenView(HomeAssistantView):
|
|||||||
status_code=HTTP_FORBIDDEN,
|
status_code=HTTP_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(user, client_id)
|
refresh_token = await hass.auth.async_create_refresh_token(
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token, remote_addr)
|
user, client_id, credential=credential
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
access_token = hass.auth.async_create_access_token(
|
||||||
|
refresh_token, remote_addr
|
||||||
|
)
|
||||||
|
except InvalidAuthError as exc:
|
||||||
|
return self.json(
|
||||||
|
{"error": "access_denied", "error_description": str(exc)},
|
||||||
|
status_code=HTTP_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
return self.json(
|
return self.json(
|
||||||
{
|
{
|
||||||
@ -346,7 +359,15 @@ class TokenView(HomeAssistantView):
|
|||||||
if refresh_token.client_id != client_id:
|
if refresh_token.client_id != client_id:
|
||||||
return self.json({"error": "invalid_request"}, status_code=HTTP_BAD_REQUEST)
|
return self.json({"error": "invalid_request"}, status_code=HTTP_BAD_REQUEST)
|
||||||
|
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token, remote_addr)
|
try:
|
||||||
|
access_token = hass.auth.async_create_access_token(
|
||||||
|
refresh_token, remote_addr
|
||||||
|
)
|
||||||
|
except InvalidAuthError as exc:
|
||||||
|
return self.json(
|
||||||
|
{"error": "access_denied", "error_description": str(exc)},
|
||||||
|
status_code=HTTP_FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
return self.json(
|
return self.json(
|
||||||
{
|
{
|
||||||
@ -482,7 +503,12 @@ async def websocket_create_long_lived_access_token(
|
|||||||
access_token_expiration=timedelta(days=msg["lifespan"]),
|
access_token_expiration=timedelta(days=msg["lifespan"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
try:
|
||||||
|
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||||
|
except InvalidAuthError as exc:
|
||||||
|
return websocket_api.error_message(
|
||||||
|
msg["id"], websocket_api.const.ERR_UNAUTHORIZED, str(exc)
|
||||||
|
)
|
||||||
|
|
||||||
connection.send_message(websocket_api.result_message(msg["id"], access_token))
|
connection.send_message(websocket_api.result_message(msg["id"], access_token))
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import jwt
|
|||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .const import KEY_AUTHENTICATED, KEY_HASS_USER
|
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||||
|
|
||||||
@ -62,6 +62,7 @@ def setup_auth(hass, app):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
request[KEY_HASS_USER] = refresh_token.user
|
request[KEY_HASS_USER] = refresh_token.user
|
||||||
|
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def async_validate_signed_request(request):
|
async def async_validate_signed_request(request):
|
||||||
@ -92,6 +93,7 @@ def setup_auth(hass, app):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
request[KEY_HASS_USER] = refresh_token.user
|
request[KEY_HASS_USER] = refresh_token.user
|
||||||
|
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
|
@ -2,3 +2,4 @@
|
|||||||
KEY_AUTHENTICATED = "ha_authenticated"
|
KEY_AUTHENTICATED = "ha_authenticated"
|
||||||
KEY_HASS = "hass"
|
KEY_HASS = "hass"
|
||||||
KEY_HASS_USER = "hass_user"
|
KEY_HASS_USER = "hass_user"
|
||||||
|
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id"
|
||||||
|
@ -5,6 +5,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.auth.const import GROUP_ID_ADMIN
|
from homeassistant.auth.const import GROUP_ID_ADMIN
|
||||||
from homeassistant.components.auth import indieauth
|
from homeassistant.components.auth import indieauth
|
||||||
|
from homeassistant.components.http.const import KEY_HASS_REFRESH_TOKEN_ID
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.components.http.view import HomeAssistantView
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
from homeassistant.const import HTTP_BAD_REQUEST, HTTP_FORBIDDEN
|
from homeassistant.const import HTTP_BAD_REQUEST, HTTP_FORBIDDEN
|
||||||
@ -132,7 +133,9 @@ class UserOnboardingView(_BaseOnboardingView):
|
|||||||
|
|
||||||
# Return authorization code for fetching tokens and connect
|
# Return authorization code for fetching tokens and connect
|
||||||
# during onboarding.
|
# during onboarding.
|
||||||
auth_code = hass.components.auth.create_auth_code(data["client_id"], user)
|
auth_code = hass.components.auth.create_auth_code(
|
||||||
|
data["client_id"], credentials
|
||||||
|
)
|
||||||
return self.json({"auth_code": auth_code})
|
return self.json({"auth_code": auth_code})
|
||||||
|
|
||||||
|
|
||||||
@ -183,7 +186,7 @@ class IntegrationOnboardingView(_BaseOnboardingView):
|
|||||||
async def post(self, request, data):
|
async def post(self, request, data):
|
||||||
"""Handle token creation."""
|
"""Handle token creation."""
|
||||||
hass = request.app["hass"]
|
hass = request.app["hass"]
|
||||||
user = request["hass_user"]
|
refresh_token_id = request[KEY_HASS_REFRESH_TOKEN_ID]
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
if self._async_is_done():
|
if self._async_is_done():
|
||||||
@ -201,8 +204,16 @@ class IntegrationOnboardingView(_BaseOnboardingView):
|
|||||||
"invalid client id or redirect uri", HTTP_BAD_REQUEST
|
"invalid client id or redirect uri", HTTP_BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
|
refresh_token = await hass.auth.async_get_refresh_token(refresh_token_id)
|
||||||
|
if refresh_token is None or refresh_token.credential is None:
|
||||||
|
return self.json_message(
|
||||||
|
"Credentials for user not available", HTTP_FORBIDDEN
|
||||||
|
)
|
||||||
|
|
||||||
# Return authorization code so we can redirect user and log them in
|
# Return authorization code so we can redirect user and log them in
|
||||||
auth_code = hass.components.auth.create_auth_code(data["client_id"], user)
|
auth_code = hass.components.auth.create_auth_code(
|
||||||
|
data["client_id"], refresh_token.credential
|
||||||
|
)
|
||||||
return self.json({"auth_code": auth_code})
|
return self.json({"auth_code": auth_code})
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ async def test_login(hass):
|
|||||||
result["flow_id"], {"pin": "123456"}
|
result["flow_id"], {"pin": "123456"}
|
||||||
)
|
)
|
||||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
assert result["data"].id == "mock-user"
|
assert result["data"].id == "mock-id"
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_flow(hass):
|
async def test_setup_flow(hass):
|
||||||
|
@ -229,7 +229,7 @@ async def test_login_flow_validates_mfa(hass):
|
|||||||
result["flow_id"], {"code": MOCK_CODE}
|
result["flow_id"], {"code": MOCK_CODE}
|
||||||
)
|
)
|
||||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
assert result["data"].id == "mock-user"
|
assert result["data"].id == "mock-id"
|
||||||
|
|
||||||
|
|
||||||
async def test_setup_user_notify_service(hass):
|
async def test_setup_user_notify_service(hass):
|
||||||
|
@ -127,7 +127,7 @@ async def test_login_flow_validates_mfa(hass):
|
|||||||
result["flow_id"], {"code": MOCK_CODE}
|
result["flow_id"], {"code": MOCK_CODE}
|
||||||
)
|
)
|
||||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
assert result["data"].id == "mock-user"
|
assert result["data"].id == "mock-id"
|
||||||
|
|
||||||
|
|
||||||
async def test_race_condition_in_data_loading(hass):
|
async def test_race_condition_in_data_loading(hass):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
"""Test the Trusted Networks auth provider."""
|
"""Test the Trusted Networks auth provider."""
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
@ -142,6 +143,16 @@ async def test_validate_access(provider):
|
|||||||
provider.async_validate_access(ip_address("2001:db8::ff00:42:8329"))
|
provider.async_validate_access(ip_address("2001:db8::ff00:42:8329"))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_validate_refresh_token(provider):
|
||||||
|
"""Verify re-validation of refresh token."""
|
||||||
|
with patch.object(provider, "async_validate_access") as mock:
|
||||||
|
with pytest.raises(tn_auth.InvalidAuthError):
|
||||||
|
provider.async_validate_refresh_token(Mock(), None)
|
||||||
|
|
||||||
|
provider.async_validate_refresh_token(Mock(), "127.0.0.1")
|
||||||
|
mock.assert_called_once_with(ip_address("127.0.0.1"))
|
||||||
|
|
||||||
|
|
||||||
async def test_login_flow(manager, provider):
|
async def test_login_flow(manager, provider):
|
||||||
"""Test login flow."""
|
"""Test login flow."""
|
||||||
owner = await manager.async_create_user("test-owner")
|
owner = await manager.async_create_user("test-owner")
|
||||||
|
@ -37,6 +37,7 @@ async def test_loading_no_group_data_format(hass, hass_storage):
|
|||||||
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
||||||
"token": "some-token",
|
"token": "some-token",
|
||||||
"user_id": "user-id",
|
"user_id": "user-id",
|
||||||
|
"version": "1.2.3",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"access_token_expiration": 1800.0,
|
"access_token_expiration": 1800.0,
|
||||||
@ -87,12 +88,14 @@ async def test_loading_no_group_data_format(hass, hass_storage):
|
|||||||
assert len(owner.refresh_tokens) == 1
|
assert len(owner.refresh_tokens) == 1
|
||||||
owner_token = list(owner.refresh_tokens.values())[0]
|
owner_token = list(owner.refresh_tokens.values())[0]
|
||||||
assert owner_token.id == "user-token-id"
|
assert owner_token.id == "user-token-id"
|
||||||
|
assert owner_token.version == "1.2.3"
|
||||||
|
|
||||||
assert system.system_generated is True
|
assert system.system_generated is True
|
||||||
assert system.groups == []
|
assert system.groups == []
|
||||||
assert len(system.refresh_tokens) == 1
|
assert len(system.refresh_tokens) == 1
|
||||||
system_token = list(system.refresh_tokens.values())[0]
|
system_token = list(system.refresh_tokens.values())[0]
|
||||||
assert system_token.id == "system-token-id"
|
assert system_token.id == "system-token-id"
|
||||||
|
assert system_token.version is None
|
||||||
|
|
||||||
|
|
||||||
async def test_loading_all_access_group_data_format(hass, hass_storage):
|
async def test_loading_all_access_group_data_format(hass, hass_storage):
|
||||||
@ -129,6 +132,7 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
|
|||||||
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
||||||
"token": "some-token",
|
"token": "some-token",
|
||||||
"user_id": "user-id",
|
"user_id": "user-id",
|
||||||
|
"version": "1.2.3",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"access_token_expiration": 1800.0,
|
"access_token_expiration": 1800.0,
|
||||||
@ -139,6 +143,7 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
|
|||||||
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
"last_used_at": "2018-10-03T13:43:19.774712+00:00",
|
||||||
"token": "some-token",
|
"token": "some-token",
|
||||||
"user_id": "system-id",
|
"user_id": "system-id",
|
||||||
|
"version": None,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"access_token_expiration": 1800.0,
|
"access_token_expiration": 1800.0,
|
||||||
@ -179,12 +184,14 @@ async def test_loading_all_access_group_data_format(hass, hass_storage):
|
|||||||
assert len(owner.refresh_tokens) == 1
|
assert len(owner.refresh_tokens) == 1
|
||||||
owner_token = list(owner.refresh_tokens.values())[0]
|
owner_token = list(owner.refresh_tokens.values())[0]
|
||||||
assert owner_token.id == "user-token-id"
|
assert owner_token.id == "user-token-id"
|
||||||
|
assert owner_token.version == "1.2.3"
|
||||||
|
|
||||||
assert system.system_generated is True
|
assert system.system_generated is True
|
||||||
assert system.groups == []
|
assert system.groups == []
|
||||||
assert len(system.refresh_tokens) == 1
|
assert len(system.refresh_tokens) == 1
|
||||||
system_token = list(system.refresh_tokens.values())[0]
|
system_token = list(system.refresh_tokens.values())[0]
|
||||||
assert system_token.id == "system-token-id"
|
assert system_token.id == "system-token-id"
|
||||||
|
assert system_token.version is None
|
||||||
|
|
||||||
|
|
||||||
async def test_loading_empty_data(hass, hass_storage):
|
async def test_loading_empty_data(hass, hass_storage):
|
||||||
|
@ -7,7 +7,12 @@ import pytest
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import auth, data_entry_flow
|
from homeassistant import auth, data_entry_flow
|
||||||
from homeassistant.auth import auth_store, const as auth_const, models as auth_models
|
from homeassistant.auth import (
|
||||||
|
InvalidAuthError,
|
||||||
|
auth_store,
|
||||||
|
const as auth_const,
|
||||||
|
models as auth_models,
|
||||||
|
)
|
||||||
from homeassistant.auth.const import MFA_SESSION_EXPIRATION
|
from homeassistant.auth.const import MFA_SESSION_EXPIRATION
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
@ -162,7 +167,10 @@ async def test_create_new_user(hass):
|
|||||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||||
)
|
)
|
||||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
user = step["result"]
|
credential = step["result"]
|
||||||
|
assert credential is not None
|
||||||
|
|
||||||
|
user = await manager.async_get_or_create_user(credential)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.is_owner is False
|
assert user.is_owner is False
|
||||||
assert user.name == "Test Name"
|
assert user.name == "Test Name"
|
||||||
@ -229,7 +237,8 @@ async def test_login_as_existing_user(mock_hass):
|
|||||||
)
|
)
|
||||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
|
|
||||||
user = step["result"]
|
credential = step["result"]
|
||||||
|
user = await manager.async_get_user_by_credentials(credential)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
assert user.id == "mock-user"
|
assert user.id == "mock-user"
|
||||||
assert user.is_owner is False
|
assert user.is_owner is False
|
||||||
@ -259,7 +268,8 @@ async def test_linking_user_to_two_auth_providers(hass, hass_storage):
|
|||||||
step = await manager.login_flow.async_configure(
|
step = await manager.login_flow.async_configure(
|
||||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||||
)
|
)
|
||||||
user = step["result"]
|
credential = step["result"]
|
||||||
|
user = await manager.async_get_or_create_user(credential)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
|
|
||||||
step = await manager.login_flow.async_init(
|
step = await manager.login_flow.async_init(
|
||||||
@ -293,13 +303,19 @@ async def test_saving_loading(hass, hass_storage):
|
|||||||
step = await manager.login_flow.async_configure(
|
step = await manager.login_flow.async_configure(
|
||||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||||
)
|
)
|
||||||
user = step["result"]
|
credential = step["result"]
|
||||||
|
user = await manager.async_get_or_create_user(credential)
|
||||||
|
|
||||||
await manager.async_activate_user(user)
|
await manager.async_activate_user(user)
|
||||||
# the first refresh token will be used to create access token
|
# the first refresh token will be used to create access token
|
||||||
refresh_token = await manager.async_create_refresh_token(user, CLIENT_ID)
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user, CLIENT_ID, credential=credential
|
||||||
|
)
|
||||||
manager.async_create_access_token(refresh_token, "192.168.0.1")
|
manager.async_create_access_token(refresh_token, "192.168.0.1")
|
||||||
# the second refresh token will not be used
|
# the second refresh token will not be used
|
||||||
await manager.async_create_refresh_token(user, "dummy-client")
|
await manager.async_create_refresh_token(
|
||||||
|
user, "dummy-client", credential=credential
|
||||||
|
)
|
||||||
|
|
||||||
await flush_store(manager._store._store)
|
await flush_store(manager._store._store)
|
||||||
|
|
||||||
@ -452,6 +468,46 @@ async def test_refresh_token_type_long_lived_access_token(hass):
|
|||||||
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
assert token.token_type == auth_models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_provider_validation(mock_hass):
|
||||||
|
"""Test that creating access token from refresh token checks with provider."""
|
||||||
|
manager = await auth.auth_manager_from_config(
|
||||||
|
mock_hass,
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "insecure_example",
|
||||||
|
"users": [{"username": "test-user", "password": "test-pass"}],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
credential = auth_models.Credentials(
|
||||||
|
id="mock-credential-id",
|
||||||
|
auth_provider_type="insecure_example",
|
||||||
|
auth_provider_id=None,
|
||||||
|
data={"username": "test-user"},
|
||||||
|
is_new=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
user = MockUser().add_to_auth_manager(manager)
|
||||||
|
user.credentials.append(credential)
|
||||||
|
refresh_token = await manager.async_create_refresh_token(
|
||||||
|
user, CLIENT_ID, credential=credential
|
||||||
|
)
|
||||||
|
ip = "127.0.0.1"
|
||||||
|
|
||||||
|
assert manager.async_create_access_token(refresh_token, ip) is not None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.auth.providers.insecure_example.ExampleAuthProvider.async_validate_refresh_token",
|
||||||
|
side_effect=InvalidAuthError("Invalid access"),
|
||||||
|
) as call:
|
||||||
|
with pytest.raises(InvalidAuthError):
|
||||||
|
manager.async_create_access_token(refresh_token, ip)
|
||||||
|
|
||||||
|
call.assert_called_with(refresh_token, ip)
|
||||||
|
|
||||||
|
|
||||||
async def test_cannot_deactive_owner(mock_hass):
|
async def test_cannot_deactive_owner(mock_hass):
|
||||||
"""Test that we cannot deactivate the owner."""
|
"""Test that we cannot deactivate the owner."""
|
||||||
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
manager = await auth.auth_manager_from_config(mock_hass, [], [])
|
||||||
@ -626,14 +682,10 @@ async def test_login_with_auth_module(mock_hass):
|
|||||||
step["flow_id"], {"pin": "test-pin"}
|
step["flow_id"], {"pin": "test-pin"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finally passed, get user
|
# Finally passed, get credential
|
||||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
user = step["result"]
|
assert step["result"]
|
||||||
assert user is not None
|
assert step["result"].id == "mock-id"
|
||||||
assert user.id == "mock-user"
|
|
||||||
assert user.is_owner is False
|
|
||||||
assert user.is_active is False
|
|
||||||
assert user.name == "Paulus"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_login_with_multi_auth_module(mock_hass):
|
async def test_login_with_multi_auth_module(mock_hass):
|
||||||
@ -703,14 +755,10 @@ async def test_login_with_multi_auth_module(mock_hass):
|
|||||||
step["flow_id"], {"pin": "test-pin2"}
|
step["flow_id"], {"pin": "test-pin2"}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Finally passed, get user
|
# Finally passed, get credential
|
||||||
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert step["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
user = step["result"]
|
assert step["result"]
|
||||||
assert user is not None
|
assert step["result"].id == "mock-id"
|
||||||
assert user.id == "mock-user"
|
|
||||||
assert user.is_owner is False
|
|
||||||
assert user.is_active is False
|
|
||||||
assert user.name == "Paulus"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_module_expired_session(mock_hass):
|
async def test_auth_module_expired_session(mock_hass):
|
||||||
@ -792,7 +840,8 @@ async def test_enable_mfa_for_user(hass, hass_storage):
|
|||||||
step = await manager.login_flow.async_configure(
|
step = await manager.login_flow.async_configure(
|
||||||
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
step["flow_id"], {"username": "test-user", "password": "test-pass"}
|
||||||
)
|
)
|
||||||
user = step["result"]
|
credential = step["result"]
|
||||||
|
user = await manager.async_get_or_create_user(credential)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
|
|
||||||
# new user don't have mfa enabled
|
# new user don't have mfa enabled
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.auth.models import Credentials
|
from homeassistant.auth.models import Credentials
|
||||||
from homeassistant.components import auth
|
from homeassistant.components import auth
|
||||||
from homeassistant.components.auth import RESULT_TYPE_USER
|
from homeassistant.components.auth import RESULT_TYPE_USER
|
||||||
@ -13,6 +14,24 @@ from . import async_setup_auth
|
|||||||
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
from tests.common import CLIENT_ID, CLIENT_REDIRECT_URI, MockUser
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_user_refresh_token(hass):
|
||||||
|
"""Create a testing user with a connected credential."""
|
||||||
|
user = await hass.auth.async_create_user("Test User")
|
||||||
|
|
||||||
|
credential = Credentials(
|
||||||
|
id="mock-credential-id",
|
||||||
|
auth_provider_type="insecure_example",
|
||||||
|
auth_provider_id=None,
|
||||||
|
data={"username": "test-user"},
|
||||||
|
is_new=False,
|
||||||
|
)
|
||||||
|
user.credentials.append(credential)
|
||||||
|
|
||||||
|
return await hass.auth.async_create_refresh_token(
|
||||||
|
user, CLIENT_ID, credential=credential
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
||||||
"""Test logging in with new user and refreshing tokens."""
|
"""Test logging in with new user and refreshing tokens."""
|
||||||
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||||
@ -107,12 +126,6 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
|
|||||||
|
|
||||||
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||||
user = refresh_token.user
|
user = refresh_token.user
|
||||||
credential = Credentials(
|
|
||||||
auth_provider_type="homeassistant", auth_provider_id=None, data={}, id="test-id"
|
|
||||||
)
|
|
||||||
user.credentials.append(credential)
|
|
||||||
assert len(user.credentials) == 1
|
|
||||||
|
|
||||||
client = await hass_ws_client(hass, hass_access_token)
|
client = await hass_ws_client(hass, hass_access_token)
|
||||||
|
|
||||||
await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER})
|
await client.send_json({"id": 5, "type": auth.WS_TYPE_CURRENT_USER})
|
||||||
@ -185,8 +198,7 @@ async def test_refresh_token_system_generated(hass, aiohttp_client):
|
|||||||
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
||||||
"""Test that we verify client ID."""
|
"""Test that we verify client ID."""
|
||||||
client = await async_setup_auth(hass, aiohttp_client)
|
client = await async_setup_auth(hass, aiohttp_client)
|
||||||
user = await hass.auth.async_create_user("Test User")
|
refresh_token = await async_setup_user_refresh_token(hass)
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
|
|
||||||
|
|
||||||
# No client ID
|
# No client ID
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
@ -229,11 +241,37 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_provider_rejected(
|
||||||
|
hass, aiohttp_client, hass_admin_user, hass_admin_credential
|
||||||
|
):
|
||||||
|
"""Test that we verify client ID."""
|
||||||
|
client = await async_setup_auth(hass, aiohttp_client)
|
||||||
|
refresh_token = await async_setup_user_refresh_token(hass)
|
||||||
|
|
||||||
|
# Rejected by provider
|
||||||
|
with patch(
|
||||||
|
"homeassistant.auth.providers.insecure_example.ExampleAuthProvider.async_validate_refresh_token",
|
||||||
|
side_effect=InvalidAuthError("Invalid access"),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/token",
|
||||||
|
data={
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token.token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 403
|
||||||
|
result = await resp.json()
|
||||||
|
assert result["error"] == "access_denied"
|
||||||
|
assert result["error_description"] == "Invalid access"
|
||||||
|
|
||||||
|
|
||||||
async def test_revoking_refresh_token(hass, aiohttp_client):
|
async def test_revoking_refresh_token(hass, aiohttp_client):
|
||||||
"""Test that we can revoke refresh tokens."""
|
"""Test that we can revoke refresh tokens."""
|
||||||
client = await async_setup_auth(hass, aiohttp_client)
|
client = await async_setup_auth(hass, aiohttp_client)
|
||||||
user = await hass.auth.async_create_user("Test User")
|
refresh_token = await async_setup_user_refresh_token(hass)
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(user, CLIENT_ID)
|
|
||||||
|
|
||||||
# Test that we can create an access token
|
# Test that we can create an access token
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
|
@ -48,7 +48,9 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
|
|||||||
id="hij", name="Inactive User", is_active=False, groups=[group]
|
id="hij", name="Inactive User", is_active=False, groups=[group]
|
||||||
).add_to_hass(hass)
|
).add_to_hass(hass)
|
||||||
|
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(owner, CLIENT_ID)
|
refresh_token = await hass.auth.async_create_refresh_token(
|
||||||
|
owner, CLIENT_ID, credential=owner.credentials[0]
|
||||||
|
)
|
||||||
access_token = hass.auth.async_create_access_token(refresh_token)
|
access_token = hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
client = await hass_ws_client(hass, access_token)
|
client = await hass_ws_client(hass, access_token)
|
||||||
@ -60,13 +62,13 @@ async def test_list(hass, hass_ws_client, hass_admin_user):
|
|||||||
assert len(data) == 4
|
assert len(data) == 4
|
||||||
assert data[0] == {
|
assert data[0] == {
|
||||||
"id": hass_admin_user.id,
|
"id": hass_admin_user.id,
|
||||||
"username": None,
|
"username": "admin",
|
||||||
"name": "Mock User",
|
"name": "Mock User",
|
||||||
"is_owner": False,
|
"is_owner": False,
|
||||||
"is_active": True,
|
"is_active": True,
|
||||||
"system_generated": False,
|
"system_generated": False,
|
||||||
"group_ids": [group.id for group in hass_admin_user.groups],
|
"group_ids": [group.id for group in hass_admin_user.groups],
|
||||||
"credentials": [],
|
"credentials": [{"type": "homeassistant"}],
|
||||||
}
|
}
|
||||||
assert data[1] == {
|
assert data[1] == {
|
||||||
"id": owner.id,
|
"id": owner.id,
|
||||||
|
@ -4,24 +4,19 @@ import pytest
|
|||||||
from homeassistant.auth.providers import homeassistant as prov_ha
|
from homeassistant.auth.providers import homeassistant as prov_ha
|
||||||
from homeassistant.components.config import auth_provider_homeassistant as auth_ha
|
from homeassistant.components.config import auth_provider_homeassistant as auth_ha
|
||||||
|
|
||||||
from tests.common import CLIENT_ID, MockUser, register_auth_provider
|
from tests.common import CLIENT_ID, MockUser
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_config(hass):
|
async def setup_config(hass, local_auth):
|
||||||
"""Fixture that sets up the auth provider homeassistant module."""
|
"""Fixture that sets up the auth provider ."""
|
||||||
hass.loop.run_until_complete(
|
await auth_ha.async_setup(hass)
|
||||||
register_auth_provider(hass, {"type": "homeassistant"})
|
|
||||||
)
|
|
||||||
hass.loop.run_until_complete(auth_ha.async_setup(hass))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def auth_provider(hass):
|
async def auth_provider(local_auth):
|
||||||
"""Hass auth provider."""
|
"""Hass auth provider."""
|
||||||
provider = hass.auth.auth_providers[0]
|
return local_auth
|
||||||
await provider.async_initialize()
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -34,8 +29,8 @@ async def owner_access_token(hass, hass_owner_user):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def test_user_credential(hass, auth_provider):
|
async def hass_admin_credential(hass, auth_provider):
|
||||||
"""Add a test user."""
|
"""Overload credentials to admin user."""
|
||||||
await hass.async_add_executor_job(
|
await hass.async_add_executor_job(
|
||||||
auth_provider.data.add_auth, "test-user", "test-pass"
|
auth_provider.data.add_auth, "test-user", "test-pass"
|
||||||
)
|
)
|
||||||
@ -124,7 +119,7 @@ async def test_create_auth(hass, hass_ws_client, hass_storage):
|
|||||||
"id": 5,
|
"id": 5,
|
||||||
"type": "config/auth_provider/homeassistant/create",
|
"type": "config/auth_provider/homeassistant/create",
|
||||||
"user_id": user.id,
|
"user_id": user.id,
|
||||||
"username": "test-user",
|
"username": "test-user2",
|
||||||
"password": "test-pass",
|
"password": "test-pass",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -135,10 +130,10 @@ async def test_create_auth(hass, hass_ws_client, hass_storage):
|
|||||||
creds = user.credentials[0]
|
creds = user.credentials[0]
|
||||||
assert creds.auth_provider_type == "homeassistant"
|
assert creds.auth_provider_type == "homeassistant"
|
||||||
assert creds.auth_provider_id is None
|
assert creds.auth_provider_id is None
|
||||||
assert creds.data == {"username": "test-user"}
|
assert creds.data == {"username": "test-user2"}
|
||||||
assert prov_ha.STORAGE_KEY in hass_storage
|
assert prov_ha.STORAGE_KEY in hass_storage
|
||||||
entry = hass_storage[prov_ha.STORAGE_KEY]["data"]["users"][0]
|
entry = hass_storage[prov_ha.STORAGE_KEY]["data"]["users"][1]
|
||||||
assert entry["username"] == "test-user"
|
assert entry["username"] == "test-user2"
|
||||||
|
|
||||||
|
|
||||||
async def test_create_auth_duplicate_username(hass, hass_ws_client, hass_storage):
|
async def test_create_auth_duplicate_username(hass, hass_ws_client, hass_storage):
|
||||||
@ -242,7 +237,7 @@ async def test_delete_unknown_auth(hass, hass_ws_client):
|
|||||||
{
|
{
|
||||||
"id": 5,
|
"id": 5,
|
||||||
"type": "config/auth_provider/homeassistant/delete",
|
"type": "config/auth_provider/homeassistant/delete",
|
||||||
"username": "test-user",
|
"username": "test-user2",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -251,12 +246,8 @@ async def test_delete_unknown_auth(hass, hass_ws_client):
|
|||||||
assert result["error"]["code"] == "auth_not_found"
|
assert result["error"]["code"] == "auth_not_found"
|
||||||
|
|
||||||
|
|
||||||
async def test_change_password(
|
async def test_change_password(hass, hass_ws_client, auth_provider):
|
||||||
hass, hass_ws_client, hass_admin_user, auth_provider, test_user_credential
|
|
||||||
):
|
|
||||||
"""Test that change password succeeds with valid password."""
|
"""Test that change password succeeds with valid password."""
|
||||||
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
|
|
||||||
|
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
{
|
{
|
||||||
@ -273,10 +264,9 @@ async def test_change_password(
|
|||||||
|
|
||||||
|
|
||||||
async def test_change_password_wrong_pw(
|
async def test_change_password_wrong_pw(
|
||||||
hass, hass_ws_client, hass_admin_user, auth_provider, test_user_credential
|
hass, hass_ws_client, hass_admin_user, auth_provider
|
||||||
):
|
):
|
||||||
"""Test that change password fails with invalid password."""
|
"""Test that change password fails with invalid password."""
|
||||||
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
|
|
||||||
|
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
@ -295,8 +285,9 @@ async def test_change_password_wrong_pw(
|
|||||||
await auth_provider.async_validate_login("test-user", "new-pass")
|
await auth_provider.async_validate_login("test-user", "new-pass")
|
||||||
|
|
||||||
|
|
||||||
async def test_change_password_no_creds(hass, hass_ws_client):
|
async def test_change_password_no_creds(hass, hass_ws_client, hass_admin_user):
|
||||||
"""Test that change password fails with no credentials."""
|
"""Test that change password fails with no credentials."""
|
||||||
|
hass_admin_user.credentials.clear()
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
@ -313,9 +304,7 @@ async def test_change_password_no_creds(hass, hass_ws_client):
|
|||||||
assert result["error"]["code"] == "credentials_not_found"
|
assert result["error"]["code"] == "credentials_not_found"
|
||||||
|
|
||||||
|
|
||||||
async def test_admin_change_password_not_owner(
|
async def test_admin_change_password_not_owner(hass, hass_ws_client, auth_provider):
|
||||||
hass, hass_ws_client, auth_provider, test_user_credential
|
|
||||||
):
|
|
||||||
"""Test that change password fails when not owner."""
|
"""Test that change password fails when not owner."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
@ -358,6 +347,8 @@ async def test_admin_change_password_no_cred(
|
|||||||
hass, hass_ws_client, owner_access_token, hass_admin_user
|
hass, hass_ws_client, owner_access_token, hass_admin_user
|
||||||
):
|
):
|
||||||
"""Test that change password fails with unknown credential."""
|
"""Test that change password fails with unknown credential."""
|
||||||
|
|
||||||
|
hass_admin_user.credentials.clear()
|
||||||
client = await hass_ws_client(hass, owner_access_token)
|
client = await hass_ws_client(hass, owner_access_token)
|
||||||
|
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
@ -379,12 +370,9 @@ async def test_admin_change_password(
|
|||||||
hass_ws_client,
|
hass_ws_client,
|
||||||
owner_access_token,
|
owner_access_token,
|
||||||
auth_provider,
|
auth_provider,
|
||||||
test_user_credential,
|
|
||||||
hass_admin_user,
|
hass_admin_user,
|
||||||
):
|
):
|
||||||
"""Test that owners can change any password."""
|
"""Test that owners can change any password."""
|
||||||
await hass.auth.async_link_user(hass_admin_user, test_user_credential)
|
|
||||||
|
|
||||||
client = await hass_ws_client(hass, owner_access_token)
|
client = await hass_ws_client(hass, owner_access_token)
|
||||||
|
|
||||||
await client.send_json(
|
await client.send_json(
|
||||||
|
@ -247,7 +247,7 @@ async def test_onboarding_user_race(hass, hass_storage, aiohttp_client):
|
|||||||
assert sorted([res1.status, res2.status]) == [200, HTTP_FORBIDDEN]
|
assert sorted([res1.status, res2.status]) == [200, HTTP_FORBIDDEN]
|
||||||
|
|
||||||
|
|
||||||
async def test_onboarding_integration(hass, hass_storage, hass_client):
|
async def test_onboarding_integration(hass, hass_storage, hass_client, hass_admin_user):
|
||||||
"""Test finishing integration step."""
|
"""Test finishing integration step."""
|
||||||
mock_storage(hass_storage, {"done": [const.STEP_USER]})
|
mock_storage(hass_storage, {"done": [const.STEP_USER]})
|
||||||
|
|
||||||
@ -288,6 +288,28 @@ async def test_onboarding_integration(hass, hass_storage, hass_client):
|
|||||||
assert len(user.refresh_tokens) == 2, user
|
assert len(user.refresh_tokens) == 2, user
|
||||||
|
|
||||||
|
|
||||||
|
async def test_onboarding_integration_missing_credential(
|
||||||
|
hass, hass_storage, hass_client, hass_access_token
|
||||||
|
):
|
||||||
|
"""Test that we fail integration step if user is missing credentials."""
|
||||||
|
mock_storage(hass_storage, {"done": [const.STEP_USER]})
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "onboarding", {})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
refresh_token.credential = None
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/onboarding/integration",
|
||||||
|
json={"client_id": CLIENT_ID, "redirect_uri": CLIENT_REDIRECT_URI},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 403
|
||||||
|
|
||||||
|
|
||||||
async def test_onboarding_integration_invalid_redirect_uri(
|
async def test_onboarding_integration_invalid_redirect_uri(
|
||||||
hass, hass_storage, hass_client
|
hass, hass_storage, hass_client
|
||||||
):
|
):
|
||||||
|
@ -14,6 +14,7 @@ import requests_mock as _requests_mock
|
|||||||
|
|
||||||
from homeassistant import core as ha, loader, runner, util
|
from homeassistant import core as ha, loader, runner, util
|
||||||
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
from homeassistant.auth.const import GROUP_ID_ADMIN, GROUP_ID_READ_ONLY
|
||||||
|
from homeassistant.auth.models import Credentials
|
||||||
from homeassistant.auth.providers import homeassistant, legacy_api_password
|
from homeassistant.auth.providers import homeassistant, legacy_api_password
|
||||||
from homeassistant.components import mqtt
|
from homeassistant.components import mqtt
|
||||||
from homeassistant.components.websocket_api.auth import (
|
from homeassistant.components.websocket_api.auth import (
|
||||||
@ -201,10 +202,20 @@ def mock_device_tracker_conf():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hass_access_token(hass, hass_admin_user):
|
async def hass_admin_credential(hass, local_auth):
|
||||||
|
"""Provide credentials for admin user."""
|
||||||
|
await hass.async_add_executor_job(local_auth.data.add_auth, "admin", "admin-pass")
|
||||||
|
|
||||||
|
return await local_auth.async_get_or_create_credentials({"username": "admin"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def hass_access_token(hass, hass_admin_user, hass_admin_credential):
|
||||||
"""Return an access token to access Home Assistant."""
|
"""Return an access token to access Home Assistant."""
|
||||||
refresh_token = hass.loop.run_until_complete(
|
await hass.auth.async_link_user(hass_admin_user, hass_admin_credential)
|
||||||
hass.auth.async_create_refresh_token(hass_admin_user, CLIENT_ID)
|
|
||||||
|
refresh_token = await hass.auth.async_create_refresh_token(
|
||||||
|
hass_admin_user, CLIENT_ID, credential=hass_admin_credential
|
||||||
)
|
)
|
||||||
return hass.auth.async_create_access_token(refresh_token)
|
return hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
@ -234,10 +245,21 @@ def hass_read_only_user(hass, local_auth):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def hass_read_only_access_token(hass, hass_read_only_user):
|
def hass_read_only_access_token(hass, hass_read_only_user, local_auth):
|
||||||
"""Return a Home Assistant read only user."""
|
"""Return a Home Assistant read only user."""
|
||||||
|
credential = Credentials(
|
||||||
|
id="mock-readonly-credential-id",
|
||||||
|
auth_provider_type="homeassistant",
|
||||||
|
auth_provider_id=None,
|
||||||
|
data={"username": "readonly"},
|
||||||
|
is_new=False,
|
||||||
|
)
|
||||||
|
hass_read_only_user.credentials.append(credential)
|
||||||
|
|
||||||
refresh_token = hass.loop.run_until_complete(
|
refresh_token = hass.loop.run_until_complete(
|
||||||
hass.auth.async_create_refresh_token(hass_read_only_user, CLIENT_ID)
|
hass.auth.async_create_refresh_token(
|
||||||
|
hass_read_only_user, CLIENT_ID, credential=credential
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return hass.auth.async_create_access_token(refresh_token)
|
return hass.auth.async_create_access_token(refresh_token)
|
||||||
|
|
||||||
@ -260,6 +282,7 @@ def local_auth(hass):
|
|||||||
prv = homeassistant.HassAuthProvider(
|
prv = homeassistant.HassAuthProvider(
|
||||||
hass, hass.auth._store, {"type": "homeassistant"}
|
hass, hass.auth._store, {"type": "homeassistant"}
|
||||||
)
|
)
|
||||||
|
hass.loop.run_until_complete(prv.async_initialize())
|
||||||
hass.auth._providers[(prv.type, prv.id)] = prv
|
hass.auth._providers[(prv.type, prv.id)] = prv
|
||||||
return prv
|
return prv
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user