mirror of https://github.com/home-assistant/core
Ensure the oauth2 token is refreshed before it is expired (#42487)
The current implementation assumed clocks were in sync and did not account for the time it took to refresh the token. A 20 second buffer has been added to ensure that the token is refreshed before it expires as OAuth2Session.valid_token would assume the token was still good even though the remote would reject it and would not refresh because it was not time yet.
This commit is contained in:
parent
4e28ae8e3a
commit
94219c2266
|
@ -33,6 +33,8 @@ DATA_IMPLEMENTATIONS = "oauth2_impl"
|
|||
DATA_PROVIDERS = "oauth2_providers"
|
||||
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||
|
||||
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
|
||||
|
||||
|
||||
class AbstractOAuth2Implementation(ABC):
|
||||
"""Base class to abstract OAuth2 authentication."""
|
||||
|
@ -435,7 +437,10 @@ class OAuth2Session:
|
|||
@property
|
||||
def valid_token(self) -> bool:
|
||||
"""Return if token is still valid."""
|
||||
return cast(float, self.token["expires_at"]) > time.time()
|
||||
return (
|
||||
cast(float, self.token["expires_at"])
|
||||
> time.time() + CLOCK_OUT_OF_SYNC_MAX_SEC
|
||||
)
|
||||
|
||||
async def async_ensure_token_valid(self) -> None:
|
||||
"""Ensure that the current token is valid."""
|
||||
|
|
|
@ -420,6 +420,94 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
|
|||
assert round(config_entry.data["token"]["expires_at"] - now) == 100
|
||||
|
||||
|
||||
async def test_oauth_session_with_clock_slightly_out_of_sync(
|
||||
hass, flow_handler, local_impl, aioclient_mock
|
||||
):
|
||||
"""Test the OAuth2 session helper when the remote clock is slightly out of sync."""
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 19}
|
||||
)
|
||||
|
||||
aioclient_mock.post("https://example.com", status=201)
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain=TEST_DOMAIN,
|
||||
data={
|
||||
"auth_implementation": TEST_DOMAIN,
|
||||
"token": {
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"expires_in": 19,
|
||||
"expires_at": time.time() + 19, # Forces a refresh,
|
||||
"token_type": "bearer",
|
||||
"random_other_data": "should_stay",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
|
||||
resp = await session.async_request("post", "https://example.com")
|
||||
assert resp.status == 201
|
||||
|
||||
# Refresh token, make request
|
||||
assert len(aioclient_mock.mock_calls) == 2
|
||||
|
||||
assert (
|
||||
aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}"
|
||||
)
|
||||
|
||||
assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN
|
||||
assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2
|
||||
assert config_entry.data["token"]["expires_in"] == 19
|
||||
assert config_entry.data["token"]["random_other_data"] == "should_stay"
|
||||
assert round(config_entry.data["token"]["expires_at"] - now) == 19
|
||||
|
||||
|
||||
async def test_oauth_session_no_token_refresh_needed(
|
||||
hass, flow_handler, local_impl, aioclient_mock
|
||||
):
|
||||
"""Test the OAuth2 session helper when no refresh is needed."""
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
|
||||
aioclient_mock.post("https://example.com", status=201)
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain=TEST_DOMAIN,
|
||||
data={
|
||||
"auth_implementation": TEST_DOMAIN,
|
||||
"token": {
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"expires_in": 500,
|
||||
"expires_at": time.time() + 500, # Should NOT refresh
|
||||
"token_type": "bearer",
|
||||
"random_other_data": "should_stay",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
|
||||
resp = await session.async_request("post", "https://example.com")
|
||||
assert resp.status == 201
|
||||
|
||||
# make request (no refresh)
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
|
||||
assert (
|
||||
aioclient_mock.mock_calls[0][3]["authorization"] == f"Bearer {ACCESS_TOKEN_1}"
|
||||
)
|
||||
|
||||
assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN
|
||||
assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_1
|
||||
assert config_entry.data["token"]["expires_in"] == 500
|
||||
assert config_entry.data["token"]["random_other_data"] == "should_stay"
|
||||
assert round(config_entry.data["token"]["expires_at"] - now) == 500
|
||||
|
||||
|
||||
async def test_implementation_provider(hass, local_impl):
|
||||
"""Test providing an implementation provider."""
|
||||
assert (
|
||||
|
|
Loading…
Reference in New Issue