Add Black

This commit is contained in:
Paulus Schoutsen 2019-07-30 16:59:12 -07:00
parent 0490167a12
commit da05dfe708
16 changed files with 401 additions and 272 deletions

View File

@ -17,6 +17,10 @@
"python.pythonPath": "/usr/local/bin/python",
"python.linting.pylintEnabled": true,
"python.linting.enabled": true,
"python.formatting.provider": "black",
"editor.formatOnPaste": false,
"editor.formatOnSave": true,
"editor.formatOnType": true,
"files.trimTrailingWhitespace": true,
"editor.rulers": [80],
"terminal.integrated.shell.linux": "/bin/bash",

8
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,8 @@
repos:
- repo: https://github.com/python/black
rev: 19.3b0
hooks:
- id: black
args:
- --safe
- --quiet

View File

@ -38,7 +38,7 @@ stages:
python -m venv venv
. venv/bin/activate
pip install -r requirements_test.txt
pip install -r requirements_test.txt -c homeassistant/package_constraints.txt
displayName: 'Setup Env'
- script: |
. venv/bin/activate
@ -63,6 +63,21 @@ stages:
. venv/bin/activate
./script/gen_requirements_all.py validate
displayName: 'requirements_all validate'
- job: 'CheckFormat'
pool:
vmImage: 'ubuntu-latest'
container: $[ variables['PythonMain'] ]
steps:
- script: |
python -m venv venv
. venv/bin/activate
pip install -r requirements_test.txt -c homeassistant/package_constraints.txt
displayName: 'Setup Env'
- script: |
. venv/bin/activate
./script/check_format
displayName: 'Check Black formatting'
- stage: 'Tests'
dependsOn:

View File

