ha-core/homeassistant/components/apple_tv/__init__.py

372 lines
12 KiB
Python

"""The Apple TV integration."""
import asyncio
import logging
from random import randrange
from pyatv import connect, exceptions, scan
from pyatv.const import DeviceModel, Protocol
from pyatv.convert import model_str
from homeassistant.components import zeroconf
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
ATTR_CONNECTIONS,
ATTR_IDENTIFIERS,
ATTR_MANUFACTURER,
ATTR_MODEL,
ATTR_NAME,
ATTR_SUGGESTED_AREA,
ATTR_SW_VERSION,
CONF_ADDRESS,
CONF_NAME,
EVENT_HOMEASSISTANT_STOP,
Platform,
)
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.entity import DeviceInfo, Entity
from .const import CONF_CREDENTIALS, CONF_IDENTIFIERS, CONF_START_OFF, DOMAIN
_LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = "Apple TV"
BACKOFF_TIME_LOWER_LIMIT = 15 # seconds
BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes
SIGNAL_CONNECTED = "apple_tv_connected"
SIGNAL_DISCONNECTED = "apple_tv_disconnected"
PLATFORMS = [Platform.MEDIA_PLAYER, Platform.REMOTE]
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up a config entry for Apple TV."""
manager = AppleTVManager(hass, entry)
hass.data.setdefault(DOMAIN, {})[entry.unique_id] = manager
async def on_hass_stop(event):
"""Stop push updates when hass stops."""
await manager.disconnect()
entry.async_on_unload(
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, on_hass_stop)
)
async def setup_platforms():
"""Set up platforms and initiate connection."""
await asyncio.gather(
*(
hass.config_entries.async_forward_entry_setup(entry, platform)
for platform in PLATFORMS
)
)
await manager.init()
hass.async_create_task(setup_platforms())
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload an Apple TV config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
manager = hass.data[DOMAIN].pop(entry.unique_id)
await manager.disconnect()
return unload_ok
class AppleTVEntity(Entity):
"""Device that sends commands to an Apple TV."""
_attr_should_poll = False
def __init__(self, name, identifier, manager):
"""Initialize device."""
self.atv = None
self.manager = manager
self._attr_name = name
self._attr_unique_id = identifier
self._attr_device_info = DeviceInfo(identifiers={(DOMAIN, identifier)})
async def async_added_to_hass(self):
"""Handle when an entity is about to be added to Home Assistant."""
@callback
def _async_connected(atv):
"""Handle that a connection was made to a device."""
self.atv = atv
self.async_device_connected(atv)
self.async_write_ha_state()
@callback
def _async_disconnected():
"""Handle that a connection to a device was lost."""
self.async_device_disconnected()
self.atv = None
self.async_write_ha_state()
self.async_on_remove(
async_dispatcher_connect(
self.hass, f"{SIGNAL_CONNECTED}_{self.unique_id}", _async_connected
)
)
self.async_on_remove(
async_dispatcher_connect(
self.hass,
f"{SIGNAL_DISCONNECTED}_{self.unique_id}",
_async_disconnected,
)
)
def async_device_connected(self, atv):
"""Handle when connection is made to device."""
def async_device_disconnected(self):
"""Handle when connection was lost to device."""
class AppleTVManager:
"""Connection and power manager for an Apple TV.
An instance is used per device to share the same power state between
several platforms. It also manages scanning and connection establishment
in case of problems.
"""
def __init__(self, hass, config_entry):
"""Initialize power manager."""
self.config_entry = config_entry
self.hass = hass
self.atv = None
self._is_on = not config_entry.options.get(CONF_START_OFF, False)
self._connection_attempts = 0
self._connection_was_lost = False
self._task = None
async def init(self):
"""Initialize power management."""
if self._is_on:
await self.connect()
def connection_lost(self, _):
"""Device was unexpectedly disconnected.
This is a callback function from pyatv.interface.DeviceListener.
"""
_LOGGER.warning(
'Connection lost to Apple TV "%s"', self.config_entry.data[CONF_NAME]
)
self._connection_was_lost = True
self._handle_disconnect()
def connection_closed(self):
"""Device connection was (intentionally) closed.
This is a callback function from pyatv.interface.DeviceListener.
"""
self._handle_disconnect()
def _handle_disconnect(self):
"""Handle that the device disconnected and restart connect loop."""
if self.atv:
self.atv.close()
self.atv = None
self._dispatch_send(SIGNAL_DISCONNECTED)
self._start_connect_loop()
async def connect(self):
"""Connect to device."""
self._is_on = True
self._start_connect_loop()
async def disconnect(self):
"""Disconnect from device."""
_LOGGER.debug("Disconnecting from device")
self._is_on = False
try:
if self.atv:
self.atv.close()
self.atv = None
if self._task:
self._task.cancel()
self._task = None
except Exception: # pylint: disable=broad-except
_LOGGER.exception("An error occurred while disconnecting")
def _start_connect_loop(self):
"""Start background connect loop to device."""
if not self._task and self.atv is None and self._is_on:
self._task = asyncio.create_task(self._connect_loop())
else:
_LOGGER.debug(
"Not starting connect loop (%s, %s)", self.atv is None, self._is_on
)
async def _connect_loop(self):
"""Connect loop background task function."""
_LOGGER.debug("Starting connect loop")
# Try to find device and connect as long as the user has said that
# we are allowed to connect and we are not already connected.
while self._is_on and self.atv is None:
try:
conf = await self._scan()
if conf:
await self._connect(conf)
except exceptions.AuthenticationError:
self.config_entry.async_start_reauth(self.hass)
asyncio.create_task(self.disconnect())
_LOGGER.exception(
"Authentication failed for %s, try reconfiguring device",
self.config_entry.data[CONF_NAME],
)
break
except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Failed to connect")
self.atv = None
if self.atv is None:
self._connection_attempts += 1
backoff = min(
max(
BACKOFF_TIME_LOWER_LIMIT,
randrange(2**self._connection_attempts),
),
BACKOFF_TIME_UPPER_LIMIT,
)
_LOGGER.debug("Reconnecting in %d seconds", backoff)
await asyncio.sleep(backoff)
_LOGGER.debug("Connect loop ended")
self._task = None
async def _scan(self):
"""Try to find device by scanning for it."""
identifiers = set(
self.config_entry.data.get(CONF_IDENTIFIERS, [self.config_entry.unique_id])
)
address = self.config_entry.data[CONF_ADDRESS]
# Only scan for and set up protocols that was successfully paired
protocols = {
Protocol(int(protocol))
for protocol in self.config_entry.data[CONF_CREDENTIALS]
}
_LOGGER.debug("Discovering device %s", self.config_entry.title)
aiozc = await zeroconf.async_get_async_instance(self.hass)
atvs = await scan(
self.hass.loop,
identifier=identifiers,
protocol=protocols,
hosts=[address],
aiozc=aiozc,
)
if atvs:
return atvs[0]
_LOGGER.debug(
"Failed to find device %s with address %s",
self.config_entry.title,
address,
)
# We no longer multicast scan for the device since as soon as async_step_zeroconf runs,
# it will update the address and reload the config entry when the device is found.
return None
async def _connect(self, conf):
"""Connect to device."""
credentials = self.config_entry.data[CONF_CREDENTIALS]
session = async_get_clientsession(self.hass)
for protocol_int, creds in credentials.items():
protocol = Protocol(int(protocol_int))
if conf.get_service(protocol) is not None:
conf.set_credentials(protocol, creds)
else:
_LOGGER.warning(
"Protocol %s not found for %s, functionality will be reduced",
protocol.name,
self.config_entry.data[CONF_NAME],
)
_LOGGER.debug("Connecting to device %s", self.config_entry.data[CONF_NAME])
self.atv = await connect(conf, self.hass.loop, session=session)
self.atv.listener = self
self._dispatch_send(SIGNAL_CONNECTED, self.atv)
self._address_updated(str(conf.address))
self._async_setup_device_registry()
self._connection_attempts = 0
if self._connection_was_lost:
_LOGGER.info(
'Connection was re-established to device "%s"',
self.config_entry.data[CONF_NAME],
)
self._connection_was_lost = False
@callback
def _async_setup_device_registry(self):
attrs = {
ATTR_IDENTIFIERS: {(DOMAIN, self.config_entry.unique_id)},
ATTR_MANUFACTURER: "Apple",
ATTR_NAME: self.config_entry.data[CONF_NAME],
}
area = attrs[ATTR_NAME]
name_trailer = f" {DEFAULT_NAME}"
if area.endswith(name_trailer):
area = area[: -len(name_trailer)]
attrs[ATTR_SUGGESTED_AREA] = area
if self.atv:
dev_info = self.atv.device_info
attrs[ATTR_MODEL] = (
dev_info.raw_model
if dev_info.model == DeviceModel.Unknown and dev_info.raw_model
else model_str(dev_info.model)
)
attrs[ATTR_SW_VERSION] = dev_info.version
if dev_info.mac:
attrs[ATTR_CONNECTIONS] = {(dr.CONNECTION_NETWORK_MAC, dev_info.mac)}
device_registry = dr.async_get(self.hass)
device_registry.async_get_or_create(
config_entry_id=self.config_entry.entry_id, **attrs
)
@property
def is_connecting(self):
"""Return true if connection is in progress."""
return self._task is not None
def _address_updated(self, address):
"""Update cached address in config entry."""
_LOGGER.debug("Changing address to %s", address)
self.hass.config_entries.async_update_entry(
self.config_entry, data={**self.config_entry.data, CONF_ADDRESS: address}
)
def _dispatch_send(self, signal, *args):
"""Dispatch a signal to all entities managed by this manager."""
async_dispatcher_send(
self.hass, f"{signal}_{self.config_entry.unique_id}", *args
)