mirror of https://github.com/home-assistant/core
Add new network apis to reduce code duplication (#54832)
This commit is contained in:
parent
30564d59b6
commit
6d0ce814e7
|
@ -1,13 +1,14 @@
|
|||
"""The Network Configuration integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
import logging
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.connection import ActiveConnection
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
|
||||
|
@ -45,6 +46,35 @@ async def async_get_source_ip(hass: HomeAssistant, target_ip: str) -> str:
|
|||
return source_ip if source_ip in all_ipv4s else all_ipv4s[0]
|
||||
|
||||
|
||||
@bind_hass
|
||||
async def async_get_enabled_source_ips(
|
||||
hass: HomeAssistant,
|
||||
) -> list[IPv4Address | IPv6Address]:
|
||||
"""Build the list of enabled source ips."""
|
||||
adapters = await async_get_adapters(hass)
|
||||
sources: list[IPv4Address | IPv6Address] = []
|
||||
for adapter in adapters:
|
||||
if not adapter["enabled"]:
|
||||
continue
|
||||
if adapter["ipv4"]:
|
||||
sources.extend(IPv4Address(ipv4["address"]) for ipv4 in adapter["ipv4"])
|
||||
if adapter["ipv6"]:
|
||||
# With python 3.9 add scope_ids can be
|
||||
# added by enumerating adapter["ipv6"]s
|
||||
# IPv6Address(f"::%{ipv6['scope_id']}")
|
||||
sources.extend(IPv6Address(ipv6["address"]) for ipv6 in adapter["ipv6"])
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
@callback
|
||||
def async_only_default_interface_enabled(adapters: list[Adapter]) -> bool:
|
||||
"""Check to see if any non-default adapter is enabled."""
|
||||
return not any(
|
||||
adapter["enabled"] and not adapter["default"] for adapter in adapters
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up network for Home Assistant."""
|
||||
|
||||
|
|
|
@ -116,14 +116,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
@core_callback
|
||||
def _async_use_default_interface(adapters: list[network.Adapter]) -> bool:
|
||||
for adapter in adapters:
|
||||
if adapter["enabled"] and not adapter["default"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@core_callback
|
||||
def _async_process_callbacks(
|
||||
callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str]
|
||||
|
@ -204,24 +196,16 @@ class Scanner:
|
|||
"""Build the list of ssdp sources."""
|
||||
adapters = await network.async_get_adapters(self.hass)
|
||||
sources: set[IPv4Address | IPv6Address] = set()
|
||||
if _async_use_default_interface(adapters):
|
||||
if network.async_only_default_interface_enabled(adapters):
|
||||
sources.add(IPv4Address("0.0.0.0"))
|
||||
return sources
|
||||
|
||||
for adapter in adapters:
|
||||
if not adapter["enabled"]:
|
||||
continue
|
||||
if adapter["ipv4"]:
|
||||
ipv4 = adapter["ipv4"][0]
|
||||
sources.add(IPv4Address(ipv4["address"]))
|
||||
if adapter["ipv6"]:
|
||||
ipv6 = adapter["ipv6"][0]
|
||||
# With python 3.9 add scope_ids can be
|
||||
# added by enumerating adapter["ipv6"]s
|
||||
# IPv6Address(f"::%{ipv6['scope_id']}")
|
||||
sources.add(IPv6Address(ipv6["address"]))
|
||||
|
||||
return sources
|
||||
return {
|
||||
source_ip
|
||||
for source_ip in await network.async_get_enabled_source_ips(self.hass)
|
||||
if not source_ip.is_loopback
|
||||
and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
|
||||
}
|
||||
|
||||
async def async_scan(self, *_: Any) -> None:
|
||||
"""Scan for new entries using ssdp default and broadcast target."""
|
||||
|
|
|
@ -5,7 +5,7 @@ import asyncio
|
|||
from collections.abc import Coroutine
|
||||
from contextlib import suppress
|
||||
import fnmatch
|
||||
import ipaddress
|
||||
from ipaddress import IPv6Address, ip_address
|
||||
import logging
|
||||
import socket
|
||||
from typing import Any, TypedDict, cast
|
||||
|
@ -131,13 +131,6 @@ async def _async_get_instance(hass: HomeAssistant, **zcargs: Any) -> HaAsyncZero
|
|||
return aio_zc
|
||||
|
||||
|
||||
def _async_use_default_interface(adapters: list[Adapter]) -> bool:
|
||||
for adapter in adapters:
|
||||
if adapter["enabled"] and not adapter["default"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Zeroconf and make Home Assistant discoverable."""
|
||||
zc_args: dict = {}
|
||||
|
@ -151,25 +144,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
else:
|
||||
zc_args["ip_version"] = IPVersion.All
|
||||
|
||||
if not ipv6 and _async_use_default_interface(adapters):
|
||||
if not ipv6 and network.async_only_default_interface_enabled(adapters):
|
||||
zc_args["interfaces"] = InterfaceChoice.Default
|
||||
else:
|
||||
interfaces = zc_args["interfaces"] = []
|
||||
for adapter in adapters:
|
||||
if not adapter["enabled"]:
|
||||
continue
|
||||
if ipv4s := adapter["ipv4"]:
|
||||
interfaces.extend(
|
||||
ipv4["address"]
|
||||
for ipv4 in ipv4s
|
||||
if not ipaddress.IPv4Address(ipv4["address"]).is_loopback
|
||||
)
|
||||
if ipv6s := adapter["ipv6"]:
|
||||
for ipv6_addr in ipv6s:
|
||||
address = ipv6_addr["address"]
|
||||
v6_ip_address = ipaddress.IPv6Address(address)
|
||||
if not v6_ip_address.is_global and not v6_ip_address.is_loopback:
|
||||
interfaces.append(ipv6_addr["address"])
|
||||
zc_args["interfaces"] = [
|
||||
str(source_ip)
|
||||
for source_ip in await network.async_get_enabled_source_ips(hass)
|
||||
if not source_ip.is_loopback
|
||||
and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
|
||||
]
|
||||
|
||||
aio_zc = await _async_get_instance(hass, **zc_args)
|
||||
zeroconf = cast(HaZeroconf, aio_zc.zeroconf)
|
||||
|
@ -213,7 +196,7 @@ def _get_announced_addresses(
|
|||
addresses = {
|
||||
addr.packed
|
||||
for addr in [
|
||||
ipaddress.ip_address(ip["address"])
|
||||
ip_address(ip["address"])
|
||||
for adapter in adapters
|
||||
if adapter["enabled"]
|
||||
for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"])
|
||||
|
@ -530,7 +513,7 @@ def info_from_service(service: AsyncServiceInfo) -> HaServiceInfo | None:
|
|||
address = service.addresses[0]
|
||||
|
||||
return {
|
||||
"host": str(ipaddress.ip_address(address)),
|
||||
"host": str(ip_address(address)),
|
||||
"port": service.port,
|
||||
"hostname": service.server,
|
||||
"type": service.type,
|
||||
|
|
Loading…
Reference in New Issue