diff --git a/.coveragerc b/.coveragerc index ee8e165c9b6..44e424260c1 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1257,6 +1257,7 @@ omit = homeassistant/components/surepetcare/sensor.py homeassistant/components/swiss_hydrological_data/sensor.py homeassistant/components/swiss_public_transport/__init__.py + homeassistant/components/swiss_public_transport/coordinator.py homeassistant/components/swiss_public_transport/sensor.py homeassistant/components/swisscom/device_tracker.py homeassistant/components/switchbee/__init__.py diff --git a/homeassistant/components/swiss_public_transport/__init__.py b/homeassistant/components/swiss_public_transport/__init__.py index 37f1eeb6765..9e01a07416f 100644 --- a/homeassistant/components/swiss_public_transport/__init__.py +++ b/homeassistant/components/swiss_public_transport/__init__.py @@ -13,6 +13,7 @@ from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import CONF_DESTINATION, CONF_START, DOMAIN +from .coordinator import SwissPublicTransportDataUpdateCoordinator _LOGGER = logging.getLogger(__name__) @@ -48,7 +49,9 @@ async def async_setup_entry( f"Setup failed for entry '{start} {destination}' with invalid data" ) from e - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = opendata + coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata) + await coordinator.async_config_entry_first_refresh() + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True diff --git a/homeassistant/components/swiss_public_transport/const.py b/homeassistant/components/swiss_public_transport/const.py index d14a77feb2a..6d9fb8bb960 100644 --- a/homeassistant/components/swiss_public_transport/const.py +++ b/homeassistant/components/swiss_public_transport/const.py @@ -2,7 +2,6 @@ DOMAIN = "swiss_public_transport" - CONF_DESTINATION = "to" CONF_START = "from" diff --git a/homeassistant/components/swiss_public_transport/coordinator.py b/homeassistant/components/swiss_public_transport/coordinator.py new file mode 100644 index 00000000000..93b3312b099 --- /dev/null +++ b/homeassistant/components/swiss_public_transport/coordinator.py @@ -0,0 +1,81 @@ +"""DataUpdateCoordinator for the swiss_public_transport integration.""" +from __future__ import annotations + +from datetime import timedelta +import logging +from typing import TypedDict + +from opendata_transport import OpendataTransport +from opendata_transport.exceptions import OpendataTransportError + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed +import homeassistant.util.dt as dt_util + +from .const import DOMAIN + +_LOGGER = logging.getLogger(__name__) + + +class DataConnection(TypedDict): + """A connection data class.""" + + departure: str + next_departure: str + next_on_departure: str + duration: str + platform: str + remaining_time: str + start: str + destination: str + train_number: str + transfers: str + delay: int + + +class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnection]): + """A SwissPublicTransport Data Update Coordinator.""" + + config_entry: ConfigEntry + + def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None: + """Initialize the SwissPublicTransport data coordinator.""" + super().__init__( + hass, + _LOGGER, + name=DOMAIN, + update_interval=timedelta(seconds=90), + ) + self._opendata = opendata + + async def _async_update_data(self) -> DataConnection: + try: + await self._opendata.async_get_data() + except OpendataTransportError as e: + _LOGGER.warning( + "Unable to connect and retrieve data from transport.opendata.ch" + ) + raise UpdateFailed from e + + departure_time = dt_util.parse_datetime( + self._opendata.connections[0]["departure"] + ) + if departure_time: + remaining_time = departure_time - dt_util.as_local(dt_util.utcnow()) + else: + remaining_time = None + + return DataConnection( + departure=self._opendata.connections[0]["departure"], + next_departure=self._opendata.connections[1]["departure"], + next_on_departure=self._opendata.connections[2]["departure"], + train_number=self._opendata.connections[0]["number"], + platform=self._opendata.connections[0]["platform"], + transfers=self._opendata.connections[0]["transfers"], + duration=self._opendata.connections[0]["duration"], + start=self._opendata.from_name, + destination=self._opendata.to_name, + remaining_time=f"{remaining_time}", + delay=self._opendata.connections[0]["delay"], + ) diff --git a/homeassistant/components/swiss_public_transport/sensor.py b/homeassistant/components/swiss_public_transport/sensor.py index bc03b8d61e1..63b5891e48d 100644 --- a/homeassistant/components/swiss_public_transport/sensor.py +++ b/homeassistant/components/swiss_public_transport/sensor.py @@ -1,44 +1,30 @@ """Support for transport.opendata.ch.""" from __future__ import annotations -from collections.abc import Mapping from datetime import timedelta import logging -from typing import Any -from opendata_transport import OpendataTransport -from opendata_transport.exceptions import OpendataTransportError import voluptuous as vol from homeassistant import config_entries, core from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.const import CONF_NAME -from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant +from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant, callback from homeassistant.data_entry_flow import FlowResultType import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType -import homeassistant.util.dt as dt_util +from homeassistant.helpers.update_coordinator import CoordinatorEntity from .const import CONF_DESTINATION, CONF_START, DEFAULT_NAME, DOMAIN, PLACEHOLDERS +from .coordinator import SwissPublicTransportDataUpdateCoordinator _LOGGER = logging.getLogger(__name__) SCAN_INTERVAL = timedelta(seconds=90) -ATTR_DEPARTURE_TIME1 = "next_departure" -ATTR_DEPARTURE_TIME2 = "next_on_departure" -ATTR_DURATION = "duration" -ATTR_PLATFORM = "platform" -ATTR_REMAINING_TIME = "remaining_time" -ATTR_START = "start" -ATTR_TARGET = "destination" -ATTR_TRAIN_NUMBER = "train_number" -ATTR_TRANSFERS = "transfers" -ATTR_DELAY = "delay" - PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( { vol.Required(CONF_DESTINATION): cv.string, @@ -54,12 +40,12 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the sensor from a config entry created in the integrations UI.""" - opendata = hass.data[DOMAIN][config_entry.entry_id] + coordinator = hass.data[DOMAIN][config_entry.entry_id] name = config_entry.title async_add_entities( - [SwissPublicTransportSensor(opendata, name)], + [SwissPublicTransportSensor(coordinator, name)], True, ) @@ -108,60 +94,33 @@ async def async_setup_platform( ) -class SwissPublicTransportSensor(SensorEntity): +class SwissPublicTransportSensor( + CoordinatorEntity[SwissPublicTransportDataUpdateCoordinator], SensorEntity +): """Implementation of a Swiss public transport sensor.""" _attr_attribution = "Data provided by transport.opendata.ch" _attr_icon = "mdi:bus" - def __init__(self, opendata: OpendataTransport, name: str) -> None: + def __init__( + self, coordinator: SwissPublicTransportDataUpdateCoordinator, name: str + ) -> None: """Initialize the sensor.""" - self._opendata = opendata + super().__init__(coordinator) + self._coordinator = coordinator self._attr_name = name - self._remaining_time: timedelta | None = None + + @callback + def _handle_coordinator_update(self) -> None: + """Handle the state update and prepare the extra state attributes.""" + self._attr_extra_state_attributes = { + key: value + for key, value in self.coordinator.data.items() + if key not in {"departure"} + } + return super()._handle_coordinator_update() @property def native_value(self) -> str: """Return the state of the sensor.""" - return self._opendata.connections[0]["departure"] - - @property - def extra_state_attributes(self) -> Mapping[str, Any]: - """Return the state attributes.""" - departure_time = dt_util.parse_datetime( - self._opendata.connections[0]["departure"] - ) - if departure_time: - remaining_time = departure_time - dt_util.as_local(dt_util.utcnow()) - else: - remaining_time = None - self._remaining_time = remaining_time - - return { - ATTR_TRAIN_NUMBER: self._opendata.connections[0]["number"], - ATTR_PLATFORM: self._opendata.connections[0]["platform"], - ATTR_TRANSFERS: self._opendata.connections[0]["transfers"], - ATTR_DURATION: self._opendata.connections[0]["duration"], - ATTR_DEPARTURE_TIME1: self._opendata.connections[1]["departure"], - ATTR_DEPARTURE_TIME2: self._opendata.connections[2]["departure"], - ATTR_START: self._opendata.from_name, - ATTR_TARGET: self._opendata.to_name, - ATTR_REMAINING_TIME: f"{self._remaining_time}", - ATTR_DELAY: self._opendata.connections[0]["delay"], - } - - async def async_update(self) -> None: - """Get the latest data from opendata.ch and update the states.""" - - try: - if not self._remaining_time or self._remaining_time.total_seconds() < 0: - await self._opendata.async_get_data() - except OpendataTransportError: - self._attr_available = False - _LOGGER.warning( - "Unable to connect and retrieve data from transport.opendata.ch" - ) - else: - if not self._attr_available: - self._attr_available = True - _LOGGER.info("Connection established with transport.opendata.ch") + return self.coordinator.data["departure"]