Ensure the oauth2 token is refreshed before it is expired (#42487)

The current implementation assumed clocks were in sync and did
not account for the time it took to refresh the token.  A 20
second buffer has been added to ensure that the token is refreshed
before it expires as OAuth2Session.valid_token would assume the
token was still good even though the remote would reject it
and would not refresh because it was not time yet.
This commit is contained in:
J. Nick Koston 2020-10-28 07:47:54 -05:00 committed by GitHub
parent 4e28ae8e3a
commit 94219c2266
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 94 additions and 1 deletions

View File

@ -33,6 +33,8 @@ DATA_IMPLEMENTATIONS = "oauth2_impl"
DATA_PROVIDERS = "oauth2_providers"
AUTH_CALLBACK_PATH = "/auth/external/callback"
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
class AbstractOAuth2Implementation(ABC):
"""Base class to abstract OAuth2 authentication."""
@ -435,7 +437,10 @@ class OAuth2Session:
@property
def valid_token(self) -> bool:
"""Return if token is still valid."""
return cast(float, self.token["expires_at"]) > time.time()
return (
cast(float, self.token["expires_at"])
> time.time() + CLOCK_OUT_OF_SYNC_MAX_SEC
)
async def async_ensure_token_valid(self) -> None:
"""Ensure that the current token is valid."""

View File

@ -420,6 +420,94 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
assert round(config_entry.data["token"]["expires_at"] - now) == 100
async def test_oauth_session_with_clock_slightly_out_of_sync(
hass, flow_handler, local_impl, aioclient_mock
):
"""Test the OAuth2 session helper when the remote clock is slightly out of sync."""
flow_handler.async_register_implementation(hass, local_impl)
aioclient_mock.post(
TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 19}
)
aioclient_mock.post("https://example.com", status=201)
config_entry = MockConfigEntry(
domain=TEST_DOMAIN,
data={
"auth_implementation": TEST_DOMAIN,
"token": {
"refresh_token": REFRESH_TOKEN,
"access_token": ACCESS_TOKEN_1,
"expires_in": 19,
"expires_at": time.time() + 19, # Forces a refresh,
"token_type": "bearer",
"random_other_data": "should_stay",
},
},
)
now = time.time()
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
resp = await session.async_request("post", "https://example.com")
assert resp.status == 201
# Refresh token, make request
assert len(aioclient_mock.mock_calls) == 2
assert (
aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}"
)
assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN
assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2
assert config_entry.data["token"]["expires_in"] == 19
assert config_entry.data["token"]["random_other_data"] == "should_stay"
assert round(config_entry.data["token"]["expires_at"] - now) == 19
async def test_oauth_session_no_token_refresh_needed(
hass, flow_handler, local_impl, aioclient_mock
):
"""Test the OAuth2 session helper when no refresh is needed."""
flow_handler.async_register_implementation(hass, local_impl)
aioclient_mock.post("https://example.com", status=201)
config_entry = MockConfigEntry(
domain=TEST_DOMAIN,
data={
"auth_implementation": TEST_DOMAIN,
"token": {
"refresh_token": REFRESH_TOKEN,
"access_token": ACCESS_TOKEN_1,
"expires_in": 500,
"expires_at": time.time() + 500, # Should NOT refresh
"token_type": "bearer",
"random_other_data": "should_stay",
},
},
)
now = time.time()
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
resp = await session.async_request("post", "https://example.com")
assert resp.status == 201
# make request (no refresh)
assert len(aioclient_mock.mock_calls) == 1
assert (
aioclient_mock.mock_calls[0][3]["authorization"] == f"Bearer {ACCESS_TOKEN_1}"
)
assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN
assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_1
assert config_entry.data["token"]["expires_in"] == 500
assert config_entry.data["token"]["random_other_data"] == "should_stay"
assert round(config_entry.data["token"]["expires_at"] - now) == 500
async def test_implementation_provider(hass, local_impl):
"""Test providing an implementation provider."""
assert (