1
mirror of https://github.com/home-assistant/core synced 2024-09-15 17:29:45 +02:00

Keep cloud tokens always valid (#20762)

* Keep auth token always valid

* Remove unused refresh_auth message

* Capture EndpointConnectionError

* Lint
This commit is contained in:
Paulus Schoutsen 2019-02-05 01:45:03 -08:00 committed by Pascal Vizeli
parent b1faad0a50
commit 2733919cd8
5 changed files with 122 additions and 36 deletions

View File

@ -106,6 +106,7 @@ async def async_setup(hass, config):
)
cloud = hass.data[DOMAIN] = Cloud(hass, **kwargs)
await auth_api.async_setup(hass, cloud)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, cloud.async_start)
await http_api.async_setup(hass)
return True
@ -263,7 +264,7 @@ class Cloud:
self.access_token = info['access_token']
self.refresh_token = info['refresh_token']
self.hass.add_job(self.iot.connect())
self.hass.async_create_task(self.iot.connect())
def _decode_claims(self, token): # pylint: disable=no-self-use
"""Decode the claims in a token."""

View File

@ -1,4 +1,10 @@
"""Package to communicate with the authentication API."""
import asyncio
import logging
import random
_LOGGER = logging.getLogger(__name__)
class CloudError(Exception):
@ -39,6 +45,40 @@ AWS_EXCEPTIONS = {
}
async def async_setup(hass, cloud):
"""Configure the auth api."""
refresh_task = None
async def handle_token_refresh():
"""Handle Cloud access token refresh."""
sleep_time = 5
sleep_time = random.randint(2400, 3600)
while True:
try:
await asyncio.sleep(sleep_time)
await hass.async_add_executor_job(renew_access_token, cloud)
except CloudError as err:
_LOGGER.error("Can't refresh cloud token: %s", err)
except asyncio.CancelledError:
# Task is canceled, stop it.
break
sleep_time = random.randint(3100, 3600)
async def on_connect():
"""When the instance is connected."""
nonlocal refresh_task
refresh_task = hass.async_create_task(handle_token_refresh())
async def on_disconnect():
"""When the instance is disconnected."""
nonlocal refresh_task
refresh_task.cancel()
cloud.iot.register_on_connect(on_connect)
cloud.iot.register_on_disconnect(on_disconnect)
def _map_aws_exception(err):
"""Map AWS exception to our exceptions."""
ex = AWS_EXCEPTIONS.get(err.response['Error']['Code'], UnknownError)
@ -47,7 +87,7 @@ def _map_aws_exception(err):
def register(cloud, email, password):
"""Register a new account."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(cloud)
# Workaround for bug in Warrant. PR with fix:
@ -55,13 +95,16 @@ def register(cloud, email, password):
cognito.add_base_attributes()
try:
cognito.register(email, password)
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def resend_email_confirm(cloud, email):
"""Resend email confirmation."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(cloud, username=email)
@ -72,18 +115,23 @@ def resend_email_confirm(cloud, email):
)
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def forgot_password(cloud, email):
"""Initialize forgotten password flow."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(cloud, username=email)
try:
cognito.initiate_forgot_password()
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def login(cloud, email, password):
@ -97,7 +145,7 @@ def login(cloud, email, password):
def check_token(cloud):
"""Check that the token is valid and verify if needed."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(
cloud,
@ -109,13 +157,17 @@ def check_token(cloud):
cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token
cloud.write_user_info()
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def renew_access_token(cloud):
"""Renew access token."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
cognito = _cognito(
cloud,
@ -127,13 +179,17 @@ def renew_access_token(cloud):
cloud.id_token = cognito.id_token
cloud.access_token = cognito.access_token
cloud.write_user_info()
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def _authenticate(cloud, email, password):
"""Log in and return an authenticated Cognito instance."""
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
from warrant.exceptions import ForceChangePasswordException
assert not cloud.is_logged_in, 'Cannot login if already logged in.'
@ -145,11 +201,14 @@ def _authenticate(cloud, email, password):
return cognito
except ForceChangePasswordException:
raise PasswordChangeRequired
raise PasswordChangeRequired()
except ClientError as err:
raise _map_aws_exception(err)
except EndpointConnectionError:
raise UnknownError()
def _cognito(cloud, **kwargs):
"""Get the client credentials."""

View File

@ -62,12 +62,18 @@ class CloudIoT:
# Local code waiting for a response
self._response_handler = {}
self._on_connect = []
self._on_disconnect = []
@callback
def register_on_connect(self, on_connect_cb):
"""Register an async on_connect callback."""
self._on_connect.append(on_connect_cb)
@callback
def register_on_disconnect(self, on_disconnect_cb):
"""Register an async on_disconnect callback."""
self._on_disconnect.append(on_disconnect_cb)
@property
def connected(self):
"""Return if we're currently connected."""
@ -102,6 +108,17 @@ class CloudIoT:
# Still adding it here to make sure we can always reconnect
_LOGGER.exception("Unexpected error")
if self.state == STATE_CONNECTED and self._on_disconnect:
try:
yield from asyncio.wait([
cb() for cb in self._on_disconnect
])
except Exception: # pylint: disable=broad-except
# Safety net. This should never hit.
# Still adding it here to make sure we don't break the flow
_LOGGER.exception(
"Unexpected error in on_disconnect callbacks")
if self.close_requested:
break
@ -192,7 +209,13 @@ class CloudIoT:
self.state = STATE_CONNECTED
if self._on_connect:
yield from asyncio.wait([cb() for cb in self._on_connect])
try:
yield from asyncio.wait([cb() for cb in self._on_connect])
except Exception: # pylint: disable=broad-except
# Safety net. This should never hit.
# Still adding it here to make sure we don't break the flow
_LOGGER.exception(
"Unexpected error in on_connect callbacks")
while not client.closed:
msg = yield from client.receive()
@ -326,11 +349,6 @@ async def async_handle_cloud(hass, cloud, payload):
await cloud.logout()
_LOGGER.error("You have been logged out from Home Assistant cloud: %s",
payload['reason'])
elif action == 'refresh_auth':
# Refresh the auth token between now and payload['seconds']
hass.helpers.event.async_call_later(
random.randint(0, payload['seconds']),
lambda now: auth_api.check_token(cloud))
else:
_LOGGER.warning("Received unknown cloud action: %s", action)

