diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index 2d449285493d..ec693f312f37 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -236,11 +236,18 @@ class Throttle(object): if self.limit_no_throttle is not None: method = Throttle(self.limit_no_throttle)(method) - # We want to be able to differentiate between function and method calls + # Different methods that can be passed in: + # - a function + # - an unbound function on a class + # - a method (bound function on a class) + + # We want to be able to differentiate between function and unbound + # methods (which are considered functions). # All methods have the classname in their qualname seperated by a '.' # Functions have a '.' in their qualname if defined inline, but will # be prefixed by '..' so we strip that out. - is_func = '.' not in method.__qualname__.split('..')[-1] + is_func = (not hasattr(method, '__self__') and + '.' not in method.__qualname__.split('..')[-1]) @wraps(method) def wrapper(*args, **kwargs): @@ -248,8 +255,13 @@ class Throttle(object): Wrapper that allows wrapped to be called only once per min_time. If we cannot acquire the lock, it is running so return None. """ - # pylint: disable=protected-access - host = wrapper if is_func else args[0] + if hasattr(method, '__self__'): + host = method.__self__ + elif is_func: + host = wrapper + else: + host = args[0] if args else wrapper + if not hasattr(host, '_throttle_lock'): host._throttle_lock = threading.Lock() diff --git a/tests/util/test_init.py b/tests/util/test_init.py index 2bf917f4e25c..2e520ac4980f 100644 --- a/tests/util/test_init.py +++ b/tests/util/test_init.py @@ -229,3 +229,16 @@ class TestUtil(unittest.TestCase): self.assertTrue(Tester().hello()) self.assertTrue(Tester().hello()) + + def test_throttle_on_method(self): + """ Test that throttle works when wrapping a method. """ + + class Tester(object): + def hello(self): + return True + + tester = Tester() + throttled = util.Throttle(timedelta(seconds=1))(tester.hello) + + self.assertTrue(throttled()) + self.assertIsNone(throttled())