Move thread safety check in async_register/async_remove (#116077)

This commit is contained in:
J. Nick Koston 2024-04-24 10:41:11 +02:00 committed by GitHub
parent 5bded2a52d
commit e0b58c3f45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 4 deletions

View File

@ -2456,7 +2456,7 @@ class ServiceRegistry:
"""
run_callback_threadsafe(
self._hass.loop,
self.async_register,
self._async_register,
domain,
service,
service_func,
@ -2484,6 +2484,33 @@ class ServiceRegistry:
Schema is called to coerce and validate the service data.
This method must be run in the event loop.
"""
self._hass.verify_event_loop_thread("async_register")
self._async_register(
domain, service, service_func, schema, supports_response, job_type
)
@callback
def _async_register(
self,
domain: str,
service: str,
service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse]
| ServiceResponse
| EntityServiceResponse
| None,
],
schema: vol.Schema | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
job_type: HassJobType | None = None,
) -> None:
"""Register a service.
Schema is called to coerce and validate the service data.
This method must be run in the event loop.
"""
domain = domain.lower()
@ -2502,20 +2529,29 @@ class ServiceRegistry:
else:
self._services[domain] = {service: service_obj}
self._hass.bus.async_fire(
self._hass.bus.async_fire_internal(
EVENT_SERVICE_REGISTERED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service}
)
def remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler."""
run_callback_threadsafe(
self._hass.loop, self.async_remove, domain, service
self._hass.loop, self._async_remove, domain, service
).result()
@callback
def async_remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler.
This method must be run in the event loop.
"""
self._hass.verify_event_loop_thread("async_remove")
self._async_remove(domain, service)
@callback
def _async_remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler.
This method must be run in the event loop.
"""
domain = domain.lower()
@ -2530,7 +2566,7 @@ class ServiceRegistry:
if not self._services[domain]:
self._services.pop(domain)
self._hass.bus.async_fire(
self._hass.bus.async_fire_internal(
EVENT_SERVICE_REMOVED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service}
)

View File

@ -3457,3 +3457,26 @@ async def test_async_fire_thread_safety(hass: HomeAssistant) -> None:
await hass.async_add_executor_job(hass.bus.async_fire, "test_event")
assert len(events) == 1
async def test_async_register_thread_safety(hass: HomeAssistant) -> None:
"""Test async_register thread safety."""
with pytest.raises(
RuntimeError, match="Detected code that calls async_register from a thread."
):
await hass.async_add_executor_job(
hass.services.async_register,
"test_domain",
"test_service",
lambda call: None,
)
async def test_async_remove_thread_safety(hass: HomeAssistant) -> None:
"""Test async_remove thread safety."""
with pytest.raises(
RuntimeError, match="Detected code that calls async_remove from a thread."
):
await hass.async_add_executor_job(
hass.services.async_remove, "test_domain", "test_service"
)