@ -21,42 +21,42 @@ from homeassistant.helpers.entity import Entity
_LOGGER = logging.getLogger(__name__)
ATTR_STOP_ID = 'stop_id'
ATTR_STOP_NAME = 'stop'
ATTR_ROUTE = 'route'
ATTR_TYPE = 'type'
ATTR_STOP_ID = "stop_id"
ATTR_STOP_NAME = "stop"
ATTR_ROUTE = "route"
ATTR_TYPE = "type"
ATTR_DIRECTION = "direction"
ATTR_DUE_IN = 'due_in'
ATTR_DUE_AT = 'due_at'
ATTR_NEXT_UP = 'next_departures'
ATTR_DUE_IN = "due_in"
ATTR_DUE_AT = "due_at"
ATTR_NEXT_UP = "next_departures"
ATTRIBUTION = "Data provided by rejseplanen.dk"
CONF_STOP_ID = 'stop_id'
CONF_ROUTE = 'route'
CONF_DIRECTION = 'direction'
CONF_DEPARTURE_TYPE = 'departure_type'
CONF_STOP_ID = "stop_id"
CONF_ROUTE = "route"
CONF_DIRECTION = "direction"
CONF_DEPARTURE_TYPE = "departure_type"
DEFAULT_NAME = 'Next departure'
ICON = 'mdi:bus'
DEFAULT_NAME = "Next departure"
ICON = "mdi:bus"
SCAN_INTERVAL = timedelta(minutes=1)
BUS_TYPES = ['BUS', 'EXB', 'TB']
TRAIN_TYPES = ['LET', 'S', 'REG', 'IC', 'LYN', 'TOG']
METRO_TYPES = ['M']
BUS_TYPES = ["BUS", "EXB", "TB"]
TRAIN_TYPES = ["LET", "S", "REG", "IC", "LYN", "TOG"]
METRO_TYPES = ["M"]
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_STOP_ID): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_ROUTE, default=[]):
vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_DIRECTION, default=[]):
vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_DEPARTURE_TYPE, default=[]):
vol.All(cv.ensure_list,
[vol.In([*BUS_TYPES, *TRAIN_TYPES, *METRO_TYPES])])
})
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
{
vol.Required(CONF_STOP_ID): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_ROUTE, default=[]): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_DIRECTION, default=[]): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_DEPARTURE_TYPE, default=[]): vol.All(
cv.ensure_list, [vol.In([*BUS_TYPES, *TRAIN_TYPES, *METRO_TYPES])]
),
}
)
def due_in_minutes(timestamp):
@ -64,8 +64,9 @@ def due_in_minutes(timestamp):
The timestamp should be in the format day.month.year hour:minute
"""
diff = datetime.strptime(
timestamp, "%d.%m.%y %H:%M") - dt_util.now().replace(tzinfo=None)
diff = datetime.strptime(timestamp, "%d.%m.%y %H:%M") - dt_util.now().replace(
tzinfo=None
)
return int(diff.total_seconds() // 60)
@ -79,8 +80,9 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
departure_type = config[CONF_DEPARTURE_TYPE]
data = PublicTransportData(stop_id, route, direction, departure_type)
add_devices([RejseplanenTransportSensor(
data, stop_id, route, direction, name)], True)
add_devices(
[RejseplanenTransportSensor(data, stop_id, route, direction, name)], True
)
class RejseplanenTransportSensor(Entity):
@ -124,14 +126,14 @@ class RejseplanenTransportSensor(Entity):
ATTR_STOP_NAME: self._times[0][ATTR_STOP_NAME],
ATTR_STOP_ID: self._stop_id,
ATTR_ATTRIBUTION: ATTRIBUTION,
ATTR_NEXT_UP: next_up
ATTR_NEXT_UP: next_up,
}
return {k: v for k, v in params.items() if v}
@property
def unit_of_measurement(self):
"""Return the unit this state is expressed in."""
return 'min'
return "min"
@property
def icon(self):
@ -148,7 +150,7 @@ class RejseplanenTransportSensor(Entity):
pass
class PublicTransportData():
class PublicTransportData:
"""The Class for handling the data retrieval."""
def __init__(self, stop_id, route, direction, departure_type):
@ -161,16 +163,21 @@ class PublicTransportData():
def empty_result(self):
"""Object returned when no departures are found."""
return [{ATTR_DUE_IN: 'n/a',
ATTR_DUE_AT: 'n/a',
ATTR_TYPE: 'n/a',
ATTR_ROUTE: self.route,
ATTR_DIRECTION: 'n/a',
ATTR_STOP_NAME: 'n/a'}]
return [
{
ATTR_DUE_IN: "n/a",
ATTR_DUE_AT: "n/a",
ATTR_TYPE: "n/a",
ATTR_ROUTE: self.route,
ATTR_DIRECTION: "n/a",
ATTR_STOP_NAME: "n/a",
}
]
def update(self):
"""Get the latest data from rejseplanen."""
import rjpl
self.info = []
def intersection(lst1, lst2):
@ -179,12 +186,9 @@ class PublicTransportData():
# Limit search to selected types, to get more results
all_types = not bool(self.departure_type)
use_train = all_types or bool(
intersection(TRAIN_TYPES, self.departure_type))
use_bus = all_types or bool(
intersection(BUS_TYPES, self.departure_type))
use_metro = all_types or bool(
intersection(METRO_TYPES, self.departure_type))
use_train = all_types or bool(intersection(TRAIN_TYPES, self.departure_type))
use_bus = all_types or bool(intersection(BUS_TYPES, self.departure_type))
use_metro = all_types or bool(intersection(METRO_TYPES, self.departure_type))
try:
results = rjpl.departureBoard(
@ -192,7 +196,7 @@ class PublicTransportData():
timeout=5,
useTrain=use_train,
useBus=use_bus,
useMetro=use_metro
useMetro=use_metro,
)
except rjpl.rjplAPIError as error:
_LOGGER.debug("API returned error: %s", error)
@ -204,36 +208,40 @@ class PublicTransportData():
return
# Filter result
results = [d for d in results if 'cancelled' not in d]
results = [d for d in results if "cancelled" not in d]
if self.route:
results = [d for d in results if d['name'] in self.route]
results = [d for d in results if d["name"] in self.route]
if self.direction:
results = [d for d in results if d['direction'] in self.direction]
results = [d for d in results if d["direction"] in self.direction]
if self.departure_type:
results = [d for d in results if d['type'] in self.departure_type]
results = [d for d in results if d["type"] in self.departure_type]
for item in results:
route = item.get('name')
route = item.get("name")
due_at_date = item.get('rtDate')
due_at_time = item.get('rtTime')
due_at_date = item.get("rtDate")
due_at_time = item.get("rtTime")
if due_at_date is None:
due_at_date = item.get('date') # Scheduled date
due_at_date = item.get("date") # Scheduled date
if due_at_time is None:
due_at_time = item.get('time') # Scheduled time
due_at_time = item.get("time") # Scheduled time
if (due_at_date is not None and
due_at_time is not None and
route is not None):
due_at = '{} {}'.format(due_at_date, due_at_time)
if (
due_at_date is not None
and due_at_time is not None
and route is not None
):
due_at = "{} {}".format(due_at_date, due_at_time)
departure_data = {ATTR_DUE_IN: due_in_minutes(due_at),
ATTR_DUE_AT: due_at,
ATTR_TYPE: item.get('type'),
ATTR_ROUTE: route,
ATTR_DIRECTION: item.get('direction'),
ATTR_STOP_NAME: item.get('stop')}
departure_data = {
ATTR_DUE_IN: due_in_minutes(due_at),
ATTR_DUE_AT: due_at,
ATTR_TYPE: item.get("type"),
ATTR_ROUTE: route,
ATTR_DIRECTION: item.get("direction"),
ATTR_STOP_NAME: item.get("stop"),
}
self.info.append(departure_data)
if not self.info:

View File

@ -14,11 +14,19 @@ from random import uniform
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from ..helpers import (
configure_reporting, construct_unique_id,
safe_read, get_attr_id_by_name, bind_cluster, LogMixin)
configure_reporting,
construct_unique_id,
safe_read,
get_attr_id_by_name,
bind_cluster,
LogMixin,
)
from ..const import (
REPORT_CONFIG_DEFAULT, SIGNAL_ATTR_UPDATED, ATTRIBUTE_CHANNEL,
EVENT_RELAY_CHANNEL, ZDO_CHANNEL
REPORT_CONFIG_DEFAULT,
SIGNAL_ATTR_UPDATED,
ATTRIBUTE_CHANNEL,
EVENT_RELAY_CHANNEL,
ZDO_CHANNEL,
)
from ..registries import CLUSTER_REPORT_CONFIGS
@ -33,32 +41,33 @@ def parse_and_log_command(channel, tsn, command_id, args):
cmd,
args,
channel.cluster.cluster_id,
tsn
tsn,
)
return cmd
def decorate_command(channel, command):
"""Wrap a cluster command to make it safe."""
@wraps(command)
async def wrapper(*args, **kwds):
from zigpy.exceptions import DeliveryError
try:
result = await command(*args, **kwds)
channel.debug("executed command: %s %s %s %s",
command.__name__,
"{}: {}".format("with args", args),
"{}: {}".format("with kwargs", kwds),
"{}: {}".format("and result", result))
channel.debug(
"executed command: %s %s %s %s",
command.__name__,
"{}: {}".format("with args", args),
"{}: {}".format("with kwargs", kwds),
"{}: {}".format("and result", result),
)
return result
except (DeliveryError, Timeout) as ex:
channel.debug(
"command failed: %s exception: %s",
command.__name__,
str(ex)
)
channel.debug("command failed: %s exception: %s", command.__name__, str(ex))
return ex
return wrapper
@ -80,13 +89,12 @@ class ZigbeeChannel(LogMixin):
self._channel_name = cluster.ep_attribute
if self.CHANNEL_NAME:
self._channel_name = self.CHANNEL_NAME
self._generic_id = 'channel_0x{:04x}'.format(cluster.cluster_id)
self._generic_id = "channel_0x{:04x}".format(cluster.cluster_id)
self._cluster = cluster
self._zha_device = device
self._unique_id = construct_unique_id(cluster)
self._report_config = CLUSTER_REPORT_CONFIGS.get(
self._cluster.cluster_id,
[{'attr': 0, 'config': REPORT_CONFIG_DEFAULT}]
self._cluster.cluster_id, [{"attr": 0, "config": REPORT_CONFIG_DEFAULT}]
)
self._status = ChannelStatus.CREATED
self._cluster.add_listener(self)
@ -130,21 +138,24 @@ class ZigbeeChannel(LogMixin):
manufacturer = None
manufacturer_code = self._zha_device.manufacturer_code
# Xiaomi devices don't need this and it disrupts pairing
if self._zha_device.manufacturer != 'LUMI':
if self.cluster.cluster_id >= 0xfc00 and manufacturer_code:
if self._zha_device.manufacturer != "LUMI":
if self.cluster.cluster_id >= 0xFC00 and manufacturer_code:
manufacturer = manufacturer_code
await bind_cluster(self._unique_id, self.cluster)
if not self.cluster.bind_only:
for report_config in self._report_config:
attr = report_config.get('attr')
min_report_interval, max_report_interval, change = \
report_config.get('config')
attr = report_config.get("attr")
min_report_interval, max_report_interval, change = report_config.get(
"config"
)
await configure_reporting(
self._unique_id, self.cluster, attr,
self._unique_id,
self.cluster,
attr,
min_report=min_report_interval,
max_report=max_report_interval,
reportable_change=change,
manufacturer=manufacturer
manufacturer=manufacturer,
)
await asyncio.sleep(uniform(0.1, 0.5))
@ -153,7 +164,7 @@ class ZigbeeChannel(LogMixin):
async def async_initialize(self, from_cache):
"""Initialize channel."""
self.debug('initializing channel: from_cache: %s', from_cache)
self.debug("initializing channel: from_cache: %s", from_cache)
self._status = ChannelStatus.INITIALIZED
@callback
@ -175,13 +186,13 @@ class ZigbeeChannel(LogMixin):
def zha_send_event(self, cluster, command, args):
"""Relay events to hass."""
self._zha_device.hass.bus.async_fire(
'zha_event',
"zha_event",
{
'unique_id': self._unique_id,
'device_ieee': str(self._zha_device.ieee),
'command': command,
'args': args
}
"unique_id": self._unique_id,
"device_ieee": str(self._zha_device.ieee),
"command": command,
"args": args,
},
)
async def async_update(self):
@ -192,14 +203,14 @@ class ZigbeeChannel(LogMixin):
"""Get the value for an attribute."""
manufacturer = None
manufacturer_code = self._zha_device.manufacturer_code
if self.cluster.cluster_id >= 0xfc00 and manufacturer_code:
if self.cluster.cluster_id >= 0xFC00 and manufacturer_code:
manufacturer = manufacturer_code
result = await safe_read(
self._cluster,
[attribute],
allow_cache=from_cache,
only_cache=from_cache,
manufacturer=manufacturer
manufacturer=manufacturer,
)
return result.get(attribute)
@ -211,14 +222,10 @@ class ZigbeeChannel(LogMixin):
def __getattr__(self, name):
"""Get attribute or a decorated cluster command."""
if hasattr(self._cluster, name) and callable(
getattr(self._cluster, name)):
if hasattr(self._cluster, name) and callable(getattr(self._cluster, name)):
command = getattr(self._cluster, name)
command.__name__ = name
return decorate_command(
self,
command
)
return decorate_command(self, command)
return self.__getattribute__(name)
@ -230,7 +237,7 @@ class AttributeListeningChannel(ZigbeeChannel):
def __init__(self, cluster, device):
"""Initialize AttributeListeningChannel."""
super().__init__(cluster, device)
attr = self._report_config[0].get('attr')
attr = self._report_config[0].get("attr")
if isinstance(attr, str):
self.value_attribute = get_attr_id_by_name(self.cluster, attr)
else:
@ -243,13 +250,14 @@ class AttributeListeningChannel(ZigbeeChannel):
async_dispatcher_send(
self._zha_device.hass,
"{}_{}".format(self.unique_id, SIGNAL_ATTR_UPDATED),
value
value,
)
async def async_initialize(self, from_cache):
"""Initialize listener."""
await self.get_attribute_value(
self._report_config[0].get('attr'), from_cache=from_cache)
self._report_config[0].get("attr"), from_cache=from_cache
)
await super().async_initialize(from_cache)
@ -293,7 +301,8 @@ class ZDOChannel(LogMixin):
async def async_initialize(self, from_cache):
"""Initialize channel."""
entry = self._zha_device.gateway.zha_storage.async_get_or_create(
self._zha_device)
self._zha_device
)
self.debug("entry loaded from storage: %s", entry)
self._status = ChannelStatus.INITIALIZED
@ -320,21 +329,19 @@ class EventRelayChannel(ZigbeeChannel):
self._cluster,
SIGNAL_ATTR_UPDATED,
{
'attribute_id': attrid,
'attribute_name': self._cluster.attributes.get(
attrid,
['Unknown'])[0],
'value': value
}
"attribute_id": attrid,
"attribute_name": self._cluster.attributes.get(attrid, ["Unknown"])[0],
"value": value,
},
)
@callback
def cluster_command(self, tsn, command_id, args):
"""Handle a cluster command received on this cluster."""
if self._cluster.server_commands is not None and \
self._cluster.server_commands.get(command_id) is not None:
if (
self._cluster.server_commands is not None
and self._cluster.server_commands.get(command_id) is not None
):
self.zha_send_event(
self._cluster,
self._cluster.server_commands.get(command_id)[0],
args
self._cluster, self._cluster.server_commands.get(command_id)[0], args
)

View File

@ -12,18 +12,46 @@ import time
from homeassistant.core import callback
from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_send)
async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.event import async_track_time_interval
from .channels import EventRelayChannel
from .const import (
ATTR_ARGS, ATTR_ATTRIBUTE, ATTR_CLUSTER_ID, ATTR_COMMAND,
ATTR_COMMAND_TYPE, ATTR_ENDPOINT_ID, ATTR_MANUFACTURER, ATTR_VALUE,
BATTERY_OR_UNKNOWN, CLIENT_COMMANDS, IEEE, IN, MAINS_POWERED,
MANUFACTURER_CODE, MODEL, NAME, NWK, OUT, POWER_CONFIGURATION_CHANNEL,
POWER_SOURCE, QUIRK_APPLIED, QUIRK_CLASS, SERVER, SERVER_COMMANDS,
SIGNAL_AVAILABLE, UNKNOWN_MANUFACTURER, UNKNOWN_MODEL, ZDO_CHANNEL,
LQI, RSSI, LAST_SEEN, ATTR_AVAILABLE)
ATTR_ARGS,
ATTR_ATTRIBUTE,
ATTR_CLUSTER_ID,
ATTR_COMMAND,
ATTR_COMMAND_TYPE,
ATTR_ENDPOINT_ID,
ATTR_MANUFACTURER,
ATTR_VALUE,
BATTERY_OR_UNKNOWN,
CLIENT_COMMANDS,
IEEE,
IN,
MAINS_POWERED,
MANUFACTURER_CODE,
MODEL,
NAME,
NWK,
OUT,
POWER_CONFIGURATION_CHANNEL,
POWER_SOURCE,
QUIRK_APPLIED,
QUIRK_CLASS,
SERVER,
SERVER_COMMANDS,
SIGNAL_AVAILABLE,
UNKNOWN_MANUFACTURER,
UNKNOWN_MODEL,
ZDO_CHANNEL,
LQI,
RSSI,
LAST_SEEN,
ATTR_AVAILABLE,
)
from .helpers import LogMixin
_LOGGER = logging.getLogger(__name__)
@ -51,22 +79,20 @@ class ZHADevice(LogMixin):
self._all_channels = []
self._available = False
self._available_signal = "{}_{}_{}".format(
self.name, self.ieee, SIGNAL_AVAILABLE)
self.name, self.ieee, SIGNAL_AVAILABLE
)
self._unsub = async_dispatcher_connect(
self.hass,
self._available_signal,
self.async_initialize
self.hass, self._available_signal, self.async_initialize
)
from zigpy.quirks import CustomDevice
self.quirk_applied = isinstance(self._zigpy_device, CustomDevice)
self.quirk_class = "{}.{}".format(
self._zigpy_device.__class__.__module__,
self._zigpy_device.__class__.__name__
self._zigpy_device.__class__.__name__,
)
self._available_check = async_track_time_interval(
self.hass,
self._check_available,
_UPDATE_ALIVE_INTERVAL
self.hass, self._check_available, _UPDATE_ALIVE_INTERVAL
)
self.status = DeviceStatus.CREATED
@ -184,15 +210,9 @@ class ZHADevice(LogMixin):
"""Set sensor availability."""
if self._available != available and available:
# Update the state the first time the device comes online
async_dispatcher_send(
self.hass,
self._available_signal,
False
)
async_dispatcher_send(self.hass, self._available_signal, False)
async_dispatcher_send(
self.hass,
"{}_{}".format(self._available_signal, 'entity'),
available
self.hass, "{}_{}".format(self._available_signal, "entity"), available
)
self._available = available
@ -215,14 +235,16 @@ class ZHADevice(LogMixin):
LQI: self.lqi,
RSSI: self.rssi,
LAST_SEEN: update_time,
ATTR_AVAILABLE: self.available
ATTR_AVAILABLE: self.available,
}
def add_cluster_channel(self, cluster_channel):
"""Add cluster channel to device."""
# only keep 1 power configuration channel
if cluster_channel.name is POWER_CONFIGURATION_CHANNEL and \
POWER_CONFIGURATION_CHANNEL in self.cluster_channels:
if (
cluster_channel.name is POWER_CONFIGURATION_CHANNEL
and POWER_CONFIGURATION_CHANNEL in self.cluster_channels
):
return
if isinstance(cluster_channel, EventRelayChannel):
@ -249,10 +271,9 @@ class ZHADevice(LogMixin):
def get_key(channel):
channel_key = "ZDO"
if hasattr(channel.cluster, 'cluster_id'):
if hasattr(channel.cluster, "cluster_id"):
channel_key = "{}_{}".format(
channel.cluster.endpoint.endpoint_id,
channel.cluster.cluster_id
channel.cluster.endpoint.endpoint_id, channel.cluster.cluster_id
)
return channel_key
@ -273,21 +294,23 @@ class ZHADevice(LogMixin):
async def async_configure(self):
"""Configure the device."""
self.debug('started configuration')
self.debug("started configuration")
await self._execute_channel_tasks(
self.get_channels_to_configure(), 'async_configure')
self.debug('completed configuration')
self.get_channels_to_configure(), "async_configure"
)
self.debug("completed configuration")
entry = self.gateway.zha_storage.async_create_or_update(self)
self.debug('stored in registry: %s', entry)
self.debug("stored in registry: %s", entry)
async def async_initialize(self, from_cache=False):
"""Initialize channels."""
self.debug('started initialization')
self.debug("started initialization")
await self._execute_channel_tasks(
self.all_channels, 'async_initialize', from_cache)
self.debug('power source: %s', self.power_source)
self.all_channels, "async_initialize", from_cache
)
self.debug("power source: %s", self.power_source)
self.status = DeviceStatus.INITIALIZED
self.debug('completed initialization')
self.debug("completed initialization")
async def _execute_channel_tasks(self, channels, task_name, *args):
"""Gather and execute a set of CHANNEL tasks."""
@ -299,11 +322,12 @@ class ZHADevice(LogMixin):
# pylint: disable=E1111
if zdo_task is None: # We only want to do this once
zdo_task = self._async_create_task(
semaphore, channel, task_name, *args)
semaphore, channel, task_name, *args
)
else:
channel_tasks.append(
self._async_create_task(
semaphore, channel, task_name, *args))
self._async_create_task(semaphore, channel, task_name, *args)
)
if zdo_task is not None:
await zdo_task
await asyncio.gather(*channel_tasks)
@ -332,10 +356,8 @@ class ZHADevice(LogMixin):
def async_get_clusters(self):
"""Get all clusters for this device."""
return {
ep_id: {
IN: endpoint.in_clusters,
OUT: endpoint.out_clusters
} for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
ep_id: {IN: endpoint.in_clusters, OUT: endpoint.out_clusters}
for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
if ep_id != 0
}
@ -343,15 +365,11 @@ class ZHADevice(LogMixin):
def async_get_std_clusters(self):
"""Get ZHA and ZLL clusters for this device."""
from zigpy.profiles import zha, zll
return {
ep_id: {
IN: endpoint.in_clusters,
OUT: endpoint.out_clusters
} for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
if ep_id != 0 and endpoint.profile_id in (
zha.PROFILE_ID,
zll.PROFILE_ID
)
ep_id: {IN: endpoint.in_clusters, OUT: endpoint.out_clusters}
for (ep_id, endpoint) in self._zigpy_device.endpoints.items()
if ep_id != 0 and endpoint.profile_id in (zha.PROFILE_ID, zll.PROFILE_ID)
}
@callback
@ -361,18 +379,15 @@ class ZHADevice(LogMixin):
return clusters[endpoint_id][cluster_type][cluster_id]
@callback
def async_get_cluster_attributes(self, endpoint_id, cluster_id,
cluster_type=IN):
def async_get_cluster_attributes(self, endpoint_id, cluster_id, cluster_type=IN):
"""Get zigbee attributes for specified cluster."""
cluster = self.async_get_cluster(endpoint_id, cluster_id,
cluster_type)
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None:
return None
return cluster.attributes
@callback
def async_get_cluster_commands(self, endpoint_id, cluster_id,
cluster_type=IN):
def async_get_cluster_commands(self, endpoint_id, cluster_id, cluster_type=IN):
"""Get zigbee commands for specified cluster."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None:
@ -382,64 +397,77 @@ class ZHADevice(LogMixin):
SERVER_COMMANDS: cluster.server_commands,
}
async def write_zigbee_attribute(self, endpoint_id, cluster_id,
attribute, value, cluster_type=IN,
manufacturer=None):
async def write_zigbee_attribute(
self,
endpoint_id,
cluster_id,
attribute,
value,
cluster_type=IN,
manufacturer=None,
):
"""Write a value to a zigbee attribute for a cluster in this entity."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None:
return None
from zigpy.exceptions import DeliveryError
try:
response = await cluster.write_attributes(
{attribute: value},
manufacturer=manufacturer
{attribute: value}, manufacturer=manufacturer
)
self.debug(
'set: %s for attr: %s to cluster: %s for ept: %s - res: %s',
"set: %s for attr: %s to cluster: %s for ept: %s - res: %s",
value,
attribute,
cluster_id,
endpoint_id,
response
response,
)
return response
except DeliveryError as exc:
self.debug(
'failed to set attribute: %s %s %s %s %s',
'{}: {}'.format(ATTR_VALUE, value),
'{}: {}'.format(ATTR_ATTRIBUTE, attribute),
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_id),
'{}: {}'.format(ATTR_ENDPOINT_ID, endpoint_id),
exc
"failed to set attribute: %s %s %s %s %s",
"{}: {}".format(ATTR_VALUE, value),
"{}: {}".format(ATTR_ATTRIBUTE, attribute),
"{}: {}".format(ATTR_CLUSTER_ID, cluster_id),
"{}: {}".format(ATTR_ENDPOINT_ID, endpoint_id),
exc,
)
return None
async def issue_cluster_command(self, endpoint_id, cluster_id, command,
command_type, args, cluster_type=IN,
manufacturer=None):
async def issue_cluster_command(
self,
endpoint_id,
cluster_id,
command,
command_type,
args,
cluster_type=IN,
manufacturer=None,
):
"""Issue a command against specified zigbee cluster on this entity."""
cluster = self.async_get_cluster(endpoint_id, cluster_id, cluster_type)
if cluster is None:
return None
response = None
if command_type == SERVER:
response = await cluster.command(command, *args,
manufacturer=manufacturer,
expect_reply=True)
response = await cluster.command(
command, *args, manufacturer=manufacturer, expect_reply=True
)
else:
response = await cluster.client_command(command, *args)
self.debug(
'Issued cluster command: %s %s %s %s %s %s %s',
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_id),
'{}: {}'.format(ATTR_COMMAND, command),
'{}: {}'.format(ATTR_COMMAND_TYPE, command_type),
'{}: {}'.format(ATTR_ARGS, args),
'{}: {}'.format(ATTR_CLUSTER_ID, cluster_type),
'{}: {}'.format(ATTR_MANUFACTURER, manufacturer),
'{}: {}'.format(ATTR_ENDPOINT_ID, endpoint_id)
"Issued cluster command: %s %s %s %s %s %s %s",
"{}: {}".format(ATTR_CLUSTER_ID, cluster_id),
"{}: {}".format(ATTR_COMMAND, command),
"{}: {}".format(ATTR_COMMAND_TYPE, command_type),
"{}: {}".format(ATTR_ARGS, args),
"{}: {}".format(ATTR_CLUSTER_ID, cluster_type),
"{}: {}".format(ATTR_MANUFACTURER, manufacturer),
"{}: {}".format(ATTR_ENDPOINT_ID, endpoint_id),
)
return response

View File

@ -12,13 +12,19 @@ from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.util import slugify
from .core.const import (
ATTR_MANUFACTURER, DATA_ZHA, DATA_ZHA_BRIDGE_ID, DOMAIN, MODEL, NAME,
SIGNAL_REMOVE)
ATTR_MANUFACTURER,
DATA_ZHA,
DATA_ZHA_BRIDGE_ID,
DOMAIN,
MODEL,
NAME,
SIGNAL_REMOVE,
)
from .core.helpers import LogMixin
_LOGGER = logging.getLogger(__name__)
ENTITY_SUFFIX = 'entity_suffix'
ENTITY_SUFFIX = "entity_suffix"
RESTART_GRACE_PERIOD = 7200 # 2 hours
@ -27,29 +33,28 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
_domain = None # Must be overridden by subclasses
def __init__(self, unique_id, zha_device, channels,
skip_entity_id=False, **kwargs):
def __init__(self, unique_id, zha_device, channels, skip_entity_id=False, **kwargs):
"""Init ZHA entity."""
self._force_update = False
self._should_poll = False
self._unique_id = unique_id
if not skip_entity_id:
ieee = zha_device.ieee
ieeetail = ''.join(['%02x' % (o, ) for o in ieee[-4:]])
ieeetail = "".join(["%02x" % (o,) for o in ieee[-4:]])
self.entity_id = "{}.{}_{}_{}_{}{}".format(
self._domain,
slugify(zha_device.manufacturer),
slugify(zha_device.model),
ieeetail,
channels[0].cluster.endpoint.endpoint_id,
kwargs.get(ENTITY_SUFFIX, ''),
kwargs.get(ENTITY_SUFFIX, ""),
)
self._state = None
self._device_state_attributes = {}
self._zha_device = zha_device
self.cluster_channels = {}
self._available = False
self._component = kwargs['component']
self._component = kwargs["component"]
self._unsubs = []
self.remove_future = None
for channel in channels:
@ -89,15 +94,14 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
def device_info(self):
"""Return a device description for device registry."""
zha_device_info = self._zha_device.device_info
ieee = zha_device_info['ieee']
ieee = zha_device_info["ieee"]
return {
'connections': {(CONNECTION_ZIGBEE, ieee)},
'identifiers': {(DOMAIN, ieee)},
"connections": {(CONNECTION_ZIGBEE, ieee)},
"identifiers": {(DOMAIN, ieee)},
ATTR_MANUFACTURER: zha_device_info[ATTR_MANUFACTURER],
MODEL: zha_device_info[MODEL],
NAME: zha_device_info[NAME],
'via_device': (
DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]),
"via_device": (DOMAIN, self.hass.data[DATA_ZHA][DATA_ZHA_BRIDGE_ID]),
}
@property
@ -112,9 +116,7 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
def async_update_state_attribute(self, key, value):
"""Update a single device state attribute."""
self._device_state_attributes.update({
key: value
})
self._device_state_attributes.update({key: value})
self.async_schedule_update_ha_state()
def async_set_state(self, state):
@ -127,24 +129,34 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
self.remove_future = asyncio.Future()
await self.async_check_recently_seen()
await self.async_accept_signal(
None, "{}_{}".format(self.zha_device.available_signal, 'entity'),
None,
"{}_{}".format(self.zha_device.available_signal, "entity"),
self.async_set_available,
signal_override=True)
signal_override=True,
)
await self.async_accept_signal(
None, "{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
None,
"{}_{}".format(SIGNAL_REMOVE, str(self.zha_device.ieee)),
self.async_remove,
signal_override=True
signal_override=True,
)
self._zha_device.gateway.register_entity_reference(
self._zha_device.ieee, self.entity_id, self._zha_device,
self.cluster_channels, self.device_info, self.remove_future)
self._zha_device.ieee,
self.entity_id,
self._zha_device,
self.cluster_channels,
self.device_info,
self.remove_future,
)
async def async_check_recently_seen(self):
"""Check if the device was seen within the last 2 hours."""
last_state = await self.async_get_last_state()
if last_state and self._zha_device.last_seen and (
time.time() - self._zha_device.last_seen <
RESTART_GRACE_PERIOD):
if (
last_state
and self._zha_device.last_seen
and (time.time() - self._zha_device.last_seen < RESTART_GRACE_PERIOD)
):
self.async_set_available(True)
if not self.zha_device.is_mains_powered:
# mains powered devices will get real time state
@ -167,24 +179,17 @@ class ZhaEntity(RestoreEntity, LogMixin, entity.Entity):
async def async_update(self):
"""Retrieve latest state."""
for channel in self.cluster_channels.values():
if hasattr(channel, 'async_update'):
if hasattr(channel, "async_update"):
await channel.async_update()
async def async_accept_signal(self, channel, signal, func,
signal_override=False):
async def async_accept_signal(self, channel, signal, func, signal_override=False):
"""Accept a signal from a channel."""
unsub = None
if signal_override:
unsub = async_dispatcher_connect(
self.hass,
signal,
func
)
unsub = async_dispatcher_connect(self.hass, signal, func)
else:
unsub = async_dispatcher_connect(
self.hass,
"{}_{}".format(channel.unique_id, signal),
func
self.hass, "{}_{}".format(channel.unique_id, signal), func
)
self._unsubs.append(unsub)

View File

@ -9,21 +9,29 @@ import random
import string
from functools import wraps
from types import MappingProxyType
from typing import (Any, Optional, TypeVar, Callable, KeysView, Union, # noqa
Iterable, List, Dict, Iterator, Coroutine, MutableSet)
from typing import (
Any,
Optional,
TypeVar,
Callable,
KeysView,
Union, # noqa
Iterable,
Coroutine,
)
import slugify as unicode_slug
from .dt import as_local, utcnow
# pylint: disable=invalid-name
T = TypeVar('T')
U = TypeVar('U')
ENUM_T = TypeVar('ENUM_T', bound=enum.Enum)
T = TypeVar("T")
U = TypeVar("U")
ENUM_T = TypeVar("ENUM_T", bound=enum.Enum)
# pylint: enable=invalid-name
RE_SANITIZE_FILENAME = re.compile(r'(~|\.\.|/|\\)')
RE_SANITIZE_PATH = re.compile(r'(~|\.(\.)+)')
RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
def sanitize_filename(filename: str) -> str:
@ -38,23 +46,24 @@ def sanitize_path(path: str) -> str:
def slugify(text: str) -> str:
"""Slugify a given text."""
return unicode_slug.slugify(text, separator='_') # type: ignore
return unicode_slug.slugify(text, separator="_") # type: ignore
def repr_helper(inp: Any) -> str:
"""Help creating a more readable string representation of objects."""
if isinstance(inp, (dict, MappingProxyType)):
return ", ".join(
repr_helper(key)+"="+repr_helper(item) for key, item
in inp.items())
repr_helper(key) + "=" + repr_helper(item) for key, item in inp.items()
)
if isinstance(inp, datetime):
return as_local(inp).isoformat()
return str(inp)
def convert(value: Optional[T], to_type: Callable[[T], U],
default: Optional[U] = None) -> Optional[U]:
def convert(
value: Optional[T], to_type: Callable[[T], U], default: Optional[U] = None
) -> Optional[U]:
"""Convert value to to_type, returns default if fails."""
try:
return default if value is None else to_type(value)
@ -63,8 +72,9 @@ def convert(value: Optional[T], to_type: Callable[[T], U],
return default
def ensure_unique_string(preferred_string: str, current_strings:
Union[Iterable[str], KeysView[str]]) -> str:
def ensure_unique_string(
preferred_string: str, current_strings: Union[Iterable[str], KeysView[str]]
) -> str:
"""Return a string that is not present in current_strings.
If preferred string exists will append _2, _3, ..
@ -88,14 +98,14 @@ def get_local_ip() -> str:
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Use Google Public DNS server to determine own IP
sock.connect(('8.8.8.8', 80))
sock.connect(("8.8.8.8", 80))
return sock.getsockname()[0] # type: ignore
except socket.error:
try:
return socket.gethostbyname(socket.gethostname())
except socket.gaierror:
return '127.0.0.1'
return "127.0.0.1"
finally:
sock.close()
@ -106,7 +116,7 @@ def get_random_string(length: int = 10) -> str:
generator = random.SystemRandom()
source_chars = string.ascii_letters + string.digits
return ''.join(generator.choice(source_chars) for _ in range(length))
return "".join(generator.choice(source_chars) for _ in range(length))
class OrderedEnum(enum.Enum):
@ -158,8 +168,9 @@ class Throttle:
Adds a datetime attribute `last_call` to the method.
"""
def __init__(self, min_time: timedelta,
limit_no_throttle: Optional[timedelta] = None) -> None:
def __init__(
self, min_time: timedelta, limit_no_throttle: Optional[timedelta] = None
) -> None:
"""Initialize the throttle."""
self.min_time = min_time
self.limit_no_throttle = limit_no_throttle
@ -168,10 +179,13 @@ class Throttle:
"""Caller for the throttle."""
# Make sure we return a coroutine if the method is async.
if asyncio.iscoroutinefunction(method):
async def throttled_value() -> None:
"""Stand-in function for when real func is being throttled."""
return None
else:
def throttled_value() -> None: # type: ignore
"""Stand-in function for when real func is being throttled."""
return None
@ -189,8 +203,10 @@ class Throttle:
# All methods have the classname in their qualname separated by a '.'
# Functions have a '.' in their qualname if defined inline, but will
# be prefixed by '.<locals>.' so we strip that out.
is_func = (not hasattr(method, '__self__') and
'.' not in method.__qualname__.split('.<locals>.')[-1])
is_func = (
not hasattr(method, "__self__")
and "." not in method.__qualname__.split(".<locals>.")[-1]
)
@wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Union[Callable, Coroutine]:
@ -199,14 +215,14 @@ class Throttle:
If we cannot acquire the lock, it is running so return None.
"""
# pylint: disable=protected-access
if hasattr(method, '__self__'):
host = getattr(method, '__self__')
if hasattr(method, "__self__"):
host = getattr(method, "__self__")
elif is_func:
host = wrapper
else:
host = args[0] if args else wrapper
if not hasattr(host, '_throttle'):
if not hasattr(host, "_throttle"):
host._throttle = {}
if id(self) not in host._throttle:
@ -217,7 +233,7 @@ class Throttle:
return throttled_value()
# Check if method is never called or no_throttle is given
force = kwargs.pop('no_throttle', False) or not throttle[1]
force = kwargs.pop("no_throttle", False) or not throttle[1]
try:
if force or utcnow() - throttle[1] > self.min_time:

