mirror of https://github.com/home-assistant/core
Add Black
This commit is contained in:
parent
0490167a12
commit
da05dfe708
|
@ -17,6 +17,10 @@
|
|||
"python.pythonPath": "/usr/local/bin/python",
|
||||
"python.linting.pylintEnabled": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.formatting.provider": "black",
|
||||
"editor.formatOnPaste": false,
|
||||
"editor.formatOnSave": true,
|
||||
"editor.formatOnType": true,
|
||||
"files.trimTrailingWhitespace": true,
|
||||
"editor.rulers": [80],
|
||||
"terminal.integrated.shell.linux": "/bin/bash",
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
repos:
|
||||
- repo: https://github.com/python/black
|
||||
rev: 19.3b0
|
||||
hooks:
|
||||
- id: black
|
||||
args:
|
||||
- --safe
|
||||
- --quiet
|
|
@ -38,7 +38,7 @@ stages:
|
|||
python -m venv venv
|
||||
|
||||
. venv/bin/activate
|
||||
pip install -r requirements_test.txt
|
||||
pip install -r requirements_test.txt -c homeassistant/package_constraints.txt
|
||||
displayName: 'Setup Env'
|
||||
- script: |
|
||||
. venv/bin/activate
|
||||
|
@ -63,6 +63,21 @@ stages:
|
|||
. venv/bin/activate
|
||||
./script/gen_requirements_all.py validate
|
||||
displayName: 'requirements_all validate'
|
||||
- job: 'CheckFormat'
|
||||
pool:
|
||||
vmImage: 'ubuntu-latest'
|
||||
container: $[ variables['PythonMain'] ]
|
||||
steps:
|
||||
- script: |
|
||||
python -m venv venv
|
||||
|
||||
. venv/bin/activate
|
||||
pip install -r requirements_test.txt -c homeassistant/package_constraints.txt
|
||||
displayName: 'Setup Env'
|
||||
- script: |
|
||||
. venv/bin/activate
|
||||
./script/check_format
|
||||
displayName: 'Check Black formatting'
|
||||
|
||||
- stage: 'Tests'
|
||||
dependsOn:
|
||||
|
|
|
@ -21,42 +21,42 @@ from homeassistant.helpers.entity import Entity
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ATTR_STOP_ID = 'stop_id'
|
||||
ATTR_STOP_NAME = 'stop'
|
||||
ATTR_ROUTE = 'route'
|
||||
ATTR_TYPE = 'type'
|
||||
ATTR_STOP_ID = "stop_id"
|
||||
ATTR_STOP_NAME = "stop"
|
||||
ATTR_ROUTE = "route"
|
||||
ATTR_TYPE = "type"
|
||||
ATTR_DIRECTION = "direction"
|
||||
ATTR_DUE_IN = 'due_in'
|
||||
ATTR_DUE_AT = 'due_at'
|
||||
ATTR_NEXT_UP = 'next_departures'
|
||||
ATTR_DUE_IN = "due_in"
|
||||
ATTR_DUE_AT = "due_at"
|
||||
ATTR_NEXT_UP = "next_departures"
|
||||
|
||||
ATTRIBUTION = "Data provided by rejseplanen.dk"
|
||||
|
||||
CONF_STOP_ID = 'stop_id'
|
||||
CONF_ROUTE = 'route'
|
||||
CONF_DIRECTION = 'direction'
|
||||
CONF_DEPARTURE_TYPE = 'departure_type'
|
||||
CONF_STOP_ID = "stop_id"
|
||||
CONF_ROUTE = "route"
|
||||
CONF_DIRECTION = "direction"
|
||||
CONF_DEPARTURE_TYPE = "departure_type"
|
||||
|
||||
DEFAULT_NAME = 'Next departure'
|
||||
ICON = 'mdi:bus'
|
||||
DEFAULT_NAME = "Next departure"
|
||||
ICON = "mdi:bus"
|
||||
|
||||
SCAN_INTERVAL = timedelta(minutes=1)
|
||||
|
||||
BUS_TYPES = ['BUS', 'EXB', 'TB']
|
||||
TRAIN_TYPES = ['LET', 'S', 'REG', 'IC', 'LYN', 'TOG']
|
||||
METRO_TYPES = ['M']
|
||||
BUS_TYPES = ["BUS", "EXB", "TB"]
|
||||
TRAIN_TYPES = ["LET", "S", "REG", "IC", "LYN", "TOG"]
|
||||
METRO_TYPES = ["M"]
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
|
||||
vol.Required(CONF_STOP_ID): cv.string,
|
||||
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
|
||||
vol.Optional(CONF_ROUTE, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(CONF_DIRECTION, default=[]):
|
||||
vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(CONF_DEPARTURE_TYPE, default=[]):
|
||||
vol.All(cv.ensure_list,
|
||||
[vol.In([*BUS_TYPES, *TRAIN_TYPES, *METRO_TYPES])])
|
||||
})
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
{
|
||||
vol.Required(CONF_STOP_ID): cv.string,
|
||||
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
|
||||
vol.Optional(CONF_ROUTE, default=[]): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(CONF_DIRECTION, default=[]): vol.All(cv.ensure_list, [cv.string]),
|
||||
vol.Optional(CONF_DEPARTURE_TYPE, default=[]): vol.All(
|
||||
cv.ensure_list, [vol.In([*BUS_TYPES, *TRAIN_TYPES, *METRO_TYPES])]
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def due_in_minutes(timestamp):
|
||||
|
@ -64,8 +64,9 @@ def due_in_minutes(timestamp):
|
|||
|
||||
The timestamp should be in the format day.month.year hour:minute
|
||||
"""
|
||||
diff = datetime.strptime(
|
||||
timestamp, "%d.%m.%y %H:%M") - dt_util.now().replace(tzinfo=None)
|
||||
diff = datetime.strptime(timestamp, "%d.%m.%y %H:%M") - dt_util.now().replace(
|
||||
tzinfo=None
|
||||
)
|
||||
|
||||
return int(diff.total_seconds() // 60)
|
||||
|
||||
|
@ -79,8 +80,9 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
|
|||
departure_type = config[CONF_DEPARTURE_TYPE]
|
||||
|
||||
data = PublicTransportData(stop_id, route, direction, departure_type)
|
||||
add_devices([RejseplanenTransportSensor(
|
||||
data, stop_id, route, direction, name)], True)
|
||||
add_devices(
|
||||
[RejseplanenTransportSensor(data, stop_id, route, direction, name)], True
|
||||
)
|
||||
|
||||
|
||||
class RejseplanenTransportSensor(Entity):
|
||||
|
@ -124,14 +126,14 @@ class RejseplanenTransportSensor(Entity):
|
|||
ATTR_STOP_NAME: self._times[0][ATTR_STOP_NAME],
|
||||
ATTR_STOP_ID: self._stop_id,
|
||||
ATTR_ATTRIBUTION: ATTRIBUTION,
|
||||
ATTR_NEXT_UP: next_up
|
||||
ATTR_NEXT_UP: next_up,
|
||||
}
|
||||
return {k: v for k, v in params.items() if v}
|
||||
|
||||
@property
|
||||
def unit_of_measurement(self):
|
||||
"""Return the unit this state is expressed in."""
|
||||
return 'min'
|
||||
return "min"
|
||||
|
||||
@property
|
||||
def icon(self):
|
||||
|
@ -148,7 +150,7 @@ class RejseplanenTransportSensor(Entity):
|
|||
pass
|
||||
|
||||
|
||||
class PublicTransportData():
|
||||
class PublicTransportData:
|
||||
"""The Class for handling the data retrieval."""
|
||||
|
||||
def __init__(self, stop_id, route, direction, departure_type):
|
||||
|
@ -161,16 +163,21 @@ class PublicTransportData():
|
|||
|
||||
def empty_result(self):
|
||||
"""Object returned when no departures are found."""
|
||||
return [{ATTR_DUE_IN: 'n/a',
|
||||
ATTR_DUE_AT: 'n/a',
|
||||
ATTR_TYPE: 'n/a',
|
||||
ATTR_ROUTE: self.route,
|
||||
ATTR_DIRECTION: 'n/a',
|
||||
ATTR_STOP_NAME: 'n/a'}]
|
||||
return [
|
||||
{
|
||||
ATTR_DUE_IN: "n/a",
|
||||
ATTR_DUE_AT: "n/a",
|
||||
ATTR_TYPE: "n/a",
|
||||
ATTR_ROUTE: self.route,
|
||||
ATTR_DIRECTION: "n/a",
|
||||
ATTR_STOP_NAME: "n/a",
|
||||
}
|
||||
]
|
||||
|
||||
def update(self):
|
||||
"""Get the latest data from rejseplanen."""
|
||||
import rjpl
|
||||
|
||||
self.info = []
|
||||
|
||||
def intersection(lst1, lst2):
|
||||
|
@ -179,12 +186,9 @@ class PublicTransportData():
|
|||
|
||||
# Limit search to selected types, to get more results
|
||||
all_types = not bool(self.departure_type)
|
||||
use_train = all_types or bool(
|
||||
intersection(TRAIN_TYPES, self.departure_type))
|
||||
use_bus = all_types or bool(
|
||||
intersection(BUS_TYPES, self.departure_type))
|
||||
use_metro = all_types or bool(
|
||||
intersection(METRO_TYPES, self.departure_type))
|
||||
use_train = all_types or bool(intersection(TRAIN_TYPES, self.departure_type))
|
||||
use_bus = all_types or bool(intersection(BUS_TYPES, self.departure_type))
|
||||
use_metro = all_types or bool(intersection(METRO_TYPES, self.departure_type))
|
||||
|
||||
try:
|
||||
results = rjpl.departureBoard(
|
||||
|
@ -192,7 +196,7 @@ class PublicTransportData():
|
|||
timeout=5,
|
||||
useTrain=use_train,
|
||||
useBus=use_bus,
|
||||
useMetro=use_metro
|
||||
useMetro=use_metro,
|
||||
)
|
||||
except rjpl.rjplAPIError as error:
|
||||
_LOGGER.debug("API returned error: %s", error)
|
||||
|
@ -204,36 +208,40 @@ class PublicTransportData():
|
|||
return
|
||||
|
||||
# Filter result
|
||||
results = [d for d in results if 'cancelled' not in d]
|
||||
results = [d for d in results if "cancelled" not in d]
|
||||
if self.route:
|
||||
results = [d for d in results if d['name'] in self.route]
|
||||
results = [d for d in results if d["name"] in self.route]
|
||||
if self.direction:
|
||||
results = [d for d in results if d['direction'] in self.direction]
|
||||
results = [d for d in results if d["direction"] in self.direction]
|
||||
if self.departure_type:
|
||||
results = [d for d in results if d['type'] in self.departure_type]
|
||||
results = [d for d in results if d["type"] in self.departure_type]
|
||||
|
||||
for item in results:
|
||||
route = item.get('name')
|
||||
route = item.get("name")
|
||||
|
||||
due_at_date = item.get('rtDate')
|
||||
due_at_time = item.get('rtTime')
|
||||
due_at_date = item.get("rtDate")
|
||||
due_at_time = item.get("rtTime")
|
||||
|
||||
if due_at_date is None:
|
||||
due_at_date = item.get('date') # Scheduled date
|
||||
due_at_date = item.get("date") # Scheduled date
|
||||
if due_at_time is None:
|
||||
due_at_time = item.get('time') # Scheduled time
|
||||
due_at_time = item.get("time") # Scheduled time
|
||||
|
||||
if (due_at_date is not None and
|
||||
due_at_time is not None and
|
||||
route is not None):
|
||||
due_at = '{} {}'.format(due_at_date, due_at_time)
|
||||
if (
|
||||
due_at_date is not None
|
||||
and due_at_time is not None
|
||||
and route is not None
|
||||
):
|
||||
due_at = "{} {}".format(due_at_date, due_at_time)
|
||||
|
||||
departure_data = {ATTR_DUE_IN: due_in_minutes(due_at),
|
||||
ATTR_DUE_AT: due_at,
|
||||
ATTR_TYPE: item.get('type'),
|
||||
ATTR_ROUTE: route,
|
||||
ATTR_DIRECTION: item.get('direction'),
|
||||
ATTR_STOP_NAME: item.get('stop')}
|
||||
departure_data = {
|
||||
ATTR_DUE_IN: due_in_minutes(due_at),
|
||||
ATTR_DUE_AT: due_at,
|
||||
ATTR_TYPE: item.get("type"),
|
||||
ATTR_ROUTE: route,
|
||||
ATTR_DIRECTION: item.get("direction"),
|
||||
ATTR_STOP_NAME: item.get("stop"),
|
||||
}
|
||||
self.info.append(departure_data)
|
||||
|
||||
if not self.info:
|
||||
|
|
|
@ -14,11 +14,19 @@ from random import uniform
|
|||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from ..helpers import (
|
||||
configure_reporting, construct_unique_id,
|
||||
safe_read, get_attr_id_by_name, bind_cluster, LogMixin)
|
||||
configure_reporting,
|
||||
construct_unique_id,
|
||||
safe_read,
|
||||
get_attr_id_by_name,
|
||||
bind_cluster,
|
||||
LogMixin,
|
||||
)
|
||||
from ..const import (
|
||||
REPORT_CONFIG_DEFAULT, SIGNAL_ATTR_UPDATED, ATTRIBUTE_CHANNEL,
|
||||
EVENT_RELAY_CHANNEL, ZDO_CHANNEL
|
||||
REPORT_CONFIG_DEFAULT,
|
||||
SIGNAL_ATTR_UPDATED,
|
||||
ATTRIBUTE_CHANNEL,
|
||||
EVENT_RELAY_CHANNEL,
|
||||
ZDO_CHANNEL,
|
||||
)
|
||||
from ..registries import CLUSTER_REPORT_CONFIGS
|
||||
|
||||
|
@ -33,32 +41,33 @@ def parse_and_log_command(channel, tsn, command_id, args):
|
|||
cmd,
|
||||
args,
|
||||
channel.cluster.cluster_id,
|
||||
tsn
|
||||
tsn,
|
||||
)
|
||||
return cmd
|
||||
|
||||
|
||||
def decorate_command(channel, command):
|
||||
"""Wrap a cluster command to make it safe."""
|
||||
|
||||
@wraps(command)
|
||||
async def wrapper(*args, **kwds):
|
||||
from zigpy.exceptions import DeliveryError
|
||||
|
||||
try:
|
||||
result = await command(*args, **kwds)
|
||||
channel.debug("executed command: %s %s %s %s",
|
||||
command.__name__,
|
||||
"{}: {}".format("with args", args),
|
||||
"{}: {}".format("with kwargs", kwds),
|
||||
"{}: {}".format("and result", result))
|
||||
channel.debug(
|
||||
"executed command: %s %s %s %s",
|
||||
command.__name__,
|
||||
"{}: {}".format("with args", args),
|
||||
"{}: {}".format("with kwargs", kwds),
|
||||
"{}: {}".format("and result", result),
|
||||
)
|
||||
return result
|
||||
|
||||
except (DeliveryError, Timeout) as ex:
|
||||
channel.debug(
|
||||
"command failed: %s exception: %s",
|
||||
command.__name__,
|
||||
str(ex)
|
||||
)
|
||||
channel.debug("command failed: %s exception: %s", command.__name__, str(ex))
|
||||
return ex
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
|
@ -80,13 +89,12 @@ class ZigbeeChannel(LogMixin):
|
|||
self._channel_name = cluster.ep_attribute
|
||||
if self.CHANNEL_NAME:
|
||||
self._channel_name = self.CHANNEL_NAME
|
||||
self._generic_id = 'channel_0x{:04x}'.format(cluster.cluster_id)
|
||||
self._generic_id = "channel_0x{:04x}".format(cluster.cluster_id)
|
||||
self._cluster = cluster
|
||||
self._zha_device = device
|
||||
self._unique_id = construct_unique_id(cluster)
|
||||
self._report_config = CLUSTER_REPORT_CONFIGS.get(
|
||||
self._cluster.cluster_id,
|
||||
[{'attr': 0, 'config': REPORT_CONFIG_DEFAULT}]
|
||||
self._cluster.cluster_id, [{"attr": 0, "config": REPORT_CONFIG_DEFAULT}]
|
||||
)
|
||||
self._status = ChannelStatus.CREATED
|
||||
self._cluster.add_listener(self)
|
||||
|
@ -130,21 +138,24 @@ class ZigbeeChannel(LogMixin):
|
|||
manufacturer = None
|
||||
manufacturer_code = self._zha_device.manufacturer_code
|
||||
# Xiaomi devices don't need this and it disrupts pairing
|
||||
if self._zha_device.manufacturer != 'LUMI':
|
||||
if self.cluster.cluster_id >= 0xfc00 and manufacturer_code:
|
||||
if self._zha_device.manufacturer != "LUMI":
|
||||
if self.cluster.cluster_id >= 0xFC00 and manufacturer_code:
|
||||
manufacturer = manufacturer_code
|
||||
await bind_cluster(self._unique_id, self.cluster)
|
||||
if not self.cluster.bind_only:
|
||||
for report_config in self._report_config:
|
||||
attr = report_config.get('attr')
|
||||
min_report_interval, max_report_interval, change = \
|
||||
report_config.get('config')
|
||||
attr = report_config.get("attr")
|
||||
min_report_interval, max_report_interval, change = report_config.get(
|
||||
"config"
|
||||
)
|
||||
await configure_reporting(
|
||||
self._unique_id, self.cluster, attr,
|
||||
self._unique_id,
|
||||
self.cluster,
|
||||
attr,
|
||||
min_report=min_report_interval,
|
||||
max_report=max_report_interval,
|
||||
reportable_change=change,
|
||||
manufacturer=manufacturer
|
||||
manufacturer=manufacturer,
|
||||
)
|
||||
await asyncio.sleep(uniform(0.1, 0.5))
|
||||
|
||||
|
@ -153,7 +164,7 @@ class ZigbeeChannel(LogMixin):
|
|||
|
||||
async def async_initialize(self, from_cache):
|
||||
"""Initialize channel."""
|
||||
self.debug('initializing channel: from_cache: %s', from_cache)
|
||||
self.debug("initializing channel: from_cache: %s", from_cache)
|
||||
self._status = ChannelStatus.INITIALIZED
|
||||
|
||||
@callback
|
||||
|
@ -175,13 +186,13 @@ class ZigbeeChannel(LogMixin):
|
|||
def zha_send_event(self, cluster, command, args):
|
||||
"""Relay events to hass."""
|
||||
self._zha_device.hass.bus.async_fire(
|
||||
'zha_event',
|
||||
"zha_event",
|
||||
{
|
||||
'unique_id': self._unique_id,
|
||||
'device_ieee': str(self._zha_device.ieee),
|
||||
'command': command,
|
||||
'args': args
|
||||
}
|
||||
"unique_id": self._unique_id,
|
||||
"device_ieee": str(self._zha_device.ieee),
|
||||
"command": command,
|
||||
"args": args,
|
||||
},
|
||||
)
|
||||
|
||||
async def async_update(self):
|
||||
|
@ -192,14 +203,14 @@ class ZigbeeChannel(LogMixin):
|
|||
"""Get the value for an attribute."""
|
||||
manufacturer = None
|
||||
manufacturer_code = self._zha_device.manufacturer_code
|
||||
if self.cluster.cluster_id >= 0xfc00 and manufacturer_code:
|
||||
if self.cluster.cluster_id >= 0xFC00 and manufacturer_code:
|
||||
manufacturer = manufacturer_code
|
||||
result = await safe_read(
|
||||
self._cluster,
|
||||
[attribute],
|
||||
allow_cache=from_cache,
|
||||
only_cache=from_cache,
|
||||
manufacturer=manufacturer
|
||||
manufacturer=manufacturer,
|
||||
)
|
||||
return result.get(attribute)
|
||||
|
||||
|
@ -211,14 +222,10 @@ class ZigbeeChannel(LogMixin):
|
|||
|
||||
def __getattr__(self, name):
|
||||
"""Get attribute or a decorated cluster command."""
|
||||
if hasattr(self._cluster, name) and callable(
|
||||
getattr(self._cluster, name)):
|
||||
if hasattr(self._cluster, name) and callable(getattr(self._cluster, name)):
|
||||
command = getattr(self._cluster, name)
|
||||
command.__name__ = name
|
||||
return decorate_command(
|
||||
self,
|
||||
command
|
||||
)
|
||||
return decorate_command(self, command)
|
||||
return self.__getattribute__(name)
|
||||
|
||||
|
||||
|
@ -230,7 +237,7 @@ class AttributeListeningChannel(ZigbeeChannel):
|
|||
def __init__(self, cluster, device):
|
||||
"""Initialize AttributeListeningChannel."""
|
||||
super().__init__(cluster, device)
|
||||
attr = self._report_config[0].get('attr')
|
||||
attr = self._report_config[0].get("attr")
|
||||
if isinstance(attr, str):
|
||||
self.value_attribute = get_attr_id_by_name(self.cluster, attr)
|
||||
else:
|
||||
|
@ -243,13 +250,14 @@ class AttributeListeningChannel(ZigbeeChannel):
|
|||
async_dispatcher_send(
|
||||
self._zha_device.hass,
|
||||
"{}_{}".format(self.unique_id, SIGNAL_ATTR_UPDATED),
|
||||
value
|
||||
value,
|
||||
)
|
||||
|
||||
async def async_initialize(self, from_cache):
|
||||
"""Initialize listener."""
|
||||
await self.get_attribute_value(
|
||||
self._report_config[0].get('attr'), from_cache=from_cache)
|
||||
self._report_config[0].get("attr"), from_cache=from_cache
|
||||
)
|
||||
await super().async_initialize(from_cache)
|
||||
|
||||
|
||||
|
@ -293,7 +301,8 @@ class ZDOChannel(LogMixin):
|
|||
async def async_initialize(self, from_cache):
|
||||
"""Initialize channel."""
|
||||
entry = self._zha_device.gateway.zha_storage.async_get_or_create(
|
||||
self._zha_device)
|
||||
self._zha_device
|
||||
)
|
||||
self.debug("entry loaded from storage: %s", entry)
|
||||
self._status = ChannelStatus.INITIALIZED
|
||||
|
||||
|
@ -320,21 +329,19 @@ class EventRelayChannel(ZigbeeChannel):
|
|||
self._cluster,
|
||||
SIGNAL_ATTR_UPDATED,
|
||||
{
|
||||
'attribute_id': attrid,
|
||||
'attribute_name': self._cluster.attributes.get(
|
||||
attrid,
|
||||
['Unknown'])[0],
|
||||
'value': value
|
||||
}
|
||||
"attribute_id": attrid,
|
||||
"attribute_name": self._cluster.attributes.get(attrid, ["Unknown"])[0],
|
||||
"value": value,
|
||||
},
|
||||
)
|
||||
|
||||
@callback
|
||||
def cluster_command(self, tsn, command_id, args):
|
||||
"""Handle a cluster command received on this cluster."""
|
||||
if self._cluster.server_commands is not None and \
|
||||
self._cluster.server_commands.get(command_id) is not None:
|
||||
if (
|
||||
self._cluster.server_commands is not None
|
||||
and self._cluster.server_commands.get(command_id) is not None
|
||||
):
|
||||
self.zha_send_event(
|
||||
self._cluster,
|
||||
self._cluster.server_commands.get(command_id)[0],
|
||||
args
|
||||
self._cluster, self._cluster.server_commands.get(command_id)[0], args
|
||||
)
|
||||
|
|
|
@ -12,18 +12,46 @@ import time
|
|||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect, async_dispatcher_send)
|
||||
async_dispatcher_connect,
|
||||
async_dispatcher_send,
|
||||
)
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
|
||||
from .channels import EventRelayChannel
|
||||
from .const import (
|
||||
ATTR_ARGS, ATTR_ATTRIBUTE, ATTR_CLUSTER_ID, ATTR_COMMAND,
|
||||
ATTR_COMMAND_TYPE, ATTR_ENDPOINT_ID, ATTR_MANUFACTURER, ATTR_VALUE,
|
||||
BATTERY_OR_UNKNOWN, CLIENT_COMMANDS, IEEE, IN, MAINS_POWERED,
|
||||
MANUFACTURER_CODE, MODEL, NAME, NWK, OUT, POWER_CONFIGURATION_CHANNEL,
|
||||
POWER_SOURCE, QUIRK_APPLIED, QUIRK_CLASS, SERVER, SERVER_COMMANDS,
|
||||
SIGNAL_AVAILABLE, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZDO_CHANNEL,
|
||||
LQI, RSSI, LAST_SEEN, ATTR_AVAILABLE)
|
||||
ATTR_ARGS,
|
||||
ATTR_ATTRIBUTE,
|
||||
ATTR_CLUSTER_ID,
|
||||
ATTR_COMMAND,
|
||||
ATTR_COMMAND_TYPE,
|
||||
ATTR_ENDPOINT_ID,
|
||||
ATTR_MANUFACTURER,
|
||||
ATTR_VALUE,
|
||||
BATTERY_OR_UNKNOWN,
|
||||
CLIENT_COMMANDS,
|
||||
IEEE,
|
||||
IN,
|
||||
MAINS_POWERED,
|
||||
MANUFACTURER_CODE,
|
||||
MODEL,
|
||||
NAME,
|
||||
NWK,
|
||||
OUT,
|
||||
POWER_CONFIGURATION_CHANNEL,
|
||||
POWER_SOURCE,
|
||||
QUIRK_APPLIED,
|
||||
QUIRK_CLASS,
|
||||
SERVER,
|
||||
SERVER_COMMANDS,
|
||||
SIGNAL_AVAILABLE,
|
||||
UNKNOWN_MANUFACTURER,
|
||||
UNKNOWN_MODEL,
|
||||
ZDO_CHANNEL,
|
||||
LQI,
|
||||
RSSI,
|
||||
LAST_SEEN,
|
||||
ATTR_AVAILABLE,
|
||||
)
|
||||
from .helpers import LogMixin
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
@ -51,22 +79,20 @@ class ZHADevice(LogMixin):
|
|||
self._all_channels = []
|
||||
self._available = False
|
||||
self._available_signal = "{}_{}_{}".format(
|
||||
self.name, self.ieee, SIGNAL_AVAILABLE)
|
||||
self.name, self.ieee, SIGNAL_AVAILABLE
|
||||
)
|
||||
self._unsub = async_dispatcher_connect(
|
||||
self.hass,
|
||||
self._available_signal,
|
||||
self.async_initialize
|
||||
self.hass, self._available_signal, self.async_initialize
|
||||
)
|
||||
from zigpy.quirks import CustomDevice
|
||||
|
||||
self.quirk_applied = isinstance(self._zigpy_device, CustomDevice)
|
||||
self.quirk_class = "{}.{}".format(
|
||||
self._zigpy_device.__class__.__module__,
|
||||
self._zigpy_device.__class__.__name__
|
||||
self._zigpy_device.__class__.__name__,
|
||||
)
|
||||
self._available_check = async_track_time_interval(
|
||||
self.hass,
|
||||
self._check_available,
|
||||
_UPDATE_ALIVE_INTERVAL
|
||||
self.hass, self._check_available, _UPDATE_ALIVE_INTERVAL
|
||||
)
|
||||
self.status = DeviceStatus.CREATED
|
||||
|
||||
|
@ -184,15 +210,9 @@ class ZHADevice(LogMixin):
|
|||
"""Set sensor availability."""
|
||||
if self._available != available and available:
|
||||
# Update the state the first time the device comes online
|
||||
async_dispatcher_send(
|
||||
self.hass,
|
||||
self._available_signal,
|
||||
False
|
||||
)
|
||||
async_dispatcher_send(self.hass, self._available_signal, False)
|
||||
async_dispatcher_send(
|
||||
self.hass,
|
||||
"{}_{}".format(self._available_signal, 'entity'),
|
||||
available
|
||||
self.hass, "{}_{}".format(self._available_signal, "entity"), available
|
||||
)
|
||||
self._available = available
|
||||
|
||||
|
@ -215,14 +235,16 @@ class ZHADevice(LogMixin):
|
|||
LQI: self.lqi,
|
||||
RSSI: self.rssi,
|
||||
LAST_SEEN: update_time,
|
||||
ATTR_AVAILABLE: self.available
|
||||
ATTR_AVAILABLE: self.available,
|
||||
}
|
||||
|
||||
def add_cluster_channel(self, cluster_channel):
|
||||
"""Add cluster channel to device."""
|
||||
# only keep 1 power configuration channel
|
||||
if cluster_channel.name is POWER_CONFIGURATION_CHANNEL and \
|
||||
POWER_CONFIGURATION_CHANNEL in self.cluster_channels:
|
||||
if (
|
||||
cluster_channel.name is POWER_CONFIGURATION_CHANNEL
|
||||
and POWER_CONFIGURATION_CHANNEL in self.cluster_channels
|
||||
):
|
||||
return
|
||||
|
||||
if isinstance(cluster_channel, EventRelayChannel):
|
||||
|
@ -249,10 +271,9 @@ class ZHADevice(LogMixin):
|
|||
|
||||
def get_key(channel):
|
||||
channel_key = "ZDO"
|
||||
if hasattr(channel.cluster, 'cluster_id'):
|
||||
if hasattr(channel.cluster, "cluster_id"):
|
||||
channel_key = "{}_{}".format(
|
||||
channel.cluster.endpoint.endpoint_id,
|
||||
channel.cluster.cluster_id
|
||||
channel.cluster.endpoint.endpoint_id, channel.cluster.cluster_id
|
||||
)
|
||||
return channel_key
|
||||
|
||||
|
@ -273,21 +294,23 @@ class ZHADevice(LogMixin):
|
|||
|
||||
async def async_configure(self):
|
||||
"""Configure the device."""
|
||||
self.debug('started configuration')
|
||||
self.debug("started configuration")
|
||||
await self._execute_channel_tasks(
|
||||
self.get_channels_to_configure(), 'async_configure')
|
||||
self.debug('completed configuration')
|
||||
self.get_channels_to_configure(), "async_configure"
|
||||
)
|
||||
self.debug("completed configuration")
|
||||
entry = self.gateway.zha_storage.async_create_or_update(self)
|
||||
self.debug('stored in registry: %s', entry)
|
||||
self.debug("stored in registry: %s", entry)
|
||||
|
||||
async def async_initialize(self, from_cache=False):
|
||||
"""Initialize channels."""
|
||||
self.debug('started initialization')
|
||||
self.debug("started initialization")
|
||||
await self._execute_channel_tasks(
|
||||
self.all_channels, 'async_initialize', from_cache)
|
||||
self.debug('power source: %s', self.power_source)
|
||||
self.all_channels, "async_initialize", from_cache
|
||||
)
|
||||
self.debug("power source: %s", self.power_source)
|
||||
self.status = DeviceStatus.INITIALIZED
|
||||
self.debug('completed initialization')
|
||||
self.debug("completed initialization")
|
||||
|
||||
async def _execute_channel_tasks(self, channels, task_name, *args):
|
||||
"""Gather and execute a set of CHANNEL tasks."""
|
||||
|
@ -299,11 +322,12 @@ class ZHADevice(LogMixin):
|
|||
# pylint: disable=E1111
|
||||
if zdo_task is None: # We only want to do this once
|
||||
zdo_task = self._async_create_task(
|
||||
semaphore, channel, task_name, *args)
|
||||
semaphore, channel, task_name, *args
|
||||
)
|
||||
else:
|
||||
channel_tasks.append(
|
||||
self._async_create_task(
|
||||
semaphore, channel, task_name, *args))
|
||||
self._async_create_task(semaphore, channel, task_name, *args)
|
||||
)
|
||||
if zdo_task is not None:
|
||||
await zdo_task
|
||||
await asyncio.gather(*channel_tasks)
|
||||
|
@ -332,10 +356,8 @@ class ZHADevice(LogMixin):
|
|||
def async_get_clusters(self):
|
||||
"""Get all clusters for this device."""
|
||||
return {
|
||||
ep_id: {
|
||||
IN: endpoint.in_clusters,
|
||||
OUT: endpoint.out_clusters
|
||||
} for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
|
||||
ep_id: {IN: endpoint.in_clusters, OUT: endpoint.out_clusters}
|
||||
for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
|
||||
if ep_id != 0
|
||||
}
|
||||
|
||||
|
@ -343,15 +365,11 @@ class ZHADevice(LogMixin):
|
|||
def async_get_std_clusters(self):
|
||||
"""Get ZHA and ZLL clusters for this device."""
|
||||
from zigpy.profiles import zha, zll
|
||||
|
||||
return {
|
||||
ep_id: {
|
||||
IN: endpoint.in_clusters,
|
||||
OUT: endpoint.out_clusters
|
||||
} for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
|
||||
if ep_id != 0 and endpoint.profile_id in (
|
||||
zha.PROFILE_ID,
|
||||
zll.PROFILE_ID
|
||||
)
|
||||
ep_id: {IN: endpoint.in_clusters, OUT: endpoint.out_clusters}
|
||||
for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
|
||||
if ep_id != 0 and endpoint.profile_id in (zha.PROFILE_ID, zll.PROFILE_ID)
|
||||
}
|
||||
|
||||
@callback
|
||||
|
@ -361,18 +379,15 @@ class ZHADevice(LogMixin):
|
|||
return clusters[endpoint_id][cluster_type][cluster_id]
|
||||
|
||||
@callback
|
||||
def async_get_cluster_attributes(self, endpoint_id, cluster_id,
|
||||
cluster_type=IN):
|
||||
def async_get_cluster_attributes(self, endpoint_id, cluster_id, cluster_type=IN):
|
||||
"""Get zigbee attributes for specified cluster."""
|
||||
cluster = self.async_get_cluster(endpoint_id, cluster_id,
|
||||
cluster_type)
|
||||
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
|
||||
if cluster is None:
|
||||
return None
|
||||
return cluster.attributes
|
||||
|
||||
@callback
|
||||
def async_get_cluster_commands(self, endpoint_id, cluster_id,
|
||||
cluster_type=IN):
|
||||
def async_get_cluster_commands(self, endpoint_id, cluster_id, cluster_type=IN):
|
||||
"""Get zigbee commands for specified cluster."""
|
||||
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
|
||||
if cluster is None:
|
||||
|
@ -382,64 +397,77 @@ class ZHADevice(LogMixin):
|
|||
SERVER_COMMANDS: cluster.server_commands,
|
||||
}
|
||||
|
||||
async def write_zigbee_attribute(self, endpoint_id, cluster_id,
|
||||
attribute, value, cluster_type=IN,
|
||||
manufacturer=None):
|
||||
async def write_zigbee_attribute(
|
||||
self,
|
||||
endpoint_id,
|
||||
cluster_id,
|
||||
attribute,
|
||||
value,
|
||||
cluster_type=IN,
|
||||
manufacturer=None,
|
||||
):
|
||||
"""Write a value to a zigbee attribute for a cluster in this entity."""
|
||||
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
|
||||
if cluster is None:
|
||||
return None
|
||||
|
||||
from zigpy.exceptions import DeliveryError
|
||||
|
||||
try:
|
||||
response = await cluster.write_attributes(
|
||||
{attribute: value},
|
||||
manufacturer=manufacturer
|
||||
{attribute: value}, manufacturer=manufacturer
|
||||
)
|
||||
self.debug(
|
||||
'set: %s for attr: %s to cluster: %s for ept: %s - res: %s',
|
||||
"set: %s for attr: %s to cluster: %s for ept: %s - res: %s",
|
||||
value,
|
||||
attribute,
|
||||
cluster_id,
|
||||
endpoint_id,
|
||||
response
|
||||
response,
|
||||
)
|
||||
return response
|
||||
except DeliveryError as exc:
|
||||
self.debug(
|
||||
'failed to set attribute: %s %s %s %s %s',
|
||||
'{}: {}'.format(ATTR_VALUE, value),
|
||||
'{}: {}'.format(ATTR_ATTRIBUTE, attribute),
|
||||
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_id),
|
||||
'{}: {}'.format(ATTR_ENDPOINT_ID, endpoint_id),
|
||||
exc
|
||||
"failed to set attribute: %s %s %s %s %s",
|
||||
"{}: {}".format(ATTR_VALUE, value),
|
||||
"{}: {}".format(ATTR_ATTRIBUTE, attribute),
|
||||
"{}: {}".format(ATTR_CLUSTER_ID, cluster_id),
|
||||
"{}: {}".format(ATTR_ENDPOINT_ID, endpoint_id),
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
async def issue_cluster_command(self, endpoint_id, cluster_id, command,
|
||||
command_type, args, cluster_type=IN,
|
||||
manufacturer=None):
|
||||
async def issue_cluster_command(
|
||||
self,
|
||||
endpoint_id,
|
||||
cluster_id,
|
||||
command,
|
||||
command_type,
|
||||
args,
|
||||
cluster_type=IN,
|
||||
manufacturer=None,
|
||||
):
|
||||
"""Issue a command against specified zigbee cluster on this entity."""
|
||||
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
|
||||
if cluster is None:
|
||||
return None
|
||||
response = None
|
||||
if command_type == SERVER:
|
||||
response = await cluster.command(command, *args,
|
||||
manufacturer=manufacturer,
|
||||
expect_reply=True)
|
||||
response = await cluster.command(
|
||||
command, *args, manufacturer=manufacturer, expect_reply=True
|
||||
)
|
||||
else:
|
||||
response = await cluster.client_command(command, *args)
|
||||
|
||||
self.debug(
|
||||
'Issued cluster command: %s %s %s %s %s %s %s',
|
||||
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_id),
|
||||
'{}: {}'.format(ATTR_COMMAND, command),
|
||||
'{}: {}'.format(ATTR_COMMAND_TYPE, command_type),
|
||||
'{}: {}'.format(ATTR_ARGS, args),
|
||||
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_type),
|
||||
'{}: {}'.format(ATTR_MANUFACTURER, manufacturer),
|
||||
'{}: {}'.format(ATTR_ENDPOINT_ID, endpoint_id)
|
||||
"Issued cluster command: %s %s %s %s %s %s %s",
|
||||
"{}: {}".format(ATTR_CLUSTER_ID, cluster_id),
|
||||
"{}: {}".format(ATTR_COMMAND, command),
|
||||
"{}: {}".format(ATTR_COMMAND_TYPE, command_type),
|
||||
"{}: {}".format(ATTR_ARGS, args),
|
||||
"{}: {}".format(ATTR_CLUSTER_ID, cluster_type),
|
||||
"{}: {}".format(ATTR_MANUFACTURER, manufacturer),
|
||||
"{}: {}".format(ATTR_ENDPOINT_ID, endpoint_id),
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
|
@ -12,13 +12,19 @@ from homeassistant.helpers.restore_state import RestoreEntity
|
|||
from homeassistant.util import slugify
|
||||
|
||||
from .core.const import (
|
||||
ATTR_MANUFACTURER, DATA_ZHA, DATA_ZHA_BRIDGE_ID, DOMAIN, MODEL, NAME,
|
||||
SIGNAL_REMOVE)
|
||||
ATTR_MANUFACTURER,
|
||||
DATA_ZHA,
|
||||
DATA_ZHA_BRIDGE_ID,
|
||||
DOMAIN,
|
||||
MODEL,
|
||||
NAME,
|
||||
SIGNAL_REMOVE,
|
||||
)
|
||||
from .core.helpers import LogMixin
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_SUFFIX = 'entity_suffix'
|
||||
ENTITY_SUFFIX = "entity_suffix"
|
||||
RESTART_GRACE_PERIOD = 7200 # 2 hours
|
||||
|
||||
|
||||
|
@ -27,29 +33,28 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
|
||||
_domain = None # Must be overridden by subclasses
|
||||
|
||||
def __init__(self, unique_id, zha_device, channels,
|
||||
skip_entity_id=False, **kwargs):
|
||||
def __init__(self, unique_id, zha_device, channels, skip_entity_id=False, **kwargs):
|
||||
"""Init ZHA entity."""
|
||||
self._force_update = False
|
||||
self._should_poll = False
|
||||
self._unique_id = unique_id
|
||||
if not skip_entity_id:
|
||||
ieee = zha_device.ieee
|
||||
ieeetail = ''.join(['%02x' % (o, ) for o in ieee[-4:]])
|
||||
ieeetail = "".join(["%02x" % (o,) for o in ieee[-4:]])
|
||||
self.entity_id = "{}.{}_{}_{}_{}{}".format(
|
||||
self._domain,
|
||||
slugify(zha_device.manufacturer),
|
||||
slugify(zha_device.model),
|
||||
ieeetail,
|
||||
channels[0].cluster.endpoint.endpoint_id,
|
||||
kwargs.get(ENTITY_SUFFIX, ''),
|
||||
kwargs.get(ENTITY_SUFFIX, ""),
|
||||
)
|
||||
self._state = None
|
||||
self._device_state_attributes = {}
|
||||
self._zha_device = zha_device
|
||||
self.cluster_channels = {}
|
||||
self._available = False
|
||||
self._component = kwargs['component']
|
||||
self._component = kwargs["component"]
|
||||
self._unsubs = []
|
||||
self.remove_future = None
|
||||
for channel in channels:
|
||||
|
@ -89,15 +94,14 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
def device_info(self):
|
||||
"""Return a device description for device registry."""
|
||||
zha_device_info = self._zha_device.device_info
|
||||
ieee = zha_device_info['ieee']
|
||||
ieee = zha_device_info["ieee"]
|
||||
return {
|
||||
'connections': {(CONNECTION_ZIGBEE, ieee)},
|
||||
'identifiers': {(DOMAIN, ieee)},
|
||||
"connections": {(CONNECTION_ZIGBEE, ieee)},
|
||||
"identifiers": {(DOMAIN, ieee)},
|
||||
ATTR_MANUFACTURER: zha_device_info[ATTR_MANUFACTURER],
|
||||
MODEL: zha_device_info[MODEL],
|
||||
NAME: zha_device_info[NAME],
|
||||
'via_device': (
|
||||
DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]),
|
||||
"via_device": (DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]),
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -112,9 +116,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
|
||||
def async_update_state_attribute(self, key, value):
|
||||
"""Update a single device state attribute."""
|
||||
self._device_state_attributes.update({
|
||||
key: value
|
||||
})
|
||||
self._device_state_attributes.update({key: value})
|
||||
self.async_schedule_update_ha_state()
|
||||
|
||||
def async_set_state(self, state):
|
||||
|
@ -127,24 +129,34 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
self.remove_future = asyncio.Future()
|
||||
await self.async_check_recently_seen()
|
||||
await self.async_accept_signal(
|
||||
None, "{}_{}".format(self.zha_device.available_signal, 'entity'),
|
||||
None,
|
||||
"{}_{}".format(self.zha_device.available_signal, "entity"),
|
||||
self.async_set_available,
|
||||
signal_override=True)
|
||||
signal_override=True,
|
||||
)
|
||||
await self.async_accept_signal(
|
||||
None, "{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
|
||||
None,
|
||||
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
|
||||
self.async_remove,
|
||||
signal_override=True
|
||||
signal_override=True,
|
||||
)
|
||||
self._zha_device.gateway.register_entity_reference(
|
||||
self._zha_device.ieee, self.entity_id, self._zha_device,
|
||||
self.cluster_channels, self.device_info, self.remove_future)
|
||||
self._zha_device.ieee,
|
||||
self.entity_id,
|
||||
self._zha_device,
|
||||
self.cluster_channels,
|
||||
self.device_info,
|
||||
self.remove_future,
|
||||
)
|
||||
|
||||
async def async_check_recently_seen(self):
|
||||
"""Check if the device was seen within the last 2 hours."""
|
||||
last_state = await self.async_get_last_state()
|
||||
if last_state and self._zha_device.last_seen and (
|
||||
time.time() - self._zha_device.last_seen <
|
||||
RESTART_GRACE_PERIOD):
|
||||
if (
|
||||
last_state
|
||||
and self._zha_device.last_seen
|
||||
and (time.time() - self._zha_device.last_seen < RESTART_GRACE_PERIOD)
|
||||
):
|
||||
self.async_set_available(True)
|
||||
if not self.zha_device.is_mains_powered:
|
||||
# mains powered devices will get real time state
|
||||
|
@ -167,24 +179,17 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
|
|||
async def async_update(self):
|
||||
"""Retrieve latest state."""
|
||||
for channel in self.cluster_channels.values():
|
||||
if hasattr(channel, 'async_update'):
|
||||
if hasattr(channel, "async_update"):
|
||||
await channel.async_update()
|
||||
|
||||
async def async_accept_signal(self, channel, signal, func,
|
||||
signal_override=False):
|
||||
async def async_accept_signal(self, channel, signal, func, signal_override=False):
|
||||
"""Accept a signal from a channel."""
|
||||
unsub = None
|
||||
if signal_override:
|
||||
unsub = async_dispatcher_connect(
|
||||
self.hass,
|
||||
signal,
|
||||
func
|
||||
)
|
||||
unsub = async_dispatcher_connect(self.hass, signal, func)
|
||||
else:
|
||||
unsub = async_dispatcher_connect(
|
||||
self.hass,
|
||||
"{}_{}".format(channel.unique_id, signal),
|
||||
func
|
||||
self.hass, "{}_{}".format(channel.unique_id, signal), func
|
||||
)
|
||||
self._unsubs.append(unsub)
|
||||
|
||||
|
|
|
@ -9,21 +9,29 @@ import random
|
|||
import string
|
||||
from functools import wraps
|
||||
from types import MappingProxyType
|
||||
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa
|
||||
Iterable, List, Dict, Iterator, Coroutine, MutableSet)
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Callable,
|
||||
KeysView,
|
||||
Union, # noqa
|
||||
Iterable,
|
||||
Coroutine,
|
||||
)
|
||||
|
||||
import slugify as unicode_slug
|
||||
|
||||
from .dt import as_local, utcnow
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
T = TypeVar('T')
|
||||
U = TypeVar('U')
|
||||
ENUM_T = TypeVar('ENUM_T', bound=enum.Enum)
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
ENUM_T = TypeVar("ENUM_T", bound=enum.Enum)
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
|
||||
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
|
||||
RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
|
||||
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
|
@ -38,23 +46,24 @@ def sanitize_path(path: str) -> str:
|
|||
|
||||
def slugify(text: str) -> str:
|
||||
"""Slugify a given text."""
|
||||
return unicode_slug.slugify(text, separator='_') # type: ignore
|
||||
return unicode_slug.slugify(text, separator="_") # type: ignore
|
||||
|
||||
|
||||
def repr_helper(inp: Any) -> str:
|
||||
"""Help creating a more readable string representation of objects."""
|
||||
if isinstance(inp, (dict, MappingProxyType)):
|
||||
return ", ".join(
|
||||
repr_helper(key)+"="+repr_helper(item) for key, item
|
||||
in inp.items())
|
||||
repr_helper(key) + "=" + repr_helper(item) for key, item in inp.items()
|
||||
)
|
||||
if isinstance(inp, datetime):
|
||||
return as_local(inp).isoformat()
|
||||
|
||||
return str(inp)
|
||||
|
||||
|
||||
def convert(value: Optional[T], to_type: Callable[[T], U],
|
||||
default: Optional[U] = None) -> Optional[U]:
|
||||
def convert(
|
||||
value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None
|
||||
) -> Optional[U]:
|
||||
"""Convert value to to_type, returns default if fails."""
|
||||
try:
|
||||
return default if value is None else to_type(value)
|
||||
|
@ -63,8 +72,9 @@ def convert(value: Optional[T], to_type: Callable[[T], U],
|
|||
return default
|
||||
|
||||
|
||||
def ensure_unique_string(preferred_string: str, current_strings:
|
||||
Union[Iterable[str], KeysView[str]]) -> str:
|
||||
def ensure_unique_string(
|
||||
preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]]
|
||||
) -> str:
|
||||
"""Return a string that is not present in current_strings.
|
||||
|
||||
If preferred string exists will append _2, _3, ..
|
||||
|
@ -88,14 +98,14 @@ def get_local_ip() -> str:
|
|||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
# Use Google Public DNS server to determine own IP
|
||||
sock.connect(('8.8.8.8', 80))
|
||||
sock.connect(("8.8.8.8", 80))
|
||||
|
||||
return sock.getsockname()[0] # type: ignore
|
||||
except socket.error:
|
||||
try:
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
except socket.gaierror:
|
||||
return '127.0.0.1'
|
||||
return "127.0.0.1"
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
@ -106,7 +116,7 @@ def get_random_string(length: int = 10) -> str:
|
|||
generator = random.SystemRandom()
|
||||
source_chars = string.ascii_letters + string.digits
|
||||
|
||||
return ''.join(generator.choice(source_chars) for _ in range(length))
|
||||
return "".join(generator.choice(source_chars) for _ in range(length))
|
||||
|
||||
|
||||
class OrderedEnum(enum.Enum):
|
||||
|
@ -158,8 +168,9 @@ class Throttle:
|
|||
Adds a datetime attribute `last_call` to the method.
|
||||
"""
|
||||
|
||||
def __init__(self, min_time: timedelta,
|
||||
limit_no_throttle: Optional[timedelta] = None) -> None:
|
||||
def __init__(
|
||||
self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None
|
||||
) -> None:
|
||||
"""Initialize the throttle."""
|
||||
self.min_time = min_time
|
||||
self.limit_no_throttle = limit_no_throttle
|
||||
|
@ -168,10 +179,13 @@ class Throttle:
|
|||
"""Caller for the throttle."""
|
||||
# Make sure we return a coroutine if the method is async.
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
async def throttled_value() -> None:
|
||||
"""Stand-in function for when real func is being throttled."""
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
def throttled_value() -> None: # type: ignore
|
||||
"""Stand-in function for when real func is being throttled."""
|
||||
return None
|
||||
|
@ -189,8 +203,10 @@ class Throttle:
|
|||
# All methods have the classname in their qualname separated by a '.'
|
||||
# Functions have a '.' in their qualname if defined inline, but will
|
||||
# be prefixed by '.<locals>.' so we strip that out.
|
||||
is_func = (not hasattr(method, '__self__') and
|
||||
'.' not in method.__qualname__.split('.<locals>.')[-1])
|
||||
is_func = (
|
||||
not hasattr(method, "__self__")
|
||||
and "." not in method.__qualname__.split(".<locals>.")[-1]
|
||||
)
|
||||
|
||||
@wraps(method)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
|
||||
|
@ -199,14 +215,14 @@ class Throttle:
|
|||
If we cannot acquire the lock, it is running so return None.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if hasattr(method, '__self__'):
|
||||
host = getattr(method, '__self__')
|
||||
if hasattr(method, "__self__"):
|
||||
host = getattr(method, "__self__")
|
||||
elif is_func:
|
||||
host = wrapper
|
||||
else:
|
||||
host = args[0] if args else wrapper
|
||||
|
||||
if not hasattr(host, '_throttle'):
|
||||
if not hasattr(host, "_throttle"):
|
||||
host._throttle = {}
|
||||
|
||||
if id(self) not in host._throttle:
|
||||
|
@ -217,7 +233,7 @@ class Throttle:
|
|||
return throttled_value()
|
||||
|
||||
# Check if method is never called or no_throttle is given
|
||||
force = kwargs.pop('no_throttle', False) or not throttle[1]
|
||||
force = kwargs.pop("no_throttle", False) or not throttle[1]
|
||||
|
||||
try:
|
||||
if force or utcnow() - throttle[1] > self.min_time:
|
||||
|
|
2
pylintrc
2
pylintrc
|
@ -6,6 +6,7 @@ good-names=i,j,k,ex,Run,_,fp
|
|||
|
||||
[MESSAGES CONTROL]
|
||||
# Reasons disabled:
|
||||
# format - handled by black
|
||||
# locally-disabled - it spams too much
|
||||
# duplicate-code - unavoidable
|
||||
# cyclic-import - doesn't test if both import on load
|
||||
|
@ -20,6 +21,7 @@ good-names=i,j,k,ex,Run,_,fp
|
|||
# not-an-iterable - https://github.com/PyCQA/pylint/issues/2311
|
||||
# unnecessary-pass - readability for functions which only contain pass
|
||||
disable=
|
||||
format,
|
||||
abstract-class-little-used,
|
||||
abstract-method,
|
||||
cyclic-import,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
[tool.black]
|
||||
target-version = ["py36", "py37", "py38"]
|
||||
exclude = 'generated'
|
|
@ -1,7 +1,10 @@
|
|||
# linters such as flake8 and pylint should be pinned, as new releases
|
||||
# make new things fail. Manually update these pins when pulling in a
|
||||
# new version
|
||||
|
||||
# When updating this file, update .pre-commit-config.yaml too
|
||||
asynctest==0.13.0
|
||||
black==19.3b0
|
||||
codecov==2.0.15
|
||||
coveralls==1.2.0
|
||||
flake8-docstrings==1.3.0
|
||||
|
@ -16,3 +19,4 @@ pytest-sugar==0.9.2
|
|||
pytest-timeout==1.3.3
|
||||
pytest==5.0.1
|
||||
requests_mock==1.6.0
|
||||
pre-commit==1.17.0
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
# linters such as flake8 and pylint should be pinned, as new releases
|
||||
# make new things fail. Manually update these pins when pulling in a
|
||||
# new version
|
||||
|
||||
# When updating this file, update .pre-commit-config.yaml too
|
||||
asynctest==0.13.0
|
||||
black==19.3b0
|
||||
codecov==2.0.15
|
||||
coveralls==1.2.0
|
||||
flake8-docstrings==1.3.0
|
||||
|
@ -17,6 +20,7 @@ pytest-sugar==0.9.2
|
|||
pytest-timeout==1.3.3
|
||||
pytest==5.0.1
|
||||
requests_mock==1.6.0
|
||||
pre-commit==1.17.0
|
||||
|
||||
|
||||
# homeassistant.components.homekit
|
||||
|
|
|
@ -7,4 +7,4 @@ set -e
|
|||
cd "$(dirname "$0")/.."
|
||||
|
||||
echo "Installing test dependencies..."
|
||||
python3 -m pip install tox colorlog
|
||||
python3 -m pip install tox colorlog pre-commit
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#!/bin/sh
|
||||
# Format code with black.
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
black \
|
||||
--check \
|
||||
--fast \
|
||||
--quiet \
|
||||
homeassistant tests script
|
|
@ -7,4 +7,5 @@ set -e
|
|||
cd "$(dirname "$0")/.."
|
||||
script/bootstrap
|
||||
|
||||
pre-commit install
|
||||
pip3 install -e .
|
||||
|
|
18
setup.cfg
18
setup.cfg
|
@ -21,12 +21,27 @@ norecursedirs = .git testing_config
|
|||
|
||||
[flake8]
|
||||
exclude = .venv,.git,.tox,docs,venv,bin,lib,deps,build
|
||||
# To work with Black
|
||||
max-line-length = 88
|
||||
# E501: line too long
|
||||
# W503: Line break occurred before a binary operator
|
||||
# E203: Whitespace before ':'
|
||||
# D202 No blank lines allowed after function docstring
|
||||
ignore =
|
||||
E501,
|
||||
W503,
|
||||
E203,
|
||||
D202
|
||||
|
||||
[isort]
|
||||
# https://github.com/timothycrosley/isort
|
||||
# https://github.com/timothycrosley/isort/wiki/isort-Settings
|
||||
# splits long import on multiple lines indented by 4 spaces
|
||||
multi_line_output = 4
|
||||
multi_line_output = 3
|
||||
include_trailing_comma=True
|
||||
force_grid_wrap=0
|
||||
use_parentheses=True
|
||||
line_length=88
|
||||
indent = " "
|
||||
# by default isort don't check module indexes
|
||||
not_skip = __init__.py
|
||||
|
@ -37,4 +52,3 @@ default_section = THIRDPARTY
|
|||
known_first_party = homeassistant,tests
|
||||
forced_separate = tests
|
||||
combine_as_imports = true
|
||||
use_parentheses = true
|
||||
|
|
Loading…
Reference in New Issue