1
mirror of https://github.com/home-assistant/core synced 2024-09-09 12:51:22 +02:00
* Add context

* Add context to switch/light services

* Test set_state API

* Lint

* Fix tests

* Do not include context yet in comparison

* Do not pass in loop

* Fix Z-Wave tests

* Add websocket test without user
This commit is contained in:
Paulus Schoutsen 2018-07-29 01:53:37 +01:00 committed by Jason Hu
parent 867f80715e
commit c7f4bdafc0
16 changed files with 363 additions and 109 deletions

View File

@ -220,7 +220,8 @@ class APIEntityStateView(HomeAssistantView):
is_new_state = hass.states.get(entity_id) is None is_new_state = hass.states.get(entity_id) is None
# Write state # Write state
hass.states.async_set(entity_id, new_state, attributes, force_update) hass.states.async_set(entity_id, new_state, attributes, force_update,
self.context(request))
# Read the state back for our response # Read the state back for our response
status_code = HTTP_CREATED if is_new_state else 200 status_code = HTTP_CREATED if is_new_state else 200
@ -279,7 +280,8 @@ class APIEventView(HomeAssistantView):
event_data[key] = state event_data[key] = state
request.app['hass'].bus.async_fire( request.app['hass'].bus.async_fire(
event_type, event_data, ha.EventOrigin.remote) event_type, event_data, ha.EventOrigin.remote,
self.context(request))
return self.json_message("Event {} fired.".format(event_type)) return self.json_message("Event {} fired.".format(event_type))
@ -316,7 +318,8 @@ class APIDomainServicesView(HomeAssistantView):
"Data should be valid JSON.", HTTP_BAD_REQUEST) "Data should be valid JSON.", HTTP_BAD_REQUEST)
with AsyncTrackStates(hass) as changed_states: with AsyncTrackStates(hass) as changed_states:
await hass.services.async_call(domain, service, data, True) await hass.services.async_call(
domain, service, data, True, self.context(request))
return self.json(changed_states) return self.json(changed_states)

View File

@ -13,7 +13,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError
import homeassistant.remote as rem import homeassistant.remote as rem
from homeassistant.components.http.ban import process_success_login from homeassistant.components.http.ban import process_success_login
from homeassistant.core import is_callback from homeassistant.core import Context, is_callback
from homeassistant.const import CONTENT_TYPE_JSON from homeassistant.const import CONTENT_TYPE_JSON
from .const import KEY_AUTHENTICATED, KEY_REAL_IP from .const import KEY_AUTHENTICATED, KEY_REAL_IP
@ -32,6 +32,14 @@ class HomeAssistantView:
cors_allowed = False cors_allowed = False
# pylint: disable=no-self-use # pylint: disable=no-self-use
def context(self, request):
"""Generate a context from a request."""
user = request.get('hass_user')
if user is None:
return Context()
return Context(user_id=user.id)
def json(self, result, status_code=200, headers=None): def json(self, result, status_code=200, headers=None):
"""Return a JSON response.""" """Return a JSON response."""
try: try:

View File

@ -359,7 +359,9 @@ async def async_setup(hass, config):
if not light.should_poll: if not light.should_poll:
continue continue
update_tasks.append(light.async_update_ha_state(True))
update_tasks.append(
light.async_update_ha_state(True, service.context))
if update_tasks: if update_tasks:
await asyncio.wait(update_tasks, loop=hass.loop) await asyncio.wait(update_tasks, loop=hass.loop)

View File

@ -114,7 +114,8 @@ async def async_setup(hass, config):
if not switch.should_poll: if not switch.should_poll:
continue continue
update_tasks.append(switch.async_update_ha_state(True)) update_tasks.append(
switch.async_update_ha_state(True, service.context))
if update_tasks: if update_tasks:
await asyncio.wait(update_tasks, loop=hass.loop) await asyncio.wait(update_tasks, loop=hass.loop)

View File