View File

@ -6,6 +6,7 @@ good-names=i,j,k,ex,Run,_,fp
[MESSAGES CONTROL]
# Reasons disabled:
# format - handled by black
# locally-disabled - it spams too much
# duplicate-code - unavoidable
# cyclic-import - doesn't test if both import on load
@ -20,6 +21,7 @@ good-names=i,j,k,ex,Run,_,fp
# not-an-iterable - https://github.com/PyCQA/pylint/issues/2311
# unnecessary-pass - readability for functions which only contain pass
disable=
format,
abstract-class-little-used,
abstract-method,
cyclic-import,

3
pyproject.toml Normal file
View File

@ -0,0 +1,3 @@
[tool.black]
target-version = ["py36", "py37", "py38"]
exclude = 'generated'

View File

@ -1,7 +1,10 @@
# linters such as flake8 and pylint should be pinned, as new releases
# make new things fail. Manually update these pins when pulling in a
# new version
# When updating this file, update .pre-commit-config.yaml too
asynctest==0.13.0
black==19.3b0
codecov==2.0.15
coveralls==1.2.0
flake8-docstrings==1.3.0
@ -16,3 +19,4 @@ pytest-sugar==0.9.2
pytest-timeout==1.3.3
pytest==5.0.1
requests_mock==1.6.0
pre-commit==1.17.0

