From 94219c226682e132a8f7aee8cb9439c5eaf96bbb Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 28 Oct 2020 07:47:54 -0500 Subject: [PATCH] Ensure the oauth2 token is refreshed before it is expired (#42487) The current implementation assumed clocks were in sync and did not account for the time it took to refresh the token. A 20 second buffer has been added to ensure that the token is refreshed before it expires as OAuth2Session.valid_token would assume the token was still good even though the remote would reject it and would not refresh because it was not time yet. --- .../helpers/config_entry_oauth2_flow.py | 7 +- .../helpers/test_config_entry_oauth2_flow.py | 88 +++++++++++++++++++ 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index ace1365df1b1..ba43a057ca3a 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -33,6 +33,8 @@ DATA_IMPLEMENTATIONS = "oauth2_impl" DATA_PROVIDERS = "oauth2_providers" AUTH_CALLBACK_PATH = "/auth/external/callback" +CLOCK_OUT_OF_SYNC_MAX_SEC = 20 + class AbstractOAuth2Implementation(ABC): """Base class to abstract OAuth2 authentication.""" @@ -435,7 +437,10 @@ class OAuth2Session: @property def valid_token(self) -> bool: """Return if token is still valid.""" - return cast(float, self.token["expires_at"]) > time.time() + return ( + cast(float, self.token["expires_at"]) + > time.time() + CLOCK_OUT_OF_SYNC_MAX_SEC + ) async def async_ensure_token_valid(self) -> None: """Ensure that the current token is valid.""" diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index 691b2e93d569..7ce71defb7ea 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -420,6 +420,94 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock): assert round(config_entry.data["token"]["expires_at"] - now) == 100 +async def test_oauth_session_with_clock_slightly_out_of_sync( + hass, flow_handler, local_impl, aioclient_mock +): + """Test the OAuth2 session helper when the remote clock is slightly out of sync.""" + flow_handler.async_register_implementation(hass, local_impl) + + aioclient_mock.post( + TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 19} + ) + + aioclient_mock.post("https://example.com", status=201) + + config_entry = MockConfigEntry( + domain=TEST_DOMAIN, + data={ + "auth_implementation": TEST_DOMAIN, + "token": { + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "expires_in": 19, + "expires_at": time.time() + 19, # Forces a refresh, + "token_type": "bearer", + "random_other_data": "should_stay", + }, + }, + ) + + now = time.time() + session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl) + resp = await session.async_request("post", "https://example.com") + assert resp.status == 201 + + # Refresh token, make request + assert len(aioclient_mock.mock_calls) == 2 + + assert ( + aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}" + ) + + assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN + assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2 + assert config_entry.data["token"]["expires_in"] == 19 + assert config_entry.data["token"]["random_other_data"] == "should_stay" + assert round(config_entry.data["token"]["expires_at"] - now) == 19 + + +async def test_oauth_session_no_token_refresh_needed( + hass, flow_handler, local_impl, aioclient_mock +): + """Test the OAuth2 session helper when no refresh is needed.""" + flow_handler.async_register_implementation(hass, local_impl) + + aioclient_mock.post("https://example.com", status=201) + + config_entry = MockConfigEntry( + domain=TEST_DOMAIN, + data={ + "auth_implementation": TEST_DOMAIN, + "token": { + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "expires_in": 500, + "expires_at": time.time() + 500, # Should NOT refresh + "token_type": "bearer", + "random_other_data": "should_stay", + }, + }, + ) + + now = time.time() + session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl) + resp = await session.async_request("post", "https://example.com") + assert resp.status == 201 + + # make request (no refresh) + assert len(aioclient_mock.mock_calls) == 1 + + assert ( + aioclient_mock.mock_calls[0][3]["authorization"] == f"Bearer {ACCESS_TOKEN_1}" + ) + + assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN + assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_1 + assert config_entry.data["token"]["expires_in"] == 500 + assert config_entry.data["token"]["random_other_data"] == "should_stay" + assert round(config_entry.data["token"]["expires_at"] - now) == 500 + + async def test_implementation_provider(hass, local_impl): """Test providing an implementation provider.""" assert (