mirror of
https://github.com/home-assistant/core
synced 2024-09-06 10:29:55 +02:00
Add product filtering feature to Trafikverket Train (#86343)
This commit is contained in:
parent
0317afeb17
commit
4c03077dfe
@ -15,7 +15,7 @@ from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
|
||||
from .const import CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
|
||||
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
|
||||
from .coordinator import TVDataUpdateCoordinator
|
||||
|
||||
|
||||
@ -36,7 +36,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
f" {entry.data[CONF_TO]}. Error: {error} "
|
||||
) from error
|
||||
|
||||
coordinator = TVDataUpdateCoordinator(hass, entry, to_station, from_station)
|
||||
coordinator = TVDataUpdateCoordinator(
|
||||
hass, entry, to_station, from_station, entry.options.get(CONF_FILTER_PRODUCT)
|
||||
)
|
||||
await coordinator.async_config_entry_first_refresh()
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
|
||||
|
||||
@ -49,6 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
)
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
return True
|
||||
|
||||
@ -57,3 +60,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Trafikverket Weatherstation config entry."""
|
||||
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||
"""Handle options update."""
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
@ -19,7 +19,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
@ -32,11 +32,15 @@ from homeassistant.helpers.selector import (
|
||||
)
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
|
||||
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
|
||||
from .util import create_unique_id, next_departuredate
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
OPTION_SCHEMA = {
|
||||
vol.Optional(CONF_FILTER_PRODUCT, default=""): TextSelector(),
|
||||
}
|
||||
|
||||
DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_API_KEY): TextSelector(),
|
||||
@ -52,7 +56,7 @@ DATA_SCHEMA = vol.Schema(
|
||||
)
|
||||
),
|
||||
}
|
||||
)
|
||||
).extend(OPTION_SCHEMA)
|
||||
DATA_SCHEMA_REAUTH = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_API_KEY): cv.string,
|
||||
@ -67,6 +71,7 @@ async def validate_input(
|
||||
train_to: str,
|
||||
train_time: str | None,
|
||||
weekdays: list[str],
|
||||
product_filter: str | None,
|
||||
) -> dict[str, str]:
|
||||
"""Validate input from user input."""
|
||||
errors: dict[str, str] = {}
|
||||
@ -87,9 +92,13 @@ async def validate_input(
|
||||
from_station = await train_api.async_get_train_station(train_from)
|
||||
to_station = await train_api.async_get_train_station(train_to)
|
||||
if train_time:
|
||||
await train_api.async_get_train_stop(from_station, to_station, when)
|
||||
await train_api.async_get_train_stop(
|
||||
from_station, to_station, when, product_filter
|
||||
)
|
||||
else:
|
||||
await train_api.async_get_next_train_stop(from_station, to_station, when)
|
||||
await train_api.async_get_next_train_stop(
|
||||
from_station, to_station, when, product_filter
|
||||
)
|
||||
except InvalidAuthentication:
|
||||
errors["base"] = "invalid_auth"
|
||||
except NoTrainStationFound:
|
||||
@ -117,6 +126,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
||||
entry: config_entries.ConfigEntry | None
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
) -> TVTrainOptionsFlowHandler:
|
||||
"""Get the options flow for this handler."""
|
||||
return TVTrainOptionsFlowHandler(config_entry)
|
||||
|
||||
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
|
||||
"""Handle re-authentication with Trafikverket."""
|
||||
|
||||
@ -140,6 +157,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
self.entry.data[CONF_TO],
|
||||
self.entry.data.get(CONF_TIME),
|
||||
self.entry.data[CONF_WEEKDAY],
|
||||
self.entry.options.get(CONF_FILTER_PRODUCT),
|
||||
)
|
||||
if not errors:
|
||||
self.hass.config_entries.async_update_entry(
|
||||
@ -170,6 +188,10 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
train_to: str = user_input[CONF_TO]
|
||||
train_time: str | None = user_input.get(CONF_TIME)
|
||||
train_days: list = user_input[CONF_WEEKDAY]
|
||||
filter_product: str | None = user_input[CONF_FILTER_PRODUCT]
|
||||
|
||||
if filter_product == "":
|
||||
filter_product = None
|
||||
|
||||
name = f"{train_from} to {train_to}"
|
||||
if train_time:
|
||||
@ -182,6 +204,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
train_to,
|
||||
train_time,
|
||||
train_days,
|
||||
filter_product,
|
||||
)
|
||||
if not errors:
|
||||
unique_id = create_unique_id(
|
||||
@ -199,6 +222,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
CONF_TIME: train_time,
|
||||
CONF_WEEKDAY: train_days,
|
||||
},
|
||||
options={CONF_FILTER_PRODUCT: filter_product},
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
@ -208,3 +232,27 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
class TVTrainOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
|
||||
"""Handle Trafikverket Train options."""
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Manage Trafikverket Train options."""
|
||||
errors: dict[str, Any] = {}
|
||||
|
||||
if user_input:
|
||||
if not (_filter := user_input.get(CONF_FILTER_PRODUCT)) or _filter == "":
|
||||
user_input[CONF_FILTER_PRODUCT] = None
|
||||
return self.async_create_entry(data=user_input)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=self.add_suggested_values_to_schema(
|
||||
vol.Schema(OPTION_SCHEMA),
|
||||
user_input or self.options,
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
@ -8,3 +8,4 @@ ATTRIBUTION = "Data provided by Trafikverket"
|
||||
CONF_FROM = "from"
|
||||
CONF_TO = "to"
|
||||
CONF_TIME = "time"
|
||||
CONF_FILTER_PRODUCT = "filter_product"
|
||||
|
@ -39,6 +39,7 @@ class TrainData:
|
||||
actual_time: datetime | None
|
||||
other_info: str | None
|
||||
deviation: str | None
|
||||
product_filter: str | None
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -68,6 +69,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
|
||||
entry: ConfigEntry,
|
||||
to_station: StationInfo,
|
||||
from_station: StationInfo,
|
||||
filter_product: str | None,
|
||||
) -> None:
|
||||
"""Initialize the Trafikverket coordinator."""
|
||||
super().__init__(
|
||||
@ -83,6 +85,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
|
||||
self.to_station: StationInfo = to_station
|
||||
self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME])
|
||||
self._weekdays: list[str] = entry.data[CONF_WEEKDAY]
|
||||
self._filter_product = filter_product
|
||||
|
||||
async def _async_update_data(self) -> TrainData:
|
||||
"""Fetch data from Trafikverket."""
|
||||
@ -99,11 +102,11 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
|
||||
try:
|
||||
if self._time:
|
||||
state = await self._train_api.async_get_train_stop(
|
||||
self.from_station, self.to_station, when
|
||||
self.from_station, self.to_station, when, self._filter_product
|
||||
)
|
||||
else:
|
||||
state = await self._train_api.async_get_next_train_stop(
|
||||
self.from_station, self.to_station, when
|
||||
self.from_station, self.to_station, when, self._filter_product
|
||||
)
|
||||
except InvalidAuthentication as error:
|
||||
raise ConfigEntryAuthFailed from error
|
||||
@ -134,6 +137,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
|
||||
actual_time=_get_as_utc(state.time_at_location),
|
||||
other_info=_get_as_joined(state.other_information),
|
||||
deviation=_get_as_joined(state.deviations),
|
||||
product_filter=self._filter_product,
|
||||
)
|
||||
|
||||
return states
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Train information for departures and delays, provided by Trafikverket."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.components.sensor import (
|
||||
SensorDeviceClass,
|
||||
@ -22,6 +23,8 @@ from homeassistant.helpers.update_coordinator import CoordinatorEntity
|
||||
from .const import ATTRIBUTION, DOMAIN
|
||||
from .coordinator import TrainData, TVDataUpdateCoordinator
|
||||
|
||||
ATTR_PRODUCT_FILTER = "product_filter"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrafikverketRequiredKeysMixin:
|
||||
@ -158,3 +161,10 @@ class TrainSensor(CoordinatorEntity[TVDataUpdateCoordinator], SensorEntity):
|
||||
def _handle_coordinator_update(self) -> None:
|
||||
self._update_attr()
|
||||
return super()._handle_coordinator_update()
|
||||
|
||||
@property
|
||||
def extra_state_attributes(self) -> Mapping[str, Any] | None:
|
||||
"""Return additional attributes for Trafikverket Train sensor."""
|
||||
if self.coordinator.data.product_filter:
|
||||
return {ATTR_PRODUCT_FILTER: self.coordinator.data.product_filter}
|
||||
return None
|
||||
|
@ -20,10 +20,12 @@
|
||||
"to": "To station",
|
||||
"from": "From station",
|
||||
"time": "Time (optional)",
|
||||
"weekday": "Days"
|
||||
"weekday": "Days",
|
||||
"filter_product": "Filter by product description"
|
||||
},
|
||||
"data_description": {
|
||||
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure"
|
||||
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure",
|
||||
"filter_product": "To filter by product description add the phrase here to match"
|
||||
}
|
||||
},
|
||||
"reauth_confirm": {
|
||||
@ -33,6 +35,18 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"filter_product": "[%key:component::trafikverket_train::config::step::user::data::filter_product%]"
|
||||
},
|
||||
"data_description": {
|
||||
"filter_product": "[%key:component::trafikverket_train::config::step::user::data_description::filter_product%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"selector": {
|
||||
"weekday": {
|
||||
"options": {
|
||||
@ -49,7 +63,12 @@
|
||||
"entity": {
|
||||
"sensor": {
|
||||
"departure_time": {
|
||||
"name": "Departure time"
|
||||
"name": "Departure time",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "Train filtering"
|
||||
}
|
||||
}
|
||||
},
|
||||
"departure_state": {
|
||||
"name": "Departure state",
|
||||
@ -57,28 +76,68 @@
|
||||
"on_time": "On time",
|
||||
"delayed": "Delayed",
|
||||
"canceled": "Cancelled"
|
||||
},
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"cancelled": {
|
||||
"name": "Cancelled"
|
||||
"name": "Cancelled",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"delayed_time": {
|
||||
"name": "Delayed time"
|
||||
"name": "Delayed time",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"planned_time": {
|
||||
"name": "Planned time"
|
||||
"name": "Planned time",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"estimated_time": {
|
||||
"name": "Estimated time"
|
||||
"name": "Estimated time",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"actual_time": {
|
||||
"name": "Actual time"
|
||||
"name": "Actual time",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"other_info": {
|
||||
"name": "Other information"
|
||||
"name": "Other information",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"deviation": {
|
||||
"name": "Deviation"
|
||||
"name": "Deviation",
|
||||
"state_attributes": {
|
||||
"product_filter": {
|
||||
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -66,6 +66,7 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
"time": "10:00",
|
||||
"weekday": ["mon", "fri"],
|
||||
}
|
||||
assert result["options"] == {"filter_product": None}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert result["result"].unique_id == "{}-{}-{}-{}".format(
|
||||
"stockholmc", "uppsalac", "10:00", "['mon', 'fri']"
|
||||
@ -448,3 +449,55 @@ async def test_reauth_flow_error_departures(
|
||||
"time": "10:00",
|
||||
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
|
||||
}
|
||||
|
||||
|
||||
async def test_options_flow(hass: HomeAssistant) -> None:
|
||||
"""Test a reauthentication flow."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_API_KEY: "1234567890",
|
||||
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
|
||||
CONF_FROM: "Stockholm C",
|
||||
CONF_TO: "Uppsala C",
|
||||
CONF_TIME: "10:00",
|
||||
CONF_WEEKDAY: WEEKDAYS,
|
||||
},
|
||||
unique_id=f"stockholmc-uppsalac-10:00-{WEEKDAYS}",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.trafikverket_train.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
assert await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={"filter_product": "SJ Regionaltåg"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {"filter_product": "SJ Regionaltåg"}
|
||||
|
||||
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
result = await hass.config_entries.options.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={"filter_product": ""},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {"filter_product": None}
|
||||
|
Loading…
Reference in New Issue
Block a user