Add coordinator to Swiss public transport (#106278)

This commit is contained in:
Cyrill Raccaud 2023-12-27 12:54:41 +01:00 committed by GitHub
parent 9944047b35
commit b935facec8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 110 additions and 67 deletions

View File

@ -1257,6 +1257,7 @@ omit =
homeassistant/components/surepetcare/sensor.py
homeassistant/components/swiss_hydrological_data/sensor.py
homeassistant/components/swiss_public_transport/__init__.py
homeassistant/components/swiss_public_transport/coordinator.py
homeassistant/components/swiss_public_transport/sensor.py
homeassistant/components/swisscom/device_tracker.py
homeassistant/components/switchbee/__init__.py

View File

@ -13,6 +13,7 @@ from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_DESTINATION, CONF_START, DOMAIN
from .coordinator import SwissPublicTransportDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__)
@ -48,7 +49,9 @@ async def async_setup_entry(
f"Setup failed for entry '{start} {destination}' with invalid data"
) from e
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = opendata
coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True

View File

@ -2,7 +2,6 @@
DOMAIN = "swiss_public_transport"
CONF_DESTINATION = "to"
CONF_START = "from"

View File

@ -0,0 +1,81 @@
"""DataUpdateCoordinator for the swiss_public_transport integration."""
from __future__ import annotations
from datetime import timedelta
import logging
from typing import TypedDict
from opendata_transport import OpendataTransport
from opendata_transport.exceptions import OpendataTransportError
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
import homeassistant.util.dt as dt_util
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
class DataConnection(TypedDict):
"""A connection data class."""
departure: str
next_departure: str
next_on_departure: str
duration: str
platform: str
remaining_time: str
start: str
destination: str
train_number: str
transfers: str
delay: int
class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnection]):
"""A SwissPublicTransport Data Update Coordinator."""
config_entry: ConfigEntry
def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
"""Initialize the SwissPublicTransport data coordinator."""
super().__init__(
hass,
_LOGGER,
name=DOMAIN,
update_interval=timedelta(seconds=90),
)
self._opendata = opendata
async def _async_update_data(self) -> DataConnection:
try:
await self._opendata.async_get_data()
except OpendataTransportError as e:
_LOGGER.warning(
"Unable to connect and retrieve data from transport.opendata.ch"
)
raise UpdateFailed from e
departure_time = dt_util.parse_datetime(
self._opendata.connections[0]["departure"]
)
if departure_time:
remaining_time = departure_time - dt_util.as_local(dt_util.utcnow())
else:
remaining_time = None
return DataConnection(
departure=self._opendata.connections[0]["departure"],
next_departure=self._opendata.connections[1]["departure"],
next_on_departure=self._opendata.connections[2]["departure"],
train_number=self._opendata.connections[0]["number"],
platform=self._opendata.connections[0]["platform"],
transfers=self._opendata.connections[0]["transfers"],
duration=self._opendata.connections[0]["duration"],
start=self._opendata.from_name,
destination=self._opendata.to_name,
remaining_time=f"{remaining_time}",
delay=self._opendata.connections[0]["delay"],
)

View File

@ -1,44 +1,30 @@
"""Support for transport.opendata.ch."""
from __future__ import annotations
from collections.abc import Mapping
from datetime import timedelta
import logging
from typing import Any
from opendata_transport import OpendataTransport
from opendata_transport.exceptions import OpendataTransportError
import voluptuous as vol
from homeassistant import config_entries, core
from homeassistant.components.sensor import PLATFORM_SCHEMA, SensorEntity
from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import CONF_NAME
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResultType
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
import homeassistant.util.dt as dt_util
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import CONF_DESTINATION, CONF_START, DEFAULT_NAME, DOMAIN, PLACEHOLDERS
from .coordinator import SwissPublicTransportDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__)
SCAN_INTERVAL = timedelta(seconds=90)
ATTR_DEPARTURE_TIME1 = "next_departure"
ATTR_DEPARTURE_TIME2 = "next_on_departure"
ATTR_DURATION = "duration"
ATTR_PLATFORM = "platform"
ATTR_REMAINING_TIME = "remaining_time"
ATTR_START = "start"
ATTR_TARGET = "destination"
ATTR_TRAIN_NUMBER = "train_number"
ATTR_TRANSFERS = "transfers"
ATTR_DELAY = "delay"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_DESTINATION): cv.string,
@ -54,12 +40,12 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up the sensor from a config entry created in the integrations UI."""
opendata = hass.data[DOMAIN][config_entry.entry_id]
coordinator = hass.data[DOMAIN][config_entry.entry_id]
name = config_entry.title
async_add_entities(
[SwissPublicTransportSensor(opendata, name)],
[SwissPublicTransportSensor(coordinator, name)],
True,
)
@ -108,60 +94,33 @@ async def async_setup_platform(
)
class SwissPublicTransportSensor(SensorEntity):
class SwissPublicTransportSensor(
CoordinatorEntity[SwissPublicTransportDataUpdateCoordinator], SensorEntity
):
"""Implementation of a Swiss public transport sensor."""
_attr_attribution = "Data provided by transport.opendata.ch"
_attr_icon = "mdi:bus"
def __init__(self, opendata: OpendataTransport, name: str) -> None:
def __init__(
self, coordinator: SwissPublicTransportDataUpdateCoordinator, name: str
) -> None:
"""Initialize the sensor."""
self._opendata = opendata
super().__init__(coordinator)
self._coordinator = coordinator
self._attr_name = name
self._remaining_time: timedelta | None = None
@callback
def _handle_coordinator_update(self) -> None:
"""Handle the state update and prepare the extra state attributes."""
self._attr_extra_state_attributes = {
key: value
for key, value in self.coordinator.data.items()
if key not in {"departure"}
}
return super()._handle_coordinator_update()
@property
def native_value(self) -> str:
"""Return the state of the sensor."""
return self._opendata.connections[0]["departure"]
@property
def extra_state_attributes(self) -> Mapping[str, Any]:
"""Return the state attributes."""
departure_time = dt_util.parse_datetime(
self._opendata.connections[0]["departure"]
)
if departure_time:
remaining_time = departure_time - dt_util.as_local(dt_util.utcnow())
else:
remaining_time = None
self._remaining_time = remaining_time
return {
ATTR_TRAIN_NUMBER: self._opendata.connections[0]["number"],
ATTR_PLATFORM: self._opendata.connections[0]["platform"],
ATTR_TRANSFERS: self._opendata.connections[0]["transfers"],
ATTR_DURATION: self._opendata.connections[0]["duration"],
ATTR_DEPARTURE_TIME1: self._opendata.connections[1]["departure"],
ATTR_DEPARTURE_TIME2: self._opendata.connections[2]["departure"],
ATTR_START: self._opendata.from_name,
ATTR_TARGET: self._opendata.to_name,
ATTR_REMAINING_TIME: f"{self._remaining_time}",
ATTR_DELAY: self._opendata.connections[0]["delay"],
}
async def async_update(self) -> None:
"""Get the latest data from opendata.ch and update the states."""
try:
if not self._remaining_time or self._remaining_time.total_seconds() < 0:
await self._opendata.async_get_data()
except OpendataTransportError:
self._attr_available = False
_LOGGER.warning(
"Unable to connect and retrieve data from transport.opendata.ch"
)
else:
if not self._attr_available:
self._attr_available = True
_LOGGER.info("Connection established with transport.opendata.ch")
return self.coordinator.data["departure"]