Use DataUpdate coordinator for Transmission (#99209)

* Switch integration to DataUpdate Coordinator

* add coordinator to .coveragerc

* Migrate TransmissionData into DUC

* update coveragerc

* Applu suggestions

* remove CONFIG_SCHEMA
This commit is contained in:
Rami Mosleh 2023-10-12 21:58:22 +03:00 committed by GitHub
parent cc3d1a11bd
commit 536ad57bf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 340 additions and 465 deletions

View File

@ -1396,6 +1396,7 @@ omit =
homeassistant/components/trafikverket_weatherstation/coordinator.py
homeassistant/components/trafikverket_weatherstation/sensor.py
homeassistant/components/transmission/__init__.py
homeassistant/components/transmission/coordinator.py
homeassistant/components/transmission/sensor.py
homeassistant/components/transmission/switch.py
homeassistant/components/travisci/sensor.py

View File

@ -1,8 +1,7 @@
"""Support for the Transmission BitTorrent client API."""
from __future__ import annotations
from collections.abc import Callable
from datetime import datetime, timedelta
from datetime import timedelta
from functools import partial
import logging
import re
@ -14,10 +13,9 @@ from transmission_rpc.error import (
TransmissionConnectError,
TransmissionError,
)
from transmission_rpc.session import SessionStats
import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import (
CONF_HOST,
CONF_ID,
@ -35,33 +33,39 @@ from homeassistant.helpers import (
entity_registry as er,
selector,
)
from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from .const import (
ATTR_DELETE_DATA,
ATTR_TORRENT,
CONF_ENTRY_ID,
CONF_LIMIT,
CONF_ORDER,
DATA_UPDATED,
DEFAULT_DELETE_DATA,
DEFAULT_LIMIT,
DEFAULT_ORDER,
DEFAULT_SCAN_INTERVAL,
DOMAIN,
EVENT_DOWNLOADED_TORRENT,
EVENT_REMOVED_TORRENT,
EVENT_STARTED_TORRENT,
SERVICE_ADD_TORRENT,
SERVICE_REMOVE_TORRENT,
SERVICE_START_TORRENT,
SERVICE_STOP_TORRENT,
)
from .coordinator import TransmissionDataUpdateCoordinator
from .errors import AuthenticationError, CannotConnect, UnknownError
_LOGGER = logging.getLogger(__name__)
PLATFORMS = [Platform.SENSOR, Platform.SWITCH]
MIGRATION_NAME_TO_KEY = {
# Sensors
"Down Speed": "download",
"Up Speed": "upload",
"Status": "status",
"Active Torrents": "active_torrents",
"Paused Torrents": "paused_torrents",
"Total Torrents": "total_torrents",
"Completed Torrents": "completed_torrents",
"Started Torrents": "started_torrents",
# Switches
"Switch": "on_off",
"Turtle Mode": "turtle_mode",
}
SERVICE_BASE_SCHEMA = vol.Schema(
{
@ -95,25 +99,6 @@ SERVICE_STOP_TORRENT_SCHEMA = vol.All(
)
)
CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False)
PLATFORMS = [Platform.SENSOR, Platform.SWITCH]
MIGRATION_NAME_TO_KEY = {
# Sensors
"Down Speed": "download",
"Up Speed": "upload",
"Status": "status",
"Active Torrents": "active_torrents",
"Paused Torrents": "paused_torrents",
"Total Torrents": "total_torrents",
"Completed Torrents": "completed_torrents",
"Started Torrents": "started_torrents",
# Switches
"Switch": "on_off",
"Turtle Mode": "turtle_mode",
}
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Set up the Transmission Component."""
@ -141,24 +126,81 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
except (AuthenticationError, UnknownError) as error:
raise ConfigEntryAuthFailed from error
client = TransmissionClient(hass, config_entry, api)
await client.async_setup()
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client
coordinator = TransmissionDataUpdateCoordinator(hass, config_entry, api)
await hass.async_add_executor_job(coordinator.init_torrent_list)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = coordinator
await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS)
client.register_services()
config_entry.add_update_listener(async_options_updated)
async def add_torrent(service: ServiceCall) -> None:
"""Add new torrent to download."""
torrent = service.data[ATTR_TORRENT]
if torrent.startswith(
("http", "ftp:", "magnet:")
) or hass.config.is_allowed_path(torrent):
await hass.async_add_executor_job(coordinator.api.add_torrent, torrent)
await coordinator.async_request_refresh()
else:
_LOGGER.warning("Could not add torrent: unsupported type or no permission")
async def start_torrent(service: ServiceCall) -> None:
"""Start torrent."""
torrent_id = service.data[CONF_ID]
await hass.async_add_executor_job(coordinator.api.start_torrent, torrent_id)
await coordinator.async_request_refresh()
async def stop_torrent(service: ServiceCall) -> None:
"""Stop torrent."""
torrent_id = service.data[CONF_ID]
await hass.async_add_executor_job(coordinator.api.stop_torrent, torrent_id)
await coordinator.async_request_refresh()
async def remove_torrent(service: ServiceCall) -> None:
"""Remove torrent."""
torrent_id = service.data[CONF_ID]
delete_data = service.data[ATTR_DELETE_DATA]
await hass.async_add_executor_job(
partial(coordinator.api.remove_torrent, torrent_id, delete_data=delete_data)
)
await coordinator.async_request_refresh()
hass.services.async_register(
DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
)
hass.services.async_register(
DOMAIN,
SERVICE_REMOVE_TORRENT,
remove_torrent,
schema=SERVICE_REMOVE_TORRENT_SCHEMA,
)
hass.services.async_register(
DOMAIN,
SERVICE_START_TORRENT,
start_torrent,
schema=SERVICE_START_TORRENT_SCHEMA,
)
hass.services.async_register(
DOMAIN,
SERVICE_STOP_TORRENT,
stop_torrent,
schema=SERVICE_STOP_TORRENT_SCHEMA,
)
return True
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload Transmission Entry from config_entry."""
client: TransmissionClient = hass.data[DOMAIN].pop(config_entry.entry_id)
if client.unsub_timer:
client.unsub_timer()
unload_ok = await hass.config_entries.async_unload_platforms(
if unload_ok := await hass.config_entries.async_unload_platforms(
config_entry, PLATFORMS
)
):
hass.data[DOMAIN].pop(config_entry.entry_id)
if not hass.data[DOMAIN]:
hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT)
@ -202,286 +244,8 @@ async def get_api(
raise UnknownError from error
def _get_client(hass: HomeAssistant, data: dict[str, Any]) -> TransmissionClient | None:
"""Return client from integration name or entry_id."""
if (
(entry_id := data.get(CONF_ENTRY_ID))
and (entry := hass.config_entries.async_get_entry(entry_id))
and entry.state == ConfigEntryState.LOADED
):
return hass.data[DOMAIN][entry_id]
return None
class TransmissionClient:
"""Transmission Client Object."""
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
api: transmission_rpc.Client,
) -> None:
"""Initialize the Transmission RPC API."""
self.hass = hass
self.config_entry = config_entry
self.tm_api = api
self._tm_data = TransmissionData(hass, config_entry, api)
self.unsub_timer: Callable[[], None] | None = None
@property
def api(self) -> TransmissionData:
"""Return the TransmissionData object."""
return self._tm_data
async def async_setup(self) -> None:
"""Set up the Transmission client."""
await self.hass.async_add_executor_job(self.api.init_torrent_list)
await self.hass.async_add_executor_job(self.api.update)
self.add_options()
self.set_scan_interval(self.config_entry.options[CONF_SCAN_INTERVAL])
def register_services(self) -> None:
"""Register integration services."""
def add_torrent(service: ServiceCall) -> None:
"""Add new torrent to download."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent = service.data[ATTR_TORRENT]
if torrent.startswith(
("http", "ftp:", "magnet:")
) or self.hass.config.is_allowed_path(torrent):
tm_client.tm_api.add_torrent(torrent)
tm_client.api.update()
else:
_LOGGER.warning(
"Could not add torrent: unsupported type or no permission"
)
def start_torrent(service: ServiceCall) -> None:
"""Start torrent."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent_id = service.data[CONF_ID]
tm_client.tm_api.start_torrent(torrent_id)
tm_client.api.update()
def stop_torrent(service: ServiceCall) -> None:
"""Stop torrent."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent_id = service.data[CONF_ID]
tm_client.tm_api.stop_torrent(torrent_id)
tm_client.api.update()
def remove_torrent(service: ServiceCall) -> None:
"""Remove torrent."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent_id = service.data[CONF_ID]
delete_data = service.data[ATTR_DELETE_DATA]
tm_client.tm_api.remove_torrent(torrent_id, delete_data=delete_data)
tm_client.api.update()
self.hass.services.async_register(
DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
)
self.hass.services.async_register(
DOMAIN,
SERVICE_REMOVE_TORRENT,
remove_torrent,
schema=SERVICE_REMOVE_TORRENT_SCHEMA,
)
self.hass.services.async_register(
DOMAIN,
SERVICE_START_TORRENT,
start_torrent,
schema=SERVICE_START_TORRENT_SCHEMA,
)
self.hass.services.async_register(
DOMAIN,
SERVICE_STOP_TORRENT,
stop_torrent,
schema=SERVICE_STOP_TORRENT_SCHEMA,
)
self.config_entry.add_update_listener(self.async_options_updated)
def add_options(self):
"""Add options for entry."""
if not self.config_entry.options:
scan_interval = self.config_entry.data.get(
CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
)
limit = self.config_entry.data.get(CONF_LIMIT, DEFAULT_LIMIT)
order = self.config_entry.data.get(CONF_ORDER, DEFAULT_ORDER)
options = {
CONF_SCAN_INTERVAL: scan_interval,
CONF_LIMIT: limit,
CONF_ORDER: order,
}
self.hass.config_entries.async_update_entry(
self.config_entry, options=options
)
def set_scan_interval(self, scan_interval: float) -> None:
"""Update scan interval."""
def refresh(event_time: datetime) -> None:
"""Get the latest data from Transmission."""
self.api.update()
if self.unsub_timer is not None:
self.unsub_timer()
self.unsub_timer = async_track_time_interval(
self.hass, refresh, timedelta(seconds=scan_interval)
)
@staticmethod
async def async_options_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Triggered by config entry options updates."""
tm_client: TransmissionClient = hass.data[DOMAIN][entry.entry_id]
tm_client.set_scan_interval(entry.options[CONF_SCAN_INTERVAL])
await hass.async_add_executor_job(tm_client.api.update)
class TransmissionData:
"""Get the latest data and update the states."""
def __init__(
self, hass: HomeAssistant, config: ConfigEntry, api: transmission_rpc.Client
) -> None:
"""Initialize the Transmission RPC API."""
self.hass = hass
self.config = config
self._api: transmission_rpc.Client = api
self.data: SessionStats | None = None
self.available: bool = True
self._session: transmission_rpc.Session | None = None
self._all_torrents: list[transmission_rpc.Torrent] = []
self._completed_torrents: list[transmission_rpc.Torrent] = []
self._started_torrents: list[transmission_rpc.Torrent] = []
self._torrents: list[transmission_rpc.Torrent] = []
@property
def host(self) -> str:
"""Return the host name."""
return self.config.data[CONF_HOST]
@property
def signal_update(self) -> str:
"""Update signal per transmission entry."""
return f"{DATA_UPDATED}-{self.host}"
@property
def torrents(self) -> list[transmission_rpc.Torrent]:
"""Get the list of torrents."""
return self._torrents
def update(self) -> None:
"""Get the latest data from Transmission instance."""
try:
self.data = self._api.session_stats()
self._torrents = self._api.get_torrents()
self._session = self._api.get_session()
self.check_completed_torrent()
self.check_started_torrent()
self.check_removed_torrent()
_LOGGER.debug("Torrent Data for %s Updated", self.host)
self.available = True
except TransmissionError:
self.available = False
_LOGGER.error("Unable to connect to Transmission client %s", self.host)
dispatcher_send(self.hass, self.signal_update)
def init_torrent_list(self) -> None:
"""Initialize torrent lists."""
self._torrents = self._api.get_torrents()
self._completed_torrents = [
torrent for torrent in self._torrents if torrent.status == "seeding"
]
self._started_torrents = [
torrent for torrent in self._torrents if torrent.status == "downloading"
]
def check_completed_torrent(self) -> None:
"""Get completed torrent functionality."""
old_completed_torrent_names = {
torrent.name for torrent in self._completed_torrents
}
current_completed_torrents = [
torrent for torrent in self._torrents if torrent.status == "seeding"
]
for torrent in current_completed_torrents:
if torrent.name not in old_completed_torrent_names:
self.hass.bus.fire(
EVENT_DOWNLOADED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._completed_torrents = current_completed_torrents
def check_started_torrent(self) -> None:
"""Get started torrent functionality."""
old_started_torrent_names = {torrent.name for torrent in self._started_torrents}
current_started_torrents = [
torrent for torrent in self._torrents if torrent.status == "downloading"
]
for torrent in current_started_torrents:
if torrent.name not in old_started_torrent_names:
self.hass.bus.fire(
EVENT_STARTED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._started_torrents = current_started_torrents
def check_removed_torrent(self) -> None:
"""Get removed torrent functionality."""
current_torrent_names = {torrent.name for torrent in self._torrents}
for torrent in self._all_torrents:
if torrent.name not in current_torrent_names:
self.hass.bus.fire(
EVENT_REMOVED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._all_torrents = self._torrents.copy()
def start_torrents(self) -> None:
"""Start all torrents."""
if not self._torrents:
return
self._api.start_all()
def stop_torrents(self) -> None:
"""Stop all active torrents."""
if not self._torrents:
return
torrent_ids = [torrent.id for torrent in self._torrents]
self._api.stop_torrent(torrent_ids)
def set_alt_speed_enabled(self, is_enabled: bool) -> None:
"""Set the alternative speed flag."""
self._api.set_session(alt_speed_enabled=is_enabled)
def get_alt_speed_enabled(self) -> bool | None:
"""Get the alternative speed flag."""
if self._session is None:
return None
return self._session.alt_speed_enabled
async def async_options_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Triggered by config entry options updates."""
coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id]
coordinator.update_interval = timedelta(seconds=entry.options[CONF_SCAN_INTERVAL])
await coordinator.async_request_refresh()