View File

@ -1,4 +1,5 @@
"""Tests for the tools to communicate with the cloud."""
import asyncio
from unittest.mock import MagicMock, patch
from botocore.exceptions import ClientError
@ -165,3 +166,31 @@ def test_check_token_raises(mock_cognito):
assert cloud.id_token != mock_cognito.id_token
assert cloud.access_token != mock_cognito.access_token
assert len(cloud.write_user_info.mock_calls) == 0
async def test_async_setup(hass):
"""Test async setup."""
cloud = MagicMock()
await auth_api.async_setup(hass, cloud)
assert len(cloud.iot.mock_calls) == 2
on_connect = cloud.iot.mock_calls[0][1][0]
on_disconnect = cloud.iot.mock_calls[1][1][0]
with patch('random.randint', return_value=0), patch(
'homeassistant.components.cloud.auth_api.renew_access_token'
) as mock_renew:
await on_connect()
# Let handle token sleep once
await asyncio.sleep(0)
# Let handle token refresh token
await asyncio.sleep(0)
assert len(mock_renew.mock_calls) == 1
assert mock_renew.mock_calls[0][1][0] is cloud
await on_disconnect()
# Make sure task is no longer being called
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(mock_renew.mock_calls) == 1

View File

@ -10,9 +10,8 @@ from homeassistant.components.cloud import (
Cloud, iot, auth_api, MODE_DEV)
from homeassistant.components.cloud.const import (
PREF_ENABLE_ALEXA, PREF_ENABLE_GOOGLE)
from homeassistant.util import dt as dt_util
from tests.components.alexa import test_smart_home as test_alexa
from tests.common import mock_coro, async_fire_time_changed
from tests.common import mock_coro
from . import mock_cloud_prefs
@ -158,26 +157,6 @@ async def test_handling_core_messages_logout(hass, mock_cloud):
assert len(mock_cloud.logout.mock_calls) == 1
async def test_handling_core_messages_refresh_auth(hass, mock_cloud):
"""Test handling core messages."""
mock_cloud.hass = hass
with patch('random.randint', return_value=0) as mock_rand, patch(
'homeassistant.components.cloud.auth_api.check_token'
) as mock_check:
await iot.async_handle_cloud(hass, mock_cloud, {
'action': 'refresh_auth',
'seconds': 230,
})
async_fire_time_changed(hass, dt_util.utcnow())
await hass.async_block_till_done()
assert len(mock_rand.mock_calls) == 1
assert mock_rand.mock_calls[0][1] == (0, 230)
assert len(mock_check.mock_calls) == 1
assert mock_check.mock_calls[0][1][0] is mock_cloud
@asyncio.coroutine
def test_cloud_getting_disconnected_by_server(mock_client, caplog, mock_cloud):
"""Test server disconnecting instance."""