@ -18,7 +18,7 @@ from voluptuous.humanize import humanize_error
from homeassistant.const import ( from homeassistant.const import (
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
__version__) __version__)
from homeassistant.core import callback from homeassistant.core import Context, callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.remote import JSONEncoder from homeassistant.remote import JSONEncoder
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
@ -262,6 +262,18 @@ class ActiveConnection:
self._handle_task = None self._handle_task = None
self._writer_task = None self._writer_task = None
@property
def user(self):
"""Return the user associated with the connection."""
return self.request.get('hass_user')
def context(self, msg):
"""Return a context."""
user = self.user
if user is None:
return Context()
return Context(user_id=user.id)
def debug(self, message1, message2=''): def debug(self, message1, message2=''):
"""Print a debug message.""" """Print a debug message."""
_LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2) _LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)
@ -287,7 +299,7 @@ class ActiveConnection:
@callback @callback
def send_message_outside(self, message): def send_message_outside(self, message):
"""Send a message to the client outside of the main task. """Send a message to the client.
Closes connection if the client is not reading the messages. Closes connection if the client is not reading the messages.
@ -508,7 +520,8 @@ def handle_call_service(hass, connection, msg):
async def call_service_helper(msg): async def call_service_helper(msg):
"""Call a service and fire complete message.""" """Call a service and fire complete message."""
await hass.services.async_call( await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), True) msg['domain'], msg['service'], msg.get('service_data'), True,
connection.context(msg))
connection.send_message_outside(result_message(msg['id'])) connection.send_message_outside(result_message(msg['id']))
hass.async_add_job(call_service_helper(msg)) hass.async_add_job(call_service_helper(msg))

View File

@ -224,9 +224,6 @@ ATTR_ID = 'id'
# Name # Name
ATTR_NAME = 'name' ATTR_NAME = 'name'
# Data for a SERVICE_EXECUTED event
ATTR_SERVICE_CALL_ID = 'service_call_id'
# Contains one string or a list of strings, each being an entity id # Contains one string or a list of strings, each being an entity id
ATTR_ENTITY_ID = 'entity_id' ATTR_ENTITY_ID = 'entity_id'

View File