View File

@ -2,7 +2,10 @@
# linters such as flake8 and pylint should be pinned, as new releases
# make new things fail. Manually update these pins when pulling in a
# new version
# When updating this file, update .pre-commit-config.yaml too
asynctest==0.13.0
black==19.3b0
codecov==2.0.15
coveralls==1.2.0
flake8-docstrings==1.3.0
@ -17,6 +20,7 @@ pytest-sugar==0.9.2
pytest-timeout==1.3.3
pytest==5.0.1
requests_mock==1.6.0
pre-commit==1.17.0
# homeassistant.components.homekit

View File

@ -7,4 +7,4 @@ set -e
cd "$(dirname "$0")/.."
echo "Installing test dependencies..."
python3 -m pip install tox colorlog
python3 -m pip install tox colorlog pre-commit

10
script/check_format Executable file
View File

@ -0,0 +1,10 @@
#!/bin/sh
# Format code with black.
cd "$(dirname "$0")/.."
black \
--check \
--fast \
--quiet \
homeassistant tests script

View File

@ -7,4 +7,5 @@ set -e
cd "$(dirname "$0")/.."
script/bootstrap
pre-commit install
pip3 install -e .

View File

@ -21,12 +21,27 @@ norecursedirs = .git testing_config
[flake8]
exclude = .venv,.git,.tox,docs,venv,bin,lib,deps,build
# To work with Black
max-line-length = 88
# E501: line too long
# W503: Line break occurred before a binary operator
# E203: Whitespace before ':'
# D202 No blank lines allowed after function docstring
ignore =
E501,
W503,
E203,
D202
[isort]
# https://github.com/timothycrosley/isort
# https://github.com/timothycrosley/isort/wiki/isort-Settings
# splits long import on multiple lines indented by 4 spaces
multi_line_output = 4
multi_line_output = 3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88
indent = " "
# by default isort don't check module indexes
not_skip = __init__.py
@ -37,4 +52,3 @@ default_section = THIRDPARTY
known_first_party = homeassistant,tests
forced_separate = tests
combine_as_imports = true
use_parentheses = true