Add new network apis to reduce code duplication (#54832)

This commit is contained in:
J. Nick Koston 2021-08-18 12:33:26 -05:00 committed by GitHub
parent 30564d59b6
commit 6d0ce814e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 51 deletions

View File

@ -1,13 +1,14 @@
"""The Network Configuration integration."""
from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address
import logging
import voluptuous as vol
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
@ -45,6 +46,35 @@ async def async_get_source_ip(hass: HomeAssistant, target_ip: str) -> str:
return source_ip if source_ip in all_ipv4s else all_ipv4s[0]
@bind_hass
async def async_get_enabled_source_ips(
hass: HomeAssistant,
) -> list[IPv4Address | IPv6Address]:
"""Build the list of enabled source ips."""
adapters = await async_get_adapters(hass)
sources: list[IPv4Address | IPv6Address] = []
for adapter in adapters:
if not adapter["enabled"]:
continue
if adapter["ipv4"]:
sources.extend(IPv4Address(ipv4["address"]) for ipv4 in adapter["ipv4"])
if adapter["ipv6"]:
# With python 3.9 add scope_ids can be
# added by enumerating adapter["ipv6"]s
# IPv6Address(f"::%{ipv6['scope_id']}")
sources.extend(IPv6Address(ipv6["address"]) for ipv6 in adapter["ipv6"])
return sources
@callback
def async_only_default_interface_enabled(adapters: list[Adapter]) -> bool:
"""Check to see if any non-default adapter is enabled."""
return not any(
adapter["enabled"] and not adapter["default"] for adapter in adapters
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up network for Home Assistant."""

View File

@ -116,14 +116,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
@core_callback
def _async_use_default_interface(adapters: list[network.Adapter]) -> bool:
for adapter in adapters:
if adapter["enabled"] and not adapter["default"]:
return False
return True
@core_callback
def _async_process_callbacks(
callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str]
@ -204,24 +196,16 @@ class Scanner:
"""Build the list of ssdp sources."""
adapters = await network.async_get_adapters(self.hass)
sources: set[IPv4Address | IPv6Address] = set()
if _async_use_default_interface(adapters):
if network.async_only_default_interface_enabled(adapters):
sources.add(IPv4Address("0.0.0.0"))
return sources
for adapter in adapters:
if not adapter["enabled"]:
continue
if adapter["ipv4"]:
ipv4 = adapter["ipv4"][0]
sources.add(IPv4Address(ipv4["address"]))
if adapter["ipv6"]:
ipv6 = adapter["ipv6"][0]
# With python 3.9 add scope_ids can be
# added by enumerating adapter["ipv6"]s
# IPv6Address(f"::%{ipv6['scope_id']}")
sources.add(IPv6Address(ipv6["address"]))
return sources
return {
source_ip
for source_ip in await network.async_get_enabled_source_ips(self.hass)
if not source_ip.is_loopback
and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
}
async def async_scan(self, *_: Any) -> None:
"""Scan for new entries using ssdp default and broadcast target."""

View File

@ -5,7 +5,7 @@ import asyncio
from collections.abc import Coroutine
from contextlib import suppress
import fnmatch
import ipaddress
from ipaddress import IPv6Address, ip_address
import logging
import socket
from typing import Any, TypedDict, cast
@ -131,13 +131,6 @@ async def _async_get_instance(hass: HomeAssistant, **zcargs: Any) -> HaAsyncZero
return aio_zc
def _async_use_default_interface(adapters: list[Adapter]) -> bool:
for adapter in adapters:
if adapter["enabled"] and not adapter["default"]:
return False
return True
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Zeroconf and make Home Assistant discoverable."""
zc_args: dict = {}
@ -151,25 +144,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
else:
zc_args["ip_version"] = IPVersion.All
if not ipv6 and _async_use_default_interface(adapters):
if not ipv6 and network.async_only_default_interface_enabled(adapters):
zc_args["interfaces"] = InterfaceChoice.Default
else:
interfaces = zc_args["interfaces"] = []
for adapter in adapters:
if not adapter["enabled"]:
continue
if ipv4s := adapter["ipv4"]:
interfaces.extend(
ipv4["address"]
for ipv4 in ipv4s
if not ipaddress.IPv4Address(ipv4["address"]).is_loopback
)
if ipv6s := adapter["ipv6"]:
for ipv6_addr in ipv6s:
address = ipv6_addr["address"]
v6_ip_address = ipaddress.IPv6Address(address)
if not v6_ip_address.is_global and not v6_ip_address.is_loopback:
interfaces.append(ipv6_addr["address"])
zc_args["interfaces"] = [
str(source_ip)
for source_ip in await network.async_get_enabled_source_ips(hass)
if not source_ip.is_loopback
and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
]
aio_zc = await _async_get_instance(hass, **zc_args)
zeroconf = cast(HaZeroconf, aio_zc.zeroconf)
@ -213,7 +196,7 @@ def _get_announced_addresses(
addresses = {
addr.packed
for addr in [
ipaddress.ip_address(ip["address"])
ip_address(ip["address"])
for adapter in adapters
if adapter["enabled"]
for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"])
@ -530,7 +513,7 @@ def info_from_service(service: AsyncServiceInfo) -> HaServiceInfo | None:
address = service.addresses[0]
return {
"host": str(ipaddress.ip_address(address)),
"host": str(ip_address(address)),
"port": service.port,
"hostname": service.server,
"type": service.type,