Add filter to strict-typing (#86215)

* Add filter to strict-typing

* Adjust comment
This commit is contained in:
epenet 2023-01-19 11:07:42 +01:00 committed by GitHub
parent 3f348714e2
commit 6802f3db30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 16 deletions

View File

@ -112,6 +112,7 @@ homeassistant.components.fastdotcom.*
homeassistant.components.feedreader.* homeassistant.components.feedreader.*
homeassistant.components.file_upload.* homeassistant.components.file_upload.*
homeassistant.components.filesize.* homeassistant.components.filesize.*
homeassistant.components.filter.*
homeassistant.components.fitbit.* homeassistant.components.fitbit.*
homeassistant.components.flux_led.* homeassistant.components.flux_led.*
homeassistant.components.forecast_solar.* homeassistant.components.forecast_solar.*

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from collections import Counter, deque from collections import Counter, deque
from copy import copy from copy import copy
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial from functools import partial
import logging import logging
@ -40,7 +41,7 @@ import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.event import async_track_state_change_event
from homeassistant.helpers.reload import async_setup_reload_service from homeassistant.helpers.reload import async_setup_reload_service
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, StateType
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
@ -211,7 +212,7 @@ class SensorFilter(SensorEntity):
self._attr_unique_id = unique_id self._attr_unique_id = unique_id
self._entity = entity_id self._entity = entity_id
self._attr_native_unit_of_measurement = None self._attr_native_unit_of_measurement = None
self._state: str | None = None self._state: StateType = None
self._filters = filters self._filters = filters
self._attr_icon = None self._attr_icon = None
self._attr_device_class = None self._attr_device_class = None
@ -242,7 +243,7 @@ class SensorFilter(SensorEntity):
self.async_write_ha_state() self.async_write_ha_state()
return return
temp_state = new_state temp_state = _State(new_state.last_updated, new_state.state)
try: try:
for filt in self._filters: for filt in self._filters:
@ -361,10 +362,10 @@ class SensorFilter(SensorEntity):
) )
@property @property
def native_value(self) -> datetime | str | None: def native_value(self) -> datetime | StateType:
"""Return the state of the sensor.""" """Return the state of the sensor."""
if self._state is not None and self.device_class == SensorDeviceClass.TIMESTAMP: if self._state is not None and self.device_class == SensorDeviceClass.TIMESTAMP:
return datetime.fromisoformat(self._state) return datetime.fromisoformat(str(self._state))
return self._state return self._state
@ -372,7 +373,9 @@ class SensorFilter(SensorEntity):
class FilterState: class FilterState:
"""State abstraction for filter usage.""" """State abstraction for filter usage."""
def __init__(self, state): state: str | float | int
def __init__(self, state: _State) -> None:
"""Initialize with HA State object.""" """Initialize with HA State object."""
self.timestamp = state.last_updated self.timestamp = state.last_updated
try: try:
@ -380,7 +383,7 @@ class FilterState:
except ValueError: except ValueError:
self.state = state.state self.state = state.state
def set_precision(self, precision): def set_precision(self, precision: int) -> None:
"""Set precision of Number based states.""" """Set precision of Number based states."""
if isinstance(self.state, Number): if isinstance(self.state, Number):
value = round(float(self.state), precision) value = round(float(self.state), precision)
@ -395,6 +398,18 @@ class FilterState:
return f"{self.timestamp} : {self.state}" return f"{self.timestamp} : {self.state}"
@dataclass
class _State:
"""Simplified State class.
The standard State class only accepts string in `state`,
and we are only interested in two properties.
"""
last_updated: datetime
state: str | float | int
class Filter: class Filter:
"""Filter skeleton.""" """Filter skeleton."""
@ -444,7 +459,7 @@ class Filter:
"""Implement filter.""" """Implement filter."""
raise NotImplementedError() raise NotImplementedError()
def filter_state(self, new_state: State) -> State: def filter_state(self, new_state: _State) -> _State:
"""Implement a common interface for filters.""" """Implement a common interface for filters."""
fstate = FilterState(new_state) fstate = FilterState(new_state)
if self._only_numbers and not isinstance(fstate.state, Number): if self._only_numbers and not isinstance(fstate.state, Number):
@ -488,7 +503,10 @@ class RangeFilter(Filter, SensorEntity):
def _filter_state(self, new_state: FilterState) -> FilterState: def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the range filter.""" """Implement the range filter."""
if self._upper_bound is not None and new_state.state > self._upper_bound: # We can cast safely here thanks to self._only_numbers = True
new_state_value = cast(float, new_state.state)
if self._upper_bound is not None and new_state_value > self._upper_bound:
self._stats_internal["erasures_up"] += 1 self._stats_internal["erasures_up"] += 1
@ -500,7 +518,7 @@ class RangeFilter(Filter, SensorEntity):
) )
new_state.state = self._upper_bound new_state.state = self._upper_bound
elif self._lower_bound is not None and new_state.state < self._lower_bound: elif self._lower_bound is not None and new_state_value < self._lower_bound:
self._stats_internal["erasures_low"] += 1 self._stats_internal["erasures_low"] += 1
@ -537,10 +555,14 @@ class OutlierFilter(Filter, SensorEntity):
def _filter_state(self, new_state: FilterState) -> FilterState: def _filter_state(self, new_state: FilterState) -> FilterState:
"""Implement the outlier filter.""" """Implement the outlier filter."""
median = statistics.median([s.state for s in self.states]) if self.states else 0 # We can cast safely here thanks to self._only_numbers = True
previous_state_values = [cast(float, s.state) for s in self.states]
new_state_value = cast(float, new_state.state)
median = statistics.median(previous_state_values) if self.states else 0
if ( if (
len(self.states) == self.states.maxlen len(self.states) == self.states.maxlen
and abs(new_state.state - median) > self._radius and abs(new_state_value - median) > self._radius
): ):
self._stats_internal["erasures"] += 1 self._stats_internal["erasures"] += 1
@ -574,9 +596,10 @@ class LowPassFilter(Filter, SensorEntity):
new_weight = 1.0 / self._time_constant new_weight = 1.0 / self._time_constant
prev_weight = 1.0 - new_weight prev_weight = 1.0 - new_weight
new_state.state = ( # We can cast safely here thanks to self._only_numbers = True
prev_weight * self.states[-1].state + new_weight * new_state.state prev_state_value = cast(float, self.states[-1].state)
) new_state_value = cast(float, new_state.state)
new_state.state = prev_weight * prev_state_value + new_weight * new_state_value
return new_state return new_state
@ -622,7 +645,9 @@ class TimeSMAFilter(Filter, SensorEntity):
start = new_state.timestamp - self._time_window start = new_state.timestamp - self._time_window
prev_state = self.last_leak if self.last_leak is not None else self.queue[0] prev_state = self.last_leak if self.last_leak is not None else self.queue[0]
for state in self.queue: for state in self.queue:
moving_sum += (state.timestamp - start).total_seconds() * prev_state.state # We can cast safely here thanks to self._only_numbers = True
prev_state_value = cast(float, prev_state.state)
moving_sum += (state.timestamp - start).total_seconds() * prev_state_value
start = state.timestamp start = state.timestamp
prev_state = state prev_state = state

View File

@ -874,6 +874,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.filter.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.fitbit.*] [mypy-homeassistant.components.fitbit.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true