1
mirror of https://github.com/home-assistant/core synced 2024-10-01 05:30:36 +02:00
ha-core/homeassistant/auth/auth_store.py

621 lines
21 KiB
Python
Raw Normal View History

2018-07-13 11:43:08 +02:00
"""Storage for auth models."""
2021-03-17 21:46:07 +01:00
from __future__ import annotations
import asyncio
from collections import OrderedDict
2018-07-13 11:43:08 +02:00
from datetime import timedelta
import hmac
from logging import getLogger
2021-03-17 21:46:07 +01:00
from typing import Any
2018-07-13 11:43:08 +02:00
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.storage import Store
2018-07-13 11:43:08 +02:00
from homeassistant.util import dt as dt_util
from . import models
from .const import (
ACCESS_TOKEN_EXPIRATION,
GROUP_ID_ADMIN,
GROUP_ID_READ_ONLY,
GROUP_ID_USER,
)
2022-01-10 17:10:46 +01:00
from .permissions import system_policies
from .permissions.models import PermissionLookup
from .permissions.types import PolicyType
2018-07-13 11:43:08 +02:00
STORAGE_VERSION = 1
2019-07-31 21:25:30 +02:00
STORAGE_KEY = "auth"
GROUP_NAME_ADMIN = "Administrators"
GROUP_NAME_USER = "Users"
2019-07-31 21:25:30 +02:00
GROUP_NAME_READ_ONLY = "Read Only"
2018-07-13 11:43:08 +02:00
class AuthStore:
"""Stores authentication info.
Any mutation to an object should happen inside the auth store.
The auth store is lazy. It won't load the data from disk until a method is
called that needs it.
"""
def __init__(self, hass: HomeAssistant) -> None:
2018-07-13 11:43:08 +02:00
"""Initialize the auth store."""
self.hass = hass
2021-03-17 21:46:07 +01:00
self._users: dict[str, models.User] | None = None
self._groups: dict[str, models.Group] | None = None
self._perm_lookup: PermissionLookup | None = None
2022-07-09 22:32:57 +02:00
self._store = Store[dict[str, list[dict[str, Any]]]](
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
2019-07-31 21:25:30 +02:00
)
self._lock = asyncio.Lock()
2018-07-13 11:43:08 +02:00
2021-03-17 21:46:07 +01:00
async def async_get_groups(self) -> list[models.Group]:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return list(self._groups.values())
2021-03-17 21:46:07 +01:00
async def async_get_group(self, group_id: str) -> models.Group | None:
"""Retrieve all users."""
if self._groups is None:
await self._async_load()
assert self._groups is not None
return self._groups.get(group_id)
2021-03-17 21:46:07 +01:00
async def async_get_users(self) -> list[models.User]:
2018-07-13 11:43:08 +02:00
"""Retrieve all users."""
if self._users is None:
await self._async_load()
assert self._users is not None
2018-07-13 11:43:08 +02:00
return list(self._users.values())
2021-03-17 21:46:07 +01:00
async def async_get_user(self, user_id: str) -> models.User | None:
2018-07-13 11:43:08 +02:00
"""Retrieve a user by id."""
if self._users is None:
await self._async_load()
assert self._users is not None
2018-07-13 11:43:08 +02:00
return self._users.get(user_id)
async def async_create_user(
2019-07-31 21:25:30 +02:00
self,
2021-03-17 21:46:07 +01:00
name: str | None,
is_owner: bool | None = None,
is_active: bool | None = None,
system_generated: bool | None = None,
credentials: models.Credentials | None = None,
group_ids: list[str] | None = None,
2021-11-29 23:01:03 +01:00
local_only: bool | None = None,
2019-07-31 21:25:30 +02:00
) -> models.User:
2018-07-13 11:43:08 +02:00
"""Create a new user."""
if self._users is None:
await self._async_load()
assert self._users is not None
assert self._groups is not None
2018-07-13 11:43:08 +02:00
groups = []
2019-07-31 21:25:30 +02:00
for group_id in group_ids or []:
2021-09-19 01:31:35 +02:00
if (group := self._groups.get(group_id)) is None:
raise ValueError(f"Invalid group specified {group_id}")
groups.append(group)
2021-03-17 21:46:07 +01:00
kwargs: dict[str, Any] = {
2019-07-31 21:25:30 +02:00
"name": name,
# Until we get group management, we just put everyone in the
# same group.
2019-07-31 21:25:30 +02:00
"groups": groups,
"perm_lookup": self._perm_lookup,
}
2018-07-13 11:43:08 +02:00
2021-11-29 23:01:03 +01:00
for attr_name, value in (
("is_owner", is_owner),
("is_active", is_active),
("local_only", local_only),
("system_generated", system_generated),
):
if value is not None:
kwargs[attr_name] = value
2018-07-13 11:43:08 +02:00
new_user = models.User(**kwargs)
self._users[new_user.id] = new_user
if credentials is None:
self._async_schedule_save()
2018-07-13 11:43:08 +02:00
return new_user
# Saving is done inside the link.
await self.async_link_user(new_user, credentials)
return new_user
2019-07-31 21:25:30 +02:00
async def async_link_user(
self, user: models.User, credentials: models.Credentials
) -> None:
2018-07-13 11:43:08 +02:00
"""Add credentials to an existing user."""
user.credentials.append(credentials)
self._async_schedule_save()
2018-07-13 11:43:08 +02:00
credentials.is_new = False
async def async_remove_user(self, user: models.User) -> None:
2018-07-13 11:43:08 +02:00
"""Remove a user."""
if self._users is None:
await self._async_load()
assert self._users is not None
2018-07-13 11:43:08 +02:00
self._users.pop(user.id)
self._async_schedule_save()
2018-07-13 11:43:08 +02:00
async def async_update_user(
2019-07-31 21:25:30 +02:00
self,
user: models.User,
2021-03-17 21:46:07 +01:00
name: str | None = None,
is_active: bool | None = None,
group_ids: list[str] | None = None,
2021-11-29 23:01:03 +01:00
local_only: bool | None = None,
2019-07-31 21:25:30 +02:00
) -> None:
"""Update a user."""
assert self._groups is not None
if group_ids is not None:
groups = []
for grid in group_ids:
2021-09-19 01:31:35 +02:00
if (group := self._groups.get(grid)) is None:
raise ValueError("Invalid group specified.")
groups.append(group)
user.groups = groups
user.invalidate_permission_cache()
2021-11-29 23:01:03 +01:00
for attr_name, value in (
("name", name),
("is_active", is_active),
("local_only", local_only),
):
if value is not None:
setattr(user, attr_name, value)
self._async_schedule_save()
async def async_activate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = True
self._async_schedule_save()
async def async_deactivate_user(self, user: models.User) -> None:
"""Activate a user."""
user.is_active = False
self._async_schedule_save()
2019-07-31 21:25:30 +02:00
async def async_remove_credentials(self, credentials: models.Credentials) -> None:
"""Remove credentials."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values():
found = None
for index, cred in enumerate(user.credentials):
if cred is credentials:
found = index
break
if found is not None:
user.credentials.pop(found)
break
self._async_schedule_save()
async def async_create_refresh_token(
2019-07-31 21:25:30 +02:00
self,
user: models.User,
2021-03-17 21:46:07 +01:00
client_id: str | None = None,
client_name: str | None = None,
client_icon: str | None = None,
2019-07-31 21:25:30 +02:00
token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
2021-03-17 21:46:07 +01:00
credential: models.Credentials | None = None,
2019-07-31 21:25:30 +02:00
) -> models.RefreshToken:
2018-07-13 11:43:08 +02:00
"""Create a new token for a user."""
2021-03-17 21:46:07 +01:00
kwargs: dict[str, Any] = {
2019-07-31 21:25:30 +02:00
"user": user,
"client_id": client_id,
"token_type": token_type,
"access_token_expiration": access_token_expiration,
"credential": credential,
}
if client_name:
2019-07-31 21:25:30 +02:00
kwargs["client_name"] = client_name
if client_icon:
2019-07-31 21:25:30 +02:00
kwargs["client_icon"] = client_icon
refresh_token = models.RefreshToken(**kwargs)
user.refresh_tokens[refresh_token.id] = refresh_token
self._async_schedule_save()
2018-07-13 11:43:08 +02:00
return refresh_token
async def async_remove_refresh_token(
2019-07-31 21:25:30 +02:00
self, refresh_token: models.RefreshToken
) -> None:
"""Remove a refresh token."""
if self._users is None:
await self._async_load()
assert self._users is not None
for user in self._users.values():
if user.refresh_tokens.pop(refresh_token.id, None):
self._async_schedule_save()
break
async def async_get_refresh_token(
2019-07-31 21:25:30 +02:00
self, token_id: str
2021-03-17 21:46:07 +01:00
) -> models.RefreshToken | None:
"""Get refresh token by id."""
2018-07-13 11:43:08 +02:00
if self._users is None:
await self._async_load()
assert self._users is not None
2018-07-13 11:43:08 +02:00
for user in self._users.values():
refresh_token = user.refresh_tokens.get(token_id)
2018-07-13 11:43:08 +02:00
if refresh_token is not None:
return refresh_token
return None
async def async_get_refresh_token_by_token(
2019-07-31 21:25:30 +02:00
self, token: str
2021-03-17 21:46:07 +01:00
) -> models.RefreshToken | None:
"""Get refresh token by token."""
if self._users is None:
await self._async_load()
assert self._users is not None
found = None
for user in self._users.values():
for refresh_token in user.refresh_tokens.values():
if hmac.compare_digest(refresh_token.token, token):
found = refresh_token
return found
@callback
def async_log_refresh_token_usage(
2021-03-17 21:46:07 +01:00
self, refresh_token: models.RefreshToken, remote_ip: str | None = None
2019-07-31 21:25:30 +02:00
) -> None:
"""Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_ip = remote_ip
self._async_schedule_save()
async def _async_load(self) -> None:
"""Load the users."""
async with self._lock:
if self._users is not None:
return
await self._async_load_task()
async def _async_load_task(self) -> None:
2018-07-13 11:43:08 +02:00
"""Load the users."""
dev_reg = dr.async_get(self.hass)
ent_reg = er.async_get(self.hass)
data = await self._store.async_load()
2018-07-13 11:43:08 +02:00
# Make sure that we're not overriding data if 2 loads happened at the
# same time
if self._users is not None:
return
2019-07-31 21:25:30 +02:00
self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg)
if data is None or not isinstance(data, dict):
self._set_defaults()
2018-07-13 11:43:08 +02:00
return
2021-03-17 21:46:07 +01:00
users: dict[str, models.User] = OrderedDict()
groups: dict[str, models.Group] = OrderedDict()
credentials: dict[str, models.Credentials] = OrderedDict()
# Soft-migrating data as we load. We are going to make sure we have a
# read only group and an admin group. There are two states that we can
# migrate from:
# 1. Data from a recent version which has a single group without policy
# 2. Data from old version which has no groups
has_admin_group = False
has_user_group = False
has_read_only_group = False
group_without_policy = None
# When creating objects we mention each attribute explicitly. This
# prevents crashing if user rolls back HA version after a new property
# was added.
2019-07-31 21:25:30 +02:00
for group_dict in data.get("groups", []):
2021-03-17 21:46:07 +01:00
policy: PolicyType | None = None
2019-07-31 21:25:30 +02:00
if group_dict["id"] == GROUP_ID_ADMIN:
has_admin_group = True
name = GROUP_NAME_ADMIN
policy = system_policies.ADMIN_POLICY
system_generated = True
2019-07-31 21:25:30 +02:00
elif group_dict["id"] == GROUP_ID_USER:
has_user_group = True
name = GROUP_NAME_USER
policy = system_policies.USER_POLICY
system_generated = True
2019-07-31 21:25:30 +02:00
elif group_dict["id"] == GROUP_ID_READ_ONLY:
has_read_only_group = True
name = GROUP_NAME_READ_ONLY
policy = system_policies.READ_ONLY_POLICY
system_generated = True
else:
2019-07-31 21:25:30 +02:00
name = group_dict["name"]
policy = group_dict.get("policy")
system_generated = False
# We don't want groups without a policy that are not system groups
# This is part of migrating from state 1
if policy is None:
2019-07-31 21:25:30 +02:00
group_without_policy = group_dict["id"]
continue
2019-07-31 21:25:30 +02:00
groups[group_dict["id"]] = models.Group(
id=group_dict["id"],
name=name,
policy=policy,
system_generated=system_generated,
)
# If there are no groups, add all existing users to the admin group.
# This is part of migrating from state 2
2019-07-31 21:25:30 +02:00
migrate_users_to_admin_group = not groups and group_without_policy is None
# If we find a no_policy_group, we need to migrate all users to the
# admin group. We only do this if there are no other groups, as is
# the expected state. If not expected state, not marking people admin.
# This is part of migrating from state 1
if groups and group_without_policy is not None:
group_without_policy = None
# This is part of migrating from state 1 and 2
if not has_admin_group:
admin_group = _system_admin_group()
groups[admin_group.id] = admin_group
# This is part of migrating from state 1 and 2
if not has_read_only_group:
read_only_group = _system_read_only_group()
groups[read_only_group.id] = read_only_group
if not has_user_group:
user_group = _system_user_group()
groups[user_group.id] = user_group
2019-07-31 21:25:30 +02:00
for user_dict in data["users"]:
# Collect the users group.
user_groups = []
2019-07-31 21:25:30 +02:00
for group_id in user_dict.get("group_ids", []):
# This is part of migrating from state 1
if group_id == group_without_policy:
group_id = GROUP_ID_ADMIN
user_groups.append(groups[group_id])
# This is part of migrating from state 2
2019-07-31 21:25:30 +02:00
if not user_dict["system_generated"] and migrate_users_to_admin_group:
user_groups.append(groups[GROUP_ID_ADMIN])
2019-07-31 21:25:30 +02:00
users[user_dict["id"]] = models.User(
name=user_dict["name"],
groups=user_groups,
2019-07-31 21:25:30 +02:00
id=user_dict["id"],
is_owner=user_dict["is_owner"],
is_active=user_dict["is_active"],
system_generated=user_dict["system_generated"],
perm_lookup=perm_lookup,
2021-11-29 23:01:03 +01:00
# New in 2021.11
local_only=user_dict.get("local_only", False),
)
2018-07-13 11:43:08 +02:00
2019-07-31 21:25:30 +02:00
for cred_dict in data["credentials"]:
credential = models.Credentials(
id=cred_dict["id"],
is_new=False,
auth_provider_type=cred_dict["auth_provider_type"],
auth_provider_id=cred_dict["auth_provider_id"],
data=cred_dict["data"],
2019-07-31 21:25:30 +02:00
)
credentials[cred_dict["id"]] = credential
users[cred_dict["user_id"]].credentials.append(credential)
2018-07-13 11:43:08 +02:00
2019-07-31 21:25:30 +02:00
for rt_dict in data["refresh_tokens"]:
# Filter out the old keys that don't have jwt_key (pre-0.76)
2019-07-31 21:25:30 +02:00
if "jwt_key" not in rt_dict:
continue
2019-07-31 21:25:30 +02:00
created_at = dt_util.parse_datetime(rt_dict["created_at"])
if created_at is None:
getLogger(__name__).error(
2019-07-31 21:25:30 +02:00
"Ignoring refresh token %(id)s with invalid created_at "
"%(created_at)s for user_id %(user_id)s",
rt_dict,
)
continue
2021-09-19 01:31:35 +02:00
if (token_type := rt_dict.get("token_type")) is None:
2019-07-31 21:25:30 +02:00
if rt_dict["client_id"] is None:
token_type = models.TOKEN_TYPE_SYSTEM
else:
token_type = models.TOKEN_TYPE_NORMAL
# old refresh_token don't have last_used_at (pre-0.78)
2021-09-19 01:31:35 +02:00
if last_used_at_str := rt_dict.get("last_used_at"):
last_used_at = dt_util.parse_datetime(last_used_at_str)
else:
last_used_at = None
2018-07-13 11:43:08 +02:00
token = models.RefreshToken(
2019-07-31 21:25:30 +02:00
id=rt_dict["id"],
user=users[rt_dict["user_id"]],
client_id=rt_dict["client_id"],
# use dict.get to keep backward compatibility
2019-07-31 21:25:30 +02:00
client_name=rt_dict.get("client_name"),
client_icon=rt_dict.get("client_icon"),
token_type=token_type,
created_at=created_at,
2018-07-13 11:43:08 +02:00
access_token_expiration=timedelta(
2019-07-31 21:25:30 +02:00
seconds=rt_dict["access_token_expiration"]
),
token=rt_dict["token"],
jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at,
2019-07-31 21:25:30 +02:00
last_used_ip=rt_dict.get("last_used_ip"),
version=rt_dict.get("version"),
2018-07-13 11:43:08 +02:00
)
2022-07-09 22:32:57 +02:00
if "credential_id" in rt_dict:
token.credential = credentials.get(rt_dict["credential_id"])
2019-07-31 21:25:30 +02:00
users[rt_dict["user_id"]].refresh_tokens[token.id] = token
2018-07-13 11:43:08 +02:00
self._groups = groups
2018-07-13 11:43:08 +02:00
self._users = users
@callback
def _async_schedule_save(self) -> None:
2018-07-13 11:43:08 +02:00
"""Save users."""
if self._users is None:
return
self._store.async_delay_save(self._data_to_save, 1)
@callback
def _data_to_save(self) -> dict[str, list[dict[str, Any]]]:
"""Return the data to store."""
assert self._users is not None
assert self._groups is not None
2018-07-13 11:43:08 +02:00
users = [
{
2019-07-31 21:25:30 +02:00
"id": user.id,
"group_ids": [group.id for group in user.groups],
"is_owner": user.is_owner,
"is_active": user.is_active,
"name": user.name,
"system_generated": user.system_generated,
2021-11-29 23:01:03 +01:00
"local_only": user.local_only,
2018-07-13 11:43:08 +02:00
}
for user in self._users.values()
]
groups = []
for group in self._groups.values():
2021-03-17 21:46:07 +01:00
g_dict: dict[str, Any] = {
2019-07-31 21:25:30 +02:00
"id": group.id,
# Name not read for sys groups. Kept here for backwards compat
2019-07-31 21:25:30 +02:00
"name": group.name,
}
if not group.system_generated:
2019-07-31 21:25:30 +02:00
g_dict["policy"] = group.policy
groups.append(g_dict)
2018-07-13 11:43:08 +02:00
credentials = [
{
2019-07-31 21:25:30 +02:00
"id": credential.id,
"user_id": user.id,
"auth_provider_type": credential.auth_provider_type,
"auth_provider_id": credential.auth_provider_id,
"data": credential.data,
2018-07-13 11:43:08 +02:00
}
for user in self._users.values()
for credential in user.credentials
]
refresh_tokens = [
{
2019-07-31 21:25:30 +02:00
"id": refresh_token.id,
"user_id": user.id,
"client_id": refresh_token.client_id,
"client_name": refresh_token.client_name,
"client_icon": refresh_token.client_icon,
"token_type": refresh_token.token_type,
"created_at": refresh_token.created_at.isoformat(),
"access_token_expiration": refresh_token.access_token_expiration.total_seconds(),
"token": refresh_token.token,
"jwt_key": refresh_token.jwt_key,
"last_used_at": refresh_token.last_used_at.isoformat()
if refresh_token.last_used_at
else None,
"last_used_ip": refresh_token.last_used_ip,
"credential_id": refresh_token.credential.id
if refresh_token.credential
else None,
"version": refresh_token.version,
2018-07-13 11:43:08 +02:00
}
for user in self._users.values()
for refresh_token in user.refresh_tokens.values()
]
return {
2019-07-31 21:25:30 +02:00
"users": users,
"groups": groups,
"credentials": credentials,
"refresh_tokens": refresh_tokens,
2018-07-13 11:43:08 +02:00
}
def _set_defaults(self) -> None:
"""Set default values for auth store."""
self._users = OrderedDict()
2021-03-17 21:46:07 +01:00
groups: dict[str, models.Group] = OrderedDict()
admin_group = _system_admin_group()
groups[admin_group.id] = admin_group
user_group = _system_user_group()
groups[user_group.id] = user_group
read_only_group = _system_read_only_group()
groups[read_only_group.id] = read_only_group
self._groups = groups
def _system_admin_group() -> models.Group:
"""Create system admin group."""
return models.Group(
name=GROUP_NAME_ADMIN,
id=GROUP_ID_ADMIN,
policy=system_policies.ADMIN_POLICY,
system_generated=True,
)
def _system_user_group() -> models.Group:
"""Create system user group."""
return models.Group(
name=GROUP_NAME_USER,
id=GROUP_ID_USER,
policy=system_policies.USER_POLICY,
system_generated=True,
)
def _system_read_only_group() -> models.Group:
"""Create read only group."""
return models.Group(
name=GROUP_NAME_READ_ONLY,
id=GROUP_ID_READ_ONLY,
policy=system_policies.READ_ONLY_POLICY,
system_generated=True,
)