1
mirror of https://github.com/home-assistant/core synced 2024-08-02 23:40:32 +02:00

StateMachine is now case insensitive for entity ids

This commit is contained in:
Paulus Schoutsen 2014-12-26 23:26:39 -08:00
parent df3ddea23c
commit 89a548252a
2 changed files with 28 additions and 8 deletions

View File

@ -285,6 +285,19 @@ class TestStateMachine(unittest.TestCase):
self.assertEqual(1, len(specific_runs))
self.assertEqual(3, len(wildcard_runs))
def test_case_insensitivty(self):
runs = []
self.states.track_change(
'light.BoWl', lambda a, b, c: runs.append(1),
ha.MATCH_ALL, ha.MATCH_ALL)
self.states.set('light.BOWL', 'off')
self.bus._pool.block_till_done()
self.assertTrue(self.states.is_state('light.bowl', 'off'))
self.assertEqual(1, len(runs))
class TestServiceCall(unittest.TestCase):
""" Test ServiceCall class. """

View File

@ -15,6 +15,8 @@ import re
import datetime as dt
import functools as ft
from requests.structures import CaseInsensitiveDict
from homeassistant.const import (
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
SERVICE_HOMEASSISTANT_STOP, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED,
@ -482,15 +484,18 @@ class StateMachine(object):
""" Helper class that tracks the state of different entities. """
def __init__(self, bus):
self._states = {}
self._states = CaseInsensitiveDict()
self._bus = bus
self._lock = threading.Lock()
def entity_ids(self, domain_filter=None):
""" List of entity ids that are being tracked. """
if domain_filter is not None:
return [entity_id for entity_id in self._states.keys()
if util.split_entity_id(entity_id)[0] == domain_filter]
domain_filter = domain_filter.lower()
return [state.entity_id for key, state
in self._states.lower_items()
if util.split_entity_id(key)[0] == domain_filter]
else:
return list(self._states.keys())
@ -524,9 +529,9 @@ class StateMachine(object):
self._states[entity_id].state == state)
def remove(self, entity_id):
""" Removes a entity from the state machine.
""" Removes an entity from the state machine.
Returns boolean to indicate if a entity was removed. """
Returns boolean to indicate if an entity was removed. """
with self._lock:
return self._states.pop(entity_id, None) is not None
@ -567,14 +572,16 @@ class StateMachine(object):
from_state = _process_match_param(from_state)
to_state = _process_match_param(to_state)
# Ensure it is a list with entity ids we want to match on
# Ensure it is a lowercase list with entity ids we want to match on
if isinstance(entity_ids, str):
entity_ids = [entity_ids]
entity_ids = [entity_ids.lower()]
else:
entity_ids = [entity_id.lower() for entity_id in entity_ids]
@ft.wraps(action)
def state_listener(event):
""" The listener that listens for specific state changes. """
if event.data['entity_id'] in entity_ids and \
if event.data['entity_id'].lower() in entity_ids and \
'old_state' in event.data and \
_matcher(event.data['old_state'].state, from_state) and \
_matcher(event.data['new_state'].state, to_state):