View File

@ -39,8 +39,6 @@ SERVICE_REMOVE_TORRENT = "remove_torrent"
SERVICE_START_TORRENT = "start_torrent"
SERVICE_STOP_TORRENT = "stop_torrent"
DATA_UPDATED = "transmission_data_updated"
EVENT_STARTED_TORRENT = "transmission_started_torrent"
EVENT_REMOVED_TORRENT = "transmission_removed_torrent"
EVENT_DOWNLOADED_TORRENT = "transmission_downloaded_torrent"

View File

@ -0,0 +1,166 @@
"""Coordinator for transmssion integration."""
from __future__ import annotations
from datetime import timedelta
import logging
import transmission_rpc
from transmission_rpc.session import SessionStats
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_SCAN_INTERVAL
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import (
CONF_LIMIT,
CONF_ORDER,
DEFAULT_LIMIT,
DEFAULT_ORDER,
DEFAULT_SCAN_INTERVAL,
DOMAIN,
EVENT_DOWNLOADED_TORRENT,
EVENT_REMOVED_TORRENT,
EVENT_STARTED_TORRENT,
)
_LOGGER = logging.getLogger(__name__)
class TransmissionDataUpdateCoordinator(DataUpdateCoordinator[SessionStats]):
"""Transmission dataupdate coordinator class."""
config_entry: ConfigEntry
def __init__(
self, hass: HomeAssistant, entry: ConfigEntry, api: transmission_rpc.Client
) -> None:
"""Initialize the Transmission RPC API."""
self.config_entry = entry
self.api = api
self.host = entry.data[CONF_HOST]
self._session: transmission_rpc.Session | None = None
self._all_torrents: list[transmission_rpc.Torrent] = []
self._completed_torrents: list[transmission_rpc.Torrent] = []
self._started_torrents: list[transmission_rpc.Torrent] = []
self.torrents: list[transmission_rpc.Torrent] = []
super().__init__(
hass,
name=f"{DOMAIN} - {self.host}",
logger=_LOGGER,
update_interval=timedelta(seconds=self.scan_interval),
)
@property
def scan_interval(self) -> float:
"""Return scan interval."""
return self.config_entry.options.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
@property
def limit(self) -> int:
"""Return limit."""
return self.config_entry.data.get(CONF_LIMIT, DEFAULT_LIMIT)
@property
def order(self) -> str:
"""Return order."""
return self.config_entry.data.get(CONF_ORDER, DEFAULT_ORDER)
async def _async_update_data(self) -> SessionStats:
"""Update transmission data."""
return await self.hass.async_add_executor_job(self.update)
def update(self) -> SessionStats:
"""Get the latest data from Transmission instance."""
try:
data = self.api.session_stats()
self.torrents = self.api.get_torrents()
self._session = self.api.get_session()
self.check_completed_torrent()
self.check_started_torrent()
self.check_removed_torrent()
except transmission_rpc.TransmissionError as err:
raise UpdateFailed("Unable to connect to Transmission client") from err
return data
def init_torrent_list(self) -> None:
"""Initialize torrent lists."""
self.torrents = self.api.get_torrents()
self._completed_torrents = [
torrent for torrent in self.torrents if torrent.status == "seeding"
]
self._started_torrents = [
torrent for torrent in self.torrents if torrent.status == "downloading"
]
def check_completed_torrent(self) -> None:
"""Get completed torrent functionality."""
old_completed_torrent_names = {
torrent.name for torrent in self._completed_torrents
}
current_completed_torrents = [
torrent for torrent in self.torrents if torrent.status == "seeding"
]
for torrent in current_completed_torrents:
if torrent.name not in old_completed_torrent_names:
self.hass.bus.fire(
EVENT_DOWNLOADED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._completed_torrents = current_completed_torrents
def check_started_torrent(self) -> None:
"""Get started torrent functionality."""
old_started_torrent_names = {torrent.name for torrent in self._started_torrents}
current_started_torrents = [
torrent for torrent in self.torrents if torrent.status == "downloading"
]
for torrent in current_started_torrents:
if torrent.name not in old_started_torrent_names:
self.hass.bus.fire(
EVENT_STARTED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._started_torrents = current_started_torrents
def check_removed_torrent(self) -> None:
"""Get removed torrent functionality."""
current_torrent_names = {torrent.name for torrent in self.torrents}
for torrent in self._all_torrents:
if torrent.name not in current_torrent_names:
self.hass.bus.fire(
EVENT_REMOVED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._all_torrents = self.torrents.copy()
def start_torrents(self) -> None:
"""Start all torrents."""
if not self.torrents:
return
self.api.start_all()
def stop_torrents(self) -> None:
"""Stop all active torrents."""
if not self.torrents:
return
torrent_ids = [torrent.id for torrent in self.torrents]
self.api.stop_torrent(torrent_ids)
def set_alt_speed_enabled(self, is_enabled: bool) -> None:
"""Set the alternative speed flag."""
self.api.set_session(alt_speed_enabled=is_enabled)
def get_alt_speed_enabled(self) -> bool | None:
"""Get the alternative speed flag."""
if self._session is None:
return None
return self._session.alt_speed_enabled

View File

@ -4,21 +4,18 @@ from __future__ import annotations
from contextlib import suppress
from typing import Any
from transmission_rpc.session import SessionStats
from transmission_rpc.torrent import Torrent
from homeassistant.components.sensor import SensorDeviceClass, SensorEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME, STATE_IDLE, UnitOfDataRate
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import TransmissionClient
from .const import (
CONF_LIMIT,
CONF_ORDER,
DOMAIN,
STATE_ATTR_TORRENT_INFO,
STATE_DOWNLOADING,
@ -26,6 +23,7 @@ from .const import (
STATE_UP_DOWN,
SUPPORTED_ORDER_MODES,
)
from .coordinator import TransmissionDataUpdateCoordinator
async def async_setup_entry(
@ -35,54 +33,56 @@ async def async_setup_entry(
) -> None:
"""Set up the Transmission sensors."""
tm_client: TransmissionClient = hass.data[DOMAIN][config_entry.entry_id]
coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][
config_entry.entry_id
]
name: str = config_entry.data[CONF_NAME]
dev = [
TransmissionSpeedSensor(
tm_client,
coordinator,
name,
"download_speed",
"download",
),
TransmissionSpeedSensor(
tm_client,
coordinator,
name,
"upload_speed",
"upload",
),
TransmissionStatusSensor(
tm_client,
coordinator,
name,
"transmission_status",
"status",
),
TransmissionTorrentsSensor(
tm_client,
coordinator,
name,
"active_torrents",
"active_torrents",
),
TransmissionTorrentsSensor(
tm_client,
coordinator,
name,
"paused_torrents",
"paused_torrents",
),
TransmissionTorrentsSensor(
tm_client,
coordinator,
name,
"total_torrents",
"total_torrents",
),
TransmissionTorrentsSensor(
tm_client,
coordinator,
name,
"completed_torrents",
"completed_torrents",
),
TransmissionTorrentsSensor(
tm_client,
coordinator,
name,
"started_torrents",
"started_torrents",
@ -92,7 +92,7 @@ async def async_setup_entry(
async_add_entities(dev, True)
class TransmissionSensor(SensorEntity):
class TransmissionSensor(CoordinatorEntity[SessionStats], SensorEntity):
"""A base class for all Transmission sensors."""
_attr_has_entity_name = True
@ -100,48 +100,23 @@ class TransmissionSensor(SensorEntity):
def __init__(
self,
tm_client: TransmissionClient,
coordinator: TransmissionDataUpdateCoordinator,
client_name: str,
sensor_translation_key: str,
key: str,
) -> None:
"""Initialize the sensor."""
self._tm_client = tm_client
super().__init__(coordinator)
self._attr_translation_key = sensor_translation_key
self._key = key
self._state: StateType = None
self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{key}"
self._attr_unique_id = f"{coordinator.config_entry.entry_id}-{key}"
self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, tm_client.config_entry.entry_id)},
identifiers={(DOMAIN, coordinator.config_entry.entry_id)},
manufacturer="Transmission",
name=client_name,
)
@property
def native_value(self) -> StateType:
"""Return the state of the sensor."""
return self._state
@property
def available(self) -> bool:
"""Could the device be accessed during the last update call."""
return self._tm_client.api.available
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
@callback
def update():
"""Update the state."""
self.async_schedule_update_ha_state(True)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self._tm_client.api.signal_update, update
)
)
class TransmissionSpeedSensor(TransmissionSensor):
"""Representation of a Transmission speed sensor."""
@ -151,15 +126,15 @@ class TransmissionSpeedSensor(TransmissionSensor):
_attr_suggested_display_precision = 2
_attr_suggested_unit_of_measurement = UnitOfDataRate.MEGABYTES_PER_SECOND
def update(self) -> None:
"""Get the latest data from Transmission and updates the state."""
if data := self._tm_client.api.data:
b_spd = (
float(data.download_speed)
if self._key == "download"
else float(data.upload_speed)
)
self._state = b_spd
@property
def native_value(self) -> float:
"""Return the speed of the sensor."""
data = self.coordinator.data
return (
float(data.download_speed)
if self._key == "download"
else float(data.upload_speed)
)
class TransmissionStatusSensor(TransmissionSensor):
@ -168,21 +143,18 @@ class TransmissionStatusSensor(TransmissionSensor):
_attr_device_class = SensorDeviceClass.ENUM
_attr_options = [STATE_IDLE, STATE_UP_DOWN, STATE_SEEDING, STATE_DOWNLOADING]
def update(self) -> None:
"""Get the latest data from Transmission and updates the state."""
if data := self._tm_client.api.data:
upload = data.upload_speed
download = data.download_speed
if upload > 0 and download > 0:
self._state = STATE_UP_DOWN
elif upload > 0 and download == 0:
self._state = STATE_SEEDING
elif upload == 0 and download > 0:
self._state = STATE_DOWNLOADING
else:
self._state = STATE_IDLE
else:
self._state = None
@property
def native_value(self) -> str:
"""Return the value of the status sensor."""
upload = self.coordinator.data.upload_speed
download = self.coordinator.data.download_speed
if upload > 0 and download > 0:
return STATE_UP_DOWN
if upload > 0 and download == 0:
return STATE_SEEDING
if upload == 0 and download > 0:
return STATE_DOWNLOADING
return STATE_IDLE
class TransmissionTorrentsSensor(TransmissionSensor):
@ -208,21 +180,22 @@ class TransmissionTorrentsSensor(TransmissionSensor):
def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes, if any."""
info = _torrents_info(
torrents=self._tm_client.api.torrents,
order=self._tm_client.config_entry.options[CONF_ORDER],
limit=self._tm_client.config_entry.options[CONF_LIMIT],
torrents=self.coordinator.torrents,
order=self.coordinator.order,
limit=self.coordinator.limit,
statuses=self.MODES[self._key],
)
return {
STATE_ATTR_TORRENT_INFO: info,
}
def update(self) -> None:
"""Get the latest data from Transmission and updates the state."""
@property
def native_value(self) -> int:
"""Return the count of the sensor."""
torrents = _filter_torrents(
self._tm_client.api.torrents, statuses=self.MODES[self._key]
self.coordinator.torrents, statuses=self.MODES[self._key]
)
self._state = len(torrents)
return len(torrents)
def _filter_torrents(

View File

@ -3,16 +3,18 @@ from collections.abc import Callable
import logging
from typing import Any
from transmission_rpc.session import SessionStats
from homeassistant.components.switch import SwitchEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME, STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant, callback
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import TransmissionClient
from .const import DOMAIN, SWITCH_TYPES
from .coordinator import TransmissionDataUpdateCoordinator
_LOGGING = logging.getLogger(__name__)
@ -24,17 +26,19 @@ async def async_setup_entry(
) -> None:
"""Set up the Transmission switch."""
tm_client: TransmissionClient = hass.data[DOMAIN][config_entry.entry_id]
coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][
config_entry.entry_id
]
name: str = config_entry.data[CONF_NAME]
dev = []
for switch_type, switch_name in SWITCH_TYPES.items():
dev.append(TransmissionSwitch(switch_type, switch_name, tm_client, name))
dev.append(TransmissionSwitch(switch_type, switch_name, coordinator, name))
async_add_entities(dev, True)
class TransmissionSwitch(SwitchEntity):
class TransmissionSwitch(CoordinatorEntity[SessionStats], SwitchEntity):
"""Representation of a Transmission switch."""
_attr_has_entity_name = True
@ -44,20 +48,18 @@ class TransmissionSwitch(SwitchEntity):
self,
switch_type: str,
switch_name: str,
tm_client: TransmissionClient,
coordinator: TransmissionDataUpdateCoordinator,
client_name: str,
) -> None:
"""Initialize the Transmission switch."""
super().__init__(coordinator)
self._attr_name = switch_name
self.type = switch_type
self._tm_client = tm_client
self._state = STATE_OFF
self._data = None
self.unsub_update: Callable[[], None] | None = None
self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{switch_type}"
self._attr_unique_id = f"{coordinator.config_entry.entry_id}-{switch_type}"
self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, tm_client.config_entry.entry_id)},
identifiers={(DOMAIN, coordinator.config_entry.entry_id)},
manufacturer="Transmission",
name=client_name,
)
@ -65,63 +67,34 @@ class TransmissionSwitch(SwitchEntity):
@property
def is_on(self) -> bool:
"""Return true if device is on."""
return self._state == STATE_ON
active = None
if self.type == "on_off":
active = self.coordinator.data.active_torrent_count > 0
elif self.type == "turtle_mode":
active = self.coordinator.get_alt_speed_enabled()
@property
def available(self) -> bool:
"""Could the device be accessed during the last update call."""
return self._tm_client.api.available
return bool(active)
def turn_on(self, **kwargs: Any) -> None:
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn the device on."""
if self.type == "on_off":
_LOGGING.debug("Starting all torrents")
self._tm_client.api.start_torrents()
await self.hass.async_add_executor_job(self.coordinator.start_torrents)
elif self.type == "turtle_mode":
_LOGGING.debug("Turning Turtle Mode of Transmission on")
self._tm_client.api.set_alt_speed_enabled(True)
self._tm_client.api.update()
await self.hass.async_add_executor_job(
self.coordinator.set_alt_speed_enabled, True
)
await self.coordinator.async_request_refresh()
def turn_off(self, **kwargs: Any) -> None:
async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn the device off."""
if self.type == "on_off":
_LOGGING.debug("Stopping all torrents")
self._tm_client.api.stop_torrents()
await self.hass.async_add_executor_job(self.coordinator.stop_torrents)
if self.type == "turtle_mode":
_LOGGING.debug("Turning Turtle Mode of Transmission off")
self._tm_client.api.set_alt_speed_enabled(False)
self._tm_client.api.update()
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
self.unsub_update = async_dispatcher_connect(
self.hass,
self._tm_client.api.signal_update,
self._schedule_immediate_update,
)
@callback
def _schedule_immediate_update(self) -> None:
self.async_schedule_update_ha_state(True)
async def will_remove_from_hass(self) -> None:
"""Unsubscribe from update dispatcher."""
if self.unsub_update:
self.unsub_update()
self.unsub_update = None
def update(self) -> None:
"""Get the latest data from Transmission and updates the state."""
active = None
if self.type == "on_off":
self._data = self._tm_client.api.data
if self._data:
active = self._data.active_torrent_count > 0
elif self.type == "turtle_mode":
active = self._tm_client.api.get_alt_speed_enabled()
if active is None:
return
self._state = STATE_ON if active else STATE_OFF
await self.hass.async_add_executor_job(
self.coordinator.set_alt_speed_enabled, False
)
await self.coordinator.async_request_refresh()