1
mirror of https://github.com/home-assistant/core synced 2024-09-06 10:29:55 +02:00

Add product filtering feature to Trafikverket Train (#86343)

This commit is contained in:
G Johansson 2023-08-09 17:20:30 +02:00 committed by GitHub
parent 0317afeb17
commit 4c03077dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 203 additions and 20 deletions

View File

@ -15,7 +15,7 @@ from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .coordinator import TVDataUpdateCoordinator
@ -36,7 +36,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
f" {entry.data[CONF_TO]}. Error: {error} "
) from error
coordinator = TVDataUpdateCoordinator(hass, entry, to_station, from_station)
coordinator = TVDataUpdateCoordinator(
hass, entry, to_station, from_station, entry.options.get(CONF_FILTER_PRODUCT)
)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
@ -49,6 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
entry.async_on_unload(entry.add_update_listener(update_listener))
return True
@ -57,3 +60,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Trafikverket Weatherstation config entry."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)

View File

@ -19,7 +19,7 @@ import voluptuous as vol
from homeassistant import config_entries
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
@ -32,11 +32,15 @@ from homeassistant.helpers.selector import (
)
import homeassistant.util.dt as dt_util
from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import create_unique_id, next_departuredate
_LOGGER = logging.getLogger(__name__)
OPTION_SCHEMA = {
vol.Optional(CONF_FILTER_PRODUCT, default=""): TextSelector(),
}
DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_API_KEY): TextSelector(),
@ -52,7 +56,7 @@ DATA_SCHEMA = vol.Schema(
)
),
}
)
).extend(OPTION_SCHEMA)
DATA_SCHEMA_REAUTH = vol.Schema(
{
vol.Required(CONF_API_KEY): cv.string,
@ -67,6 +71,7 @@ async def validate_input(
train_to: str,
train_time: str | None,
weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]:
"""Validate input from user input."""
errors: dict[str, str] = {}
@ -87,9 +92,13 @@ async def validate_input(
from_station = await train_api.async_get_train_station(train_from)
to_station = await train_api.async_get_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(from_station, to_station, when)
await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else:
await train_api.async_get_next_train_stop(from_station, to_station, when)
await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
except InvalidAuthentication:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
@ -117,6 +126,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
entry: config_entries.ConfigEntry | None
@staticmethod
@callback
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> TVTrainOptionsFlowHandler:
"""Get the options flow for this handler."""
return TVTrainOptionsFlowHandler(config_entry)
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle re-authentication with Trafikverket."""
@ -140,6 +157,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.entry.data[CONF_TO],
self.entry.data.get(CONF_TIME),
self.entry.data[CONF_WEEKDAY],
self.entry.options.get(CONF_FILTER_PRODUCT),
)
if not errors:
self.hass.config_entries.async_update_entry(
@ -170,6 +188,10 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
train_to: str = user_input[CONF_TO]
train_time: str | None = user_input.get(CONF_TIME)
train_days: list = user_input[CONF_WEEKDAY]
filter_product: str | None = user_input[CONF_FILTER_PRODUCT]
if filter_product == "":
filter_product = None
name = f"{train_from} to {train_to}"
if train_time:
@ -182,6 +204,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
train_to,
train_time,
train_days,
filter_product,
)
if not errors:
unique_id = create_unique_id(
@ -199,6 +222,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
return self.async_show_form(
@ -208,3 +232,27 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
),
errors=errors,
)
class TVTrainOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
"""Handle Trafikverket Train options."""
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage Trafikverket Train options."""
errors: dict[str, Any] = {}
if user_input:
if not (_filter := user_input.get(CONF_FILTER_PRODUCT)) or _filter == "":
user_input[CONF_FILTER_PRODUCT] = None
return self.async_create_entry(data=user_input)
return self.async_show_form(
step_id="init",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(OPTION_SCHEMA),
user_input or self.options,
),
errors=errors,
)

View File

@ -8,3 +8,4 @@ ATTRIBUTION = "Data provided by Trafikverket"
CONF_FROM = "from"
CONF_TO = "to"
CONF_TIME = "time"
CONF_FILTER_PRODUCT = "filter_product"

View File

@ -39,6 +39,7 @@ class TrainData:
actual_time: datetime | None
other_info: str | None
deviation: str | None
product_filter: str | None
_LOGGER = logging.getLogger(__name__)
@ -68,6 +69,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
entry: ConfigEntry,
to_station: StationInfo,
from_station: StationInfo,
filter_product: str | None,
) -> None:
"""Initialize the Trafikverket coordinator."""
super().__init__(
@ -83,6 +85,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
self.to_station: StationInfo = to_station
self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME])
self._weekdays: list[str] = entry.data[CONF_WEEKDAY]
self._filter_product = filter_product
async def _async_update_data(self) -> TrainData:
"""Fetch data from Trafikverket."""
@ -99,11 +102,11 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
try:
if self._time:
state = await self._train_api.async_get_train_stop(
self.from_station, self.to_station, when
self.from_station, self.to_station, when, self._filter_product
)
else:
state = await self._train_api.async_get_next_train_stop(
self.from_station, self.to_station, when
self.from_station, self.to_station, when, self._filter_product
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
@ -134,6 +137,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
actual_time=_get_as_utc(state.time_at_location),
other_info=_get_as_joined(state.other_information),
deviation=_get_as_joined(state.deviations),
product_filter=self._filter_product,
)
return states

View File

@ -1,9 +1,10 @@
"""Train information for departures and delays, provided by Trafikverket."""
from __future__ import annotations
from collections.abc import Callable
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from homeassistant.components.sensor import (
SensorDeviceClass,
@ -22,6 +23,8 @@ from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import ATTRIBUTION, DOMAIN
from .coordinator import TrainData, TVDataUpdateCoordinator
ATTR_PRODUCT_FILTER = "product_filter"
@dataclass
class TrafikverketRequiredKeysMixin:
@ -158,3 +161,10 @@ class TrainSensor(CoordinatorEntity[TVDataUpdateCoordinator], SensorEntity):
def _handle_coordinator_update(self) -> None:
self._update_attr()
return super()._handle_coordinator_update()
@property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return additional attributes for Trafikverket Train sensor."""
if self.coordinator.data.product_filter:
return {ATTR_PRODUCT_FILTER: self.coordinator.data.product_filter}
return None

View File

@ -20,10 +20,12 @@
"to": "To station",
"from": "From station",
"time": "Time (optional)",
"weekday": "Days"
"weekday": "Days",
"filter_product": "Filter by product description"
},
"data_description": {
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure"
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure",
"filter_product": "To filter by product description add the phrase here to match"
}
},
"reauth_confirm": {
@ -33,6 +35,18 @@
}
}
},
"options": {
"step": {
"init": {
"data": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data::filter_product%]"
},
"data_description": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data_description::filter_product%]"
}
}
}
},
"selector": {
"weekday": {
"options": {
@ -49,7 +63,12 @@
"entity": {
"sensor": {
"departure_time": {
"name": "Departure time"
"name": "Departure time",
"state_attributes": {
"product_filter": {
"name": "Train filtering"
}
}
},
"departure_state": {
"name": "Departure state",
@ -57,28 +76,68 @@
"on_time": "On time",
"delayed": "Delayed",
"canceled": "Cancelled"
},
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"cancelled": {
"name": "Cancelled"
"name": "Cancelled",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"delayed_time": {
"name": "Delayed time"
"name": "Delayed time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"planned_time": {
"name": "Planned time"
"name": "Planned time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"estimated_time": {
"name": "Estimated time"
"name": "Estimated time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"actual_time": {
"name": "Actual time"
"name": "Actual time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"other_info": {
"name": "Other information"
"name": "Other information",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
},
"deviation": {
"name": "Deviation"
"name": "Deviation",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}
}
}

View File

@ -66,6 +66,7 @@ async def test_form(hass: HomeAssistant) -> None:
"time": "10:00",
"weekday": ["mon", "fri"],
}
assert result["options"] == {"filter_product": None}
assert len(mock_setup_entry.mock_calls) == 1
assert result["result"].unique_id == "{}-{}-{}-{}".format(
"stockholmc", "uppsalac", "10:00", "['mon', 'fri']"
@ -448,3 +449,55 @@ async def test_reauth_flow_error_departures(
"time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
}
async def test_options_flow(hass: HomeAssistant) -> None:
"""Test a reauthentication flow."""
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
unique_id=f"stockholmc-uppsalac-10:00-{WEEKDAYS}",
)
entry.add_to_hass(hass)
with patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True,
):
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "init"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"filter_product": "SJ Regionaltåg"},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {"filter_product": "SJ Regionaltåg"}
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "init"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"filter_product": ""},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {"filter_product": None}