@ -15,6 +15,7 @@ import re
import sys import sys
import threading import threading
from time import monotonic from time import monotonic
import uuid
from types import MappingProxyType from types import MappingProxyType
# pylint: disable=unused-import # pylint: disable=unused-import
@ -23,12 +24,13 @@ from typing import ( # NOQA
TYPE_CHECKING, Awaitable, Iterator) TYPE_CHECKING, Awaitable, Iterator)
from async_timeout import timeout from async_timeout import timeout
import attr
import voluptuous as vol import voluptuous as vol
from voluptuous.humanize import humanize_error from voluptuous.humanize import humanize_error
from homeassistant.const import ( from homeassistant.const import (
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE, ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED, EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
@ -191,7 +193,7 @@ class HomeAssistant:
try: try:
# Only block for EVENT_HOMEASSISTANT_START listener # Only block for EVENT_HOMEASSISTANT_START listener
self.async_stop_track_tasks() self.async_stop_track_tasks()
with timeout(TIMEOUT_EVENT_START, loop=self.loop): with timeout(TIMEOUT_EVENT_START):
await self.async_block_till_done() await self.async_block_till_done()
except asyncio.TimeoutError: except asyncio.TimeoutError:
_LOGGER.warning( _LOGGER.warning(
@ -201,7 +203,7 @@ class HomeAssistant:
', '.join(self.config.components)) ', '.join(self.config.components))
# Allow automations to set up the start triggers before changing state # Allow automations to set up the start triggers before changing state
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0)
self.state = CoreState.running self.state = CoreState.running
_async_create_timer(self) _async_create_timer(self)
@ -307,16 +309,16 @@ class HomeAssistant:
async def async_block_till_done(self) -> None: async def async_block_till_done(self) -> None:
"""Block till all pending work is done.""" """Block till all pending work is done."""
# To flush out any call_soon_threadsafe # To flush out any call_soon_threadsafe
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0)
while self._pending_tasks: while self._pending_tasks:
pending = [task for task in self._pending_tasks pending = [task for task in self._pending_tasks
if not task.done()] if not task.done()]
self._pending_tasks.clear() self._pending_tasks.clear()
if pending: if pending:
await asyncio.wait(pending, loop=self.loop) await asyncio.wait(pending)
else: else:
await asyncio.sleep(0, loop=self.loop) await asyncio.sleep(0)
def stop(self) -> None: def stop(self) -> None:
"""Stop Home Assistant and shuts down all threads.""" """Stop Home Assistant and shuts down all threads."""
@ -343,6 +345,27 @@ class HomeAssistant:
self.loop.stop() self.loop.stop()
@attr.s(slots=True, frozen=True)
class Context:
"""The context that triggered something."""
user_id = attr.ib(
type=str,
default=None,
)
id = attr.ib(
type=str,
default=attr.Factory(lambda: uuid.uuid4().hex),
)
def as_dict(self) -> dict:
"""Return a dictionary representation of the context."""
return {
'id': self.id,
'user_id': self.user_id,
}
class EventOrigin(enum.Enum): class EventOrigin(enum.Enum):
"""Represent the origin of an event.""" """Represent the origin of an event."""
@ -357,16 +380,18 @@ class EventOrigin(enum.Enum):
class Event: class Event:
"""Representation of an event within the bus.""" """Representation of an event within the bus."""
__slots__ = ['event_type', 'data', 'origin', 'time_fired'] __slots__ = ['event_type', 'data', 'origin', 'time_fired', 'context']
def __init__(self, event_type: str, data: Optional[Dict] = None, def __init__(self, event_type: str, data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local, origin: EventOrigin = EventOrigin.local,
time_fired: Optional[int] = None) -> None: time_fired: Optional[int] = None,
context: Optional[Context] = None) -> None:
"""Initialize a new event.""" """Initialize a new event."""
self.event_type = event_type self.event_type = event_type
self.data = data or {} self.data = data or {}
self.origin = origin self.origin = origin
self.time_fired = time_fired or dt_util.utcnow() self.time_fired = time_fired or dt_util.utcnow()
self.context = context or Context()
def as_dict(self) -> Dict: def as_dict(self) -> Dict:
"""Create a dict representation of this Event. """Create a dict representation of this Event.
@ -378,6 +403,7 @@ class Event:
'data': dict(self.data), 'data': dict(self.data),
'origin': str(self.origin), 'origin': str(self.origin),
'time_fired': self.time_fired, 'time_fired': self.time_fired,
'context': self.context.as_dict()
} }
def __repr__(self) -> str: def __repr__(self) -> str:
@ -425,14 +451,16 @@ class EventBus:
).result() ).result()
def fire(self, event_type: str, event_data: Optional[Dict] = None, def fire(self, event_type: str, event_data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local) -> None: origin: EventOrigin = EventOrigin.local,
context: Optional[Context] = None) -> None:
"""Fire an event.""" """Fire an event."""
self._hass.loop.call_soon_threadsafe( self._hass.loop.call_soon_threadsafe(
self.async_fire, event_type, event_data, origin) self.async_fire, event_type, event_data, origin, context)
@callback @callback
def async_fire(self, event_type: str, event_data: Optional[Dict] = None, def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
origin: EventOrigin = EventOrigin.local) -> None: origin: EventOrigin = EventOrigin.local,
context: Optional[Context] = None) -> None:
"""Fire an event. """Fire an event.
This method must be run in the event loop. This method must be run in the event loop.
@ -445,7 +473,7 @@ class EventBus:
event_type != EVENT_HOMEASSISTANT_CLOSE): event_type != EVENT_HOMEASSISTANT_CLOSE):
listeners = match_all_listeners + listeners listeners = match_all_listeners + listeners
event = Event(event_type, event_data, origin) event = Event(event_type, event_data, origin, None, context)
if event_type != EVENT_TIME_CHANGED: if event_type != EVENT_TIME_CHANGED:
_LOGGER.info("Bus:Handling %s", event) _LOGGER.info("Bus:Handling %s", event)
@ -569,15 +597,17 @@ class State:
attributes: extra information on entity and state attributes: extra information on entity and state
last_changed: last time the state was changed, not the attributes. last_changed: last time the state was changed, not the attributes.
last_updated: last time this object was updated. last_updated: last time this object was updated.
context: Context in which it was created
""" """
__slots__ = ['entity_id', 'state', 'attributes', __slots__ = ['entity_id', 'state', 'attributes',
'last_changed', 'last_updated'] 'last_changed', 'last_updated', 'context']
def __init__(self, entity_id: str, state: Any, def __init__(self, entity_id: str, state: Any,
attributes: Optional[Dict] = None, attributes: Optional[Dict] = None,
last_changed: Optional[datetime.datetime] = None, last_changed: Optional[datetime.datetime] = None,
last_updated: Optional[datetime.datetime] = None) -> None: last_updated: Optional[datetime.datetime] = None,
context: Optional[Context] = None) -> None:
"""Initialize a new state.""" """Initialize a new state."""
state = str(state) state = str(state)
@ -596,6 +626,7 @@ class State:
self.attributes = MappingProxyType(attributes or {}) self.attributes = MappingProxyType(attributes or {})
self.last_updated = last_updated or dt_util.utcnow() self.last_updated = last_updated or dt_util.utcnow()
self.last_changed = last_changed or self.last_updated self.last_changed = last_changed or self.last_updated
self.context = context or Context()
@property @property
def domain(self) -> str: def domain(self) -> str:
@ -626,7 +657,8 @@ class State:
'state': self.state, 'state': self.state,
'attributes': dict(self.attributes), 'attributes': dict(self.attributes),
'last_changed': self.last_changed, 'last_changed': self.last_changed,
'last_updated': self.last_updated} 'last_updated': self.last_updated,
'context': self.context.as_dict()}
@classmethod @classmethod
def from_dict(cls, json_dict: Dict) -> Any: def from_dict(cls, json_dict: Dict) -> Any:
@ -650,8 +682,13 @@ class State:
if isinstance(last_updated, str): if isinstance(last_updated, str):
last_updated = dt_util.parse_datetime(last_updated) last_updated = dt_util.parse_datetime(last_updated)
context = json_dict.get('context')
if context:
context = Context(**context)
return cls(json_dict['entity_id'], json_dict['state'], return cls(json_dict['entity_id'], json_dict['state'],
json_dict.get('attributes'), last_changed, last_updated) json_dict.get('attributes'), last_changed, last_updated,
context)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Return the comparison of the state.""" """Return the comparison of the state."""
@ -662,11 +699,11 @@ class State:
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return the representation of the states.""" """Return the representation of the states."""
attr = "; {}".format(util.repr_helper(self.attributes)) \ attrs = "; {}".format(util.repr_helper(self.attributes)) \
if self.attributes else "" if self.attributes else ""
return "<state {}={}{} @ {}>".format( return "<state {}={}{} @ {}>".format(
self.entity_id, self.state, attr, self.entity_id, self.state, attrs,
dt_util.as_local(self.last_changed).isoformat()) dt_util.as_local(self.last_changed).isoformat())
@ -761,7 +798,8 @@ class StateMachine:
def set(self, entity_id: str, new_state: Any, def set(self, entity_id: str, new_state: Any,
attributes: Optional[Dict] = None, attributes: Optional[Dict] = None,
force_update: bool = False) -> None: force_update: bool = False,
context: Optional[Context] = None) -> None:
"""Set the state of an entity, add entity if it does not exist. """Set the state of an entity, add entity if it does not exist.
Attributes is an optional dict to specify attributes of this state. Attributes is an optional dict to specify attributes of this state.
@ -772,12 +810,14 @@ class StateMachine:
run_callback_threadsafe( run_callback_threadsafe(
self._loop, self._loop,
self.async_set, entity_id, new_state, attributes, force_update, self.async_set, entity_id, new_state, attributes, force_update,
context,
).result() ).result()
@callback @callback
def async_set(self, entity_id: str, new_state: Any, def async_set(self, entity_id: str, new_state: Any,
attributes: Optional[Dict] = None, attributes: Optional[Dict] = None,
force_update: bool = False) -> None: force_update: bool = False,
context: Optional[Context] = None) -> None:
"""Set the state of an entity, add entity if it does not exist. """Set the state of an entity, add entity if it does not exist.
Attributes is an optional dict to specify attributes of this state. Attributes is an optional dict to specify attributes of this state.
@ -804,13 +844,17 @@ class StateMachine:
if same_state and same_attr: if same_state and same_attr:
return return
state = State(entity_id, new_state, attributes, last_changed) if context is None:
context = Context()
state = State(entity_id, new_state, attributes, last_changed, None,
context)
self._states[entity_id] = state self._states[entity_id] = state
self._bus.async_fire(EVENT_STATE_CHANGED, { self._bus.async_fire(EVENT_STATE_CHANGED, {
'entity_id': entity_id, 'entity_id': entity_id,
'old_state': old_state, 'old_state': old_state,
'new_state': state, 'new_state': state,
}) }, EventOrigin.local, context)
class Service: class Service:
@ -818,7 +862,8 @@ class Service:
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction'] __slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None: def __init__(self, func: Callable, schema: Optional[vol.Schema],
context: Optional[Context] = None) -> None:
"""Initialize a service.""" """Initialize a service."""
self.func = func self.func = func
self.schema = schema self.schema = schema
@ -829,23 +874,25 @@ class Service:
class ServiceCall: class ServiceCall:
"""Representation of a call to a service.""" """Representation of a call to a service."""
__slots__ = ['domain', 'service', 'data', 'call_id'] __slots__ = ['domain', 'service', 'data', 'context']
def __init__(self, domain: str, service: str, data: Optional[Dict] = None, def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
call_id: Optional[str] = None) -> None: context: Optional[Context] = None) -> None:
"""Initialize a service call.""" """Initialize a service call."""
self.domain = domain.lower() self.domain = domain.lower()
self.service = service.lower() self.service = service.lower()
self.data = MappingProxyType(data or {}) self.data = MappingProxyType(data or {})
self.call_id = call_id self.context = context or Context()
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return the representation of the service.""" """Return the representation of the service."""
if self.data: if self.data:
return "<ServiceCall {}.{}: {}>".format( return "<ServiceCall {}.{} (c:{}): {}>".format(
self.domain, self.service, util.repr_helper(self.data)) self.domain, self.service, self.context.id,
util.repr_helper(self.data))
return "<ServiceCall {}.{}>".format(self.domain, self.service) return "<ServiceCall {}.{} (c:{})>".format(
self.domain, self.service, self.context.id)
class ServiceRegistry: class ServiceRegistry:
@ -857,15 +904,6 @@ class ServiceRegistry:
self._hass = hass self._hass = hass
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE] self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
def _gen_unique_id() -> Iterator[str]:
cur_id = 1
while True:
yield '{}-{}'.format(id(self), cur_id)
cur_id += 1
gen = _gen_unique_id()
self._generate_unique_id = lambda: next(gen)
@property @property
def services(self) -> Dict[str, Dict[str, Service]]: def services(self) -> Dict[str, Dict[str, Service]]:
"""Return dictionary with per domain a list of available services.""" """Return dictionary with per domain a list of available services."""
@ -957,7 +995,8 @@ class ServiceRegistry:
def call(self, domain: str, service: str, def call(self, domain: str, service: str,
service_data: Optional[Dict] = None, service_data: Optional[Dict] = None,
blocking: bool = False) -> Optional[bool]: blocking: bool = False,
context: Optional[Context] = None) -> Optional[bool]:
""" """
Call a service. Call a service.
@ -975,13 +1014,14 @@ class ServiceRegistry:
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
""" """
return run_coroutine_threadsafe( # type: ignore return run_coroutine_threadsafe( # type: ignore
self.async_call(domain, service, service_data, blocking), self.async_call(domain, service, service_data, blocking, context),
self._hass.loop self._hass.loop
).result() ).result()
async def async_call(self, domain: str, service: str, async def async_call(self, domain: str, service: str,
service_data: Optional[Dict] = None, service_data: Optional[Dict] = None,
blocking: bool = False) -> Optional[bool]: blocking: bool = False,
context: Optional[Context] = None) -> Optional[bool]:
""" """
Call a service. Call a service.
@ -1000,44 +1040,42 @@ class ServiceRegistry:
This method is a coroutine. This method is a coroutine.
""" """
call_id = self._generate_unique_id() context = context or Context()
event_data = { event_data = {
ATTR_DOMAIN: domain.lower(), ATTR_DOMAIN: domain.lower(),
ATTR_SERVICE: service.lower(), ATTR_SERVICE: service.lower(),
ATTR_SERVICE_DATA: service_data, ATTR_SERVICE_DATA: service_data,
ATTR_SERVICE_CALL_ID: call_id,
} }
if blocking: if not blocking:
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future self._hass.bus.async_fire(
EVENT_CALL_SERVICE, event_data, EventOrigin.local, context)
return None
fut = asyncio.Future() # type: asyncio.Future
@callback @callback
def service_executed(event: Event) -> None: def service_executed(event: Event) -> None:
"""Handle an executed service.""" """Handle an executed service."""
if event.data[ATTR_SERVICE_CALL_ID] == call_id: if event.context == context:
fut.set_result(True) fut.set_result(True)
unsub = self._hass.bus.async_listen( unsub = self._hass.bus.async_listen(
EVENT_SERVICE_EXECUTED, service_executed) EVENT_SERVICE_EXECUTED, service_executed)
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data,
EventOrigin.local, context)
done, _ = await asyncio.wait( done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
success = bool(done) success = bool(done)
unsub() unsub()
return success return success
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
return None
async def _event_to_service_call(self, event: Event) -> None: async def _event_to_service_call(self, event: Event) -> None:
"""Handle the SERVICE_CALLED events from the EventBus.""" """Handle the SERVICE_CALLED events from the EventBus."""
service_data = event.data.get(ATTR_SERVICE_DATA) or {} service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
service = event.data.get(ATTR_SERVICE).lower() # type: ignore service = event.data.get(ATTR_SERVICE).lower() # type: ignore
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
if not self.has_service(domain, service): if not self.has_service(domain, service):
if event.origin == EventOrigin.local: if event.origin == EventOrigin.local:
@ -1049,16 +1087,13 @@ class ServiceRegistry:
def fire_service_executed() -> None: def fire_service_executed() -> None:
"""Fire service executed event.""" """Fire service executed event."""
if not call_id:
return
data = {ATTR_SERVICE_CALL_ID: call_id}
if (service_handler.is_coroutinefunction or if (service_handler.is_coroutinefunction or
service_handler.is_callback): service_handler.is_callback):
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data) self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {},
EventOrigin.local, event.context)
else: else:
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data) self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {},
EventOrigin.local, event.context)
try: try:
if service_handler.schema: if service_handler.schema:
@ -1069,7 +1104,8 @@ class ServiceRegistry:
fire_service_executed() fire_service_executed()
return return
service_call = ServiceCall(domain, service, service_data, call_id) service_call = ServiceCall(
domain, service, service_data, event.context)
try: try:
if service_handler.is_callback: if service_handler.is_callback:

View File

@ -179,7 +179,7 @@ class Entity:
# produce undesirable effects in the entity's operation. # produce undesirable effects in the entity's operation.
@asyncio.coroutine @asyncio.coroutine
def async_update_ha_state(self, force_refresh=False): def async_update_ha_state(self, force_refresh=False, context=None):
"""Update Home Assistant with current state of entity. """Update Home Assistant with current state of entity.
If force_refresh == True will update entity before setting state. If force_refresh == True will update entity before setting state.
@ -279,7 +279,7 @@ class Entity:
pass pass
self.hass.states.async_set( self.hass.states.async_set(
self.entity_id, state, attr, self.force_update) self.entity_id, state, attr, self.force_update, context)
def schedule_update_ha_state(self, force_refresh=False): def schedule_update_ha_state(self, force_refresh=False):
"""Schedule an update ha state change task. """Schedule an update ha state change task.

View File

@ -187,7 +187,7 @@ def async_mock_service(hass, domain, service, schema=None):
"""Set up a fake service & return a calls log list to this service.""" """Set up a fake service & return a calls log list to this service."""
calls = [] calls = []
@asyncio.coroutine @ha.callback
def mock_service_log(call): # pylint: disable=unnecessary-lambda def mock_service_log(call): # pylint: disable=unnecessary-lambda
"""Mock service call.""" """Mock service call."""
calls.append(call) calls.append(call)

View File

@ -5,12 +5,12 @@ import unittest.mock as mock
import os import os
from io import StringIO from io import StringIO
from homeassistant.setup import setup_component from homeassistant import core, loader
import homeassistant.loader as loader from homeassistant.setup import setup_component, async_setup_component
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, STATE_ON, STATE_OFF, CONF_PLATFORM, ATTR_ENTITY_ID, STATE_ON, STATE_OFF, CONF_PLATFORM,
SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_SUPPORTED_FEATURES) SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_SUPPORTED_FEATURES)
import homeassistant.components.light as light from homeassistant.components import light
from homeassistant.helpers.intent import IntentHandleError from homeassistant.helpers.intent import IntentHandleError
from tests.common import ( from tests.common import (
@ -475,3 +475,24 @@ async def test_intent_set_color_and_brightness(hass):
assert call.data.get(ATTR_ENTITY_ID) == 'light.hello_2' assert call.data.get(ATTR_ENTITY_ID) == 'light.hello_2'
assert call.data.get(light.ATTR_RGB_COLOR) == (0, 0, 255) assert call.data.get(light.ATTR_RGB_COLOR) == (0, 0, 255)
assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20 assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20
async def test_light_context(hass):
"""Test that light context works."""
assert await async_setup_component(hass, 'light', {
'light': {
'platform': 'test'
}
})
state = hass.states.get('light.ceiling')
assert state is not None
await hass.services.async_call('light', 'toggle', {
'entity_id': state.entity_id,
}, True, core.Context(user_id='abcd'))
state2 = hass.states.get('light.ceiling')
assert state2 is not None
assert state.state != state2.state
assert state2.context.user_id == 'abcd'

View File

@ -2,8 +2,8 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import unittest import unittest
from homeassistant.setup import setup_component from homeassistant.setup import setup_component, async_setup_component
from homeassistant import loader from homeassistant import core, loader
from homeassistant.components import switch from homeassistant.components import switch
from homeassistant.const import STATE_ON, STATE_OFF, CONF_PLATFORM from homeassistant.const import STATE_ON, STATE_OFF, CONF_PLATFORM
@ -91,3 +91,24 @@ class TestSwitch(unittest.TestCase):
'{} 2'.format(switch.DOMAIN): {CONF_PLATFORM: 'test2'}, '{} 2'.format(switch.DOMAIN): {CONF_PLATFORM: 'test2'},
} }
)) ))
async def test_switch_context(hass):
"""Test that switch context works."""
assert await async_setup_component(hass, 'switch', {
'switch': {
'platform': 'test'
}
})
state = hass.states.get('switch.ac')
assert state is not None
await hass.services.async_call('switch', 'toggle', {
'entity_id': state.entity_id,
}, True, core.Context(user_id='abcd'))
state2 = hass.states.get('switch.ac')
assert state2 is not None
assert state.state != state2.state
assert state2.context.user_id == 'abcd'

View File

@ -12,6 +12,8 @@ from homeassistant.bootstrap import DATA_LOGGING
import homeassistant.core as ha import homeassistant.core as ha
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
@pytest.fixture @pytest.fixture
def mock_api_client(hass, aiohttp_client): def mock_api_client(hass, aiohttp_client):
@ -429,3 +431,58 @@ async def test_api_error_log(hass, aiohttp_client):
assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING] assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING]
assert resp.status == 200 assert resp.status == 200
assert await resp.text() == 'Hello' assert await resp.text() == 'Hello'
async def test_api_fire_event_context(hass, mock_api_client,
hass_access_token):
"""Test if the API sets right context if we fire an event."""
test_value = []
@ha.callback
def listener(event):
"""Helper method that will verify our event got called."""
test_value.append(event)
hass.bus.async_listen("test.event", listener)
await mock_api_client.post(
const.URL_API_EVENTS_EVENT.format("test.event"),
headers={
'authorization': 'Bearer {}'.format(hass_access_token.token)
})
await hass.async_block_till_done()
assert len(test_value) == 1
assert test_value[0].context.user_id == \
hass_access_token.refresh_token.user.id
async def test_api_call_service_context(hass, mock_api_client,
hass_access_token):
"""Test if the API sets right context if we call a service."""
calls = async_mock_service(hass, 'test_domain', 'test_service')
await mock_api_client.post(
'/api/services/test_domain/test_service',
headers={
'authorization': 'Bearer {}'.format(hass_access_token.token)
})
await hass.async_block_till_done()
assert len(calls) == 1
assert calls[0].context.user_id == hass_access_token.refresh_token.user.id
async def test_api_set_state_context(hass, mock_api_client, hass_access_token):
"""Test if the API sets right context if we set state."""
await mock_api_client.post(
'/api/states/light.kitchen',
json={
'state': 'on'
},
headers={
'authorization': 'Bearer {}'.format(hass_access_token.token)
})
state = hass.states.get('light.kitchen')
assert state.context.user_id == hass_access_token.refresh_token.user.id

View File

@ -104,12 +104,14 @@ class TestMqttEventStream:
"state": "on", "state": "on",
"entity_id": e_id, "entity_id": e_id,
"attributes": {}, "attributes": {},
"last_changed": now.isoformat() "last_changed": now.isoformat(),
} }
event['event_data'] = {"new_state": new_state, "entity_id": e_id} event['event_data'] = {"new_state": new_state, "entity_id": e_id}
# Verify that the message received was that expected # Verify that the message received was that expected
assert json.loads(msg) == event result = json.loads(msg)
result['event_data']['new_state'].pop('context')
assert result == event
@patch('homeassistant.components.mqtt.async_publish') @patch('homeassistant.components.mqtt.async_publish')
def test_time_event_does_not_send_message(self, mock_pub): def test_time_event_does_not_send_message(self, mock_pub):

View File

@ -10,7 +10,7 @@ from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi from homeassistant.components import websocket_api as wapi
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import mock_coro from tests.common import mock_coro, async_mock_service
API_PASSWORD = 'test1234' API_PASSWORD = 'test1234'
@ -443,3 +443,94 @@ async def test_auth_with_invalid_token(hass, aiohttp_client):
auth_msg = await ws.receive_json() auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_call_service_context_with_user(hass, aiohttp_client,
hass_access_token):
"""Test that the user is set in the service call context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token.token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id == hass_access_token.refresh_token.user.id
async def test_call_service_context_no_user(hass, aiohttp_client):
"""Test that connection without user sets context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id is None

View File

@ -163,10 +163,10 @@ def test_zwave_ready_wait(hass, mock_openzwave):
asyncio_sleep = asyncio.sleep asyncio_sleep = asyncio.sleep
@asyncio.coroutine @asyncio.coroutine
def sleep(duration, loop): def sleep(duration, loop=None):
if duration > 0: if duration > 0:
sleeps.append(duration) sleeps.append(duration)
yield from asyncio_sleep(0, loop=loop) yield from asyncio_sleep(0)
with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow): with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow):
with patch('asyncio.sleep', new=sleep): with patch('asyncio.sleep', new=sleep):
@ -248,10 +248,10 @@ async def test_unparsed_node_discovery(hass, mock_openzwave):
asyncio_sleep = asyncio.sleep asyncio_sleep = asyncio.sleep
async def sleep(duration, loop): async def sleep(duration, loop=None):
if duration > 0: if duration > 0:
sleeps.append(duration) sleeps.append(duration)
await asyncio_sleep(0, loop=loop) await asyncio_sleep(0)
with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow): with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow):
with patch('asyncio.sleep', new=sleep): with patch('asyncio.sleep', new=sleep):

View File

@ -277,6 +277,10 @@ class TestEvent(unittest.TestCase):
'data': data, 'data': data,
'origin': 'LOCAL', 'origin': 'LOCAL',
'time_fired': now, 'time_fired': now,
'context': {
'id': event.context.id,
'user_id': event.context.user_id,
},
} }
self.assertEqual(expected, event.as_dict()) self.assertEqual(expected, event.as_dict())
@ -598,18 +602,16 @@ class TestStateMachine(unittest.TestCase):
self.assertEqual(1, len(events)) self.assertEqual(1, len(events))
class TestServiceCall(unittest.TestCase): def test_service_call_repr():
"""Test ServiceCall class.""" """Test ServiceCall repr."""
call = ha.ServiceCall('homeassistant', 'start')
assert str(call) == \
"<ServiceCall homeassistant.start (c:{})>".format(call.context.id)
def test_repr(self): call2 = ha.ServiceCall('homeassistant', 'start', {'fast': 'yes'})
"""Test repr method.""" assert str(call2) == \
self.assertEqual( "<ServiceCall homeassistant.start (c:{}): fast=yes>".format(
"<ServiceCall homeassistant.start>", call2.context.id)
str(ha.ServiceCall('homeassistant', 'start')))
self.assertEqual(
"<ServiceCall homeassistant.start: fast=yes>",
str(ha.ServiceCall('homeassistant', 'start', {"fast": "yes"})))
class TestServiceRegistry(unittest.TestCase): class TestServiceRegistry(unittest.TestCase):