1
mirror of https://github.com/home-assistant/core synced 2024-07-15 09:42:11 +02:00

Convert cert_expiry to use asyncio (#106919)

This commit is contained in:
J. Nick Koston 2024-01-05 08:03:53 -10:00 committed by GitHub
parent 9a15a5b6c2
commit 24ee64e20c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 22 deletions

View File

@ -1,7 +1,10 @@
"""Helper functions for the Cert Expiry platform."""
import asyncio
import datetime
from functools import cache
import socket
import ssl
from typing import Any
from homeassistant.core import HomeAssistant
from homeassistant.util import dt as dt_util
@ -21,31 +24,38 @@ def _get_default_ssl_context():
return ssl.create_default_context()
def get_cert(
async def async_get_cert(
hass: HomeAssistant,
host: str,
port: int,
):
) -> dict[str, Any]:
"""Get the certificate for the host and port combination."""
ctx = _get_default_ssl_context()
address = (host, port)
with socket.create_connection(address, timeout=TIMEOUT) as sock, ctx.wrap_socket(
sock, server_hostname=address[0]
) as ssock:
cert = ssock.getpeercert()
return cert
async with asyncio.timeout(TIMEOUT):
transport, _ = await hass.loop.create_connection(
asyncio.Protocol,
host,
port,
ssl=_get_default_ssl_context(),
happy_eyeballs_delay=0.25,
server_hostname=host,
)
try:
return transport.get_extra_info("peercert")
finally:
transport.close()
async def get_cert_expiry_timestamp(
hass: HomeAssistant,
hostname: str,
port: int,
):
) -> datetime.datetime:
"""Return the certificate's expiration timestamp."""
try:
cert = await hass.async_add_executor_job(get_cert, hostname, port)
cert = await async_get_cert(hass, hostname, port)
except socket.gaierror as err:
raise ResolveFailed(f"Cannot resolve hostname: {hostname}") from err
except socket.timeout as err:
except asyncio.TimeoutError as err:
raise ConnectionTimeout(
f"Connection timeout with server: {hostname}:{port}"
) from err

View File

@ -1,4 +1,5 @@
"""Tests for the Cert Expiry config flow."""
import asyncio
import socket
import ssl
from unittest.mock import patch
@ -48,7 +49,7 @@ async def test_user_with_bad_cert(hass: HomeAssistant) -> None:
assert result["step_id"] == "user"
with patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ssl.SSLError("some error"),
):
result = await hass.config_entries.flow.async_configure(
@ -153,7 +154,7 @@ async def test_import_with_name(hass: HomeAssistant) -> None:
async def test_bad_import(hass: HomeAssistant) -> None:
"""Test import step."""
with patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ConnectionRefusedError(),
):
result = await hass.config_entries.flow.async_init(
@ -198,7 +199,7 @@ async def test_abort_on_socket_failed(hass: HomeAssistant) -> None:
)
with patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=socket.gaierror(),
):
result = await hass.config_entries.flow.async_configure(
@ -208,8 +209,8 @@ async def test_abort_on_socket_failed(hass: HomeAssistant) -> None:
assert result["errors"] == {CONF_HOST: "resolve_failed"}
with patch(
"homeassistant.components.cert_expiry.helper.get_cert",
side_effect=socket.timeout(),
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=asyncio.TimeoutError,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={CONF_HOST: HOST}
@ -218,7 +219,7 @@ async def test_abort_on_socket_failed(hass: HomeAssistant) -> None:
assert result["errors"] == {CONF_HOST: "connection_timeout"}
with patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ConnectionRefusedError,
):
result = await hass.config_entries.flow.async_configure(

View File

@ -57,7 +57,7 @@ async def test_async_setup_entry_bad_cert(hass: HomeAssistant) -> None:
)
with patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ssl.SSLError("some error"),
):
entry.add_to_hass(hass)
@ -146,7 +146,7 @@ async def test_update_sensor_network_errors(hass: HomeAssistant) -> None:
next_update = starting_time + timedelta(hours=24)
with freeze_time(next_update), patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=socket.gaierror,
):
async_fire_time_changed(hass, utcnow() + timedelta(hours=24))
@ -174,7 +174,7 @@ async def test_update_sensor_network_errors(hass: HomeAssistant) -> None:
next_update = starting_time + timedelta(hours=72)
with freeze_time(next_update), patch(
"homeassistant.components.cert_expiry.helper.get_cert",
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=ssl.SSLError("something bad"),
):
async_fire_time_changed(hass, utcnow() + timedelta(hours=72))
@ -189,7 +189,8 @@ async def test_update_sensor_network_errors(hass: HomeAssistant) -> None:
next_update = starting_time + timedelta(hours=96)
with freeze_time(next_update), patch(
"homeassistant.components.cert_expiry.helper.get_cert", side_effect=Exception()
"homeassistant.components.cert_expiry.helper.async_get_cert",
side_effect=Exception(),
):
async_fire_time_changed(hass, utcnow() + timedelta(hours=96))
await hass.async_block_till_done()