1
mirror of https://github.com/home-assistant/core synced 2024-07-09 04:58:30 +02:00

Fix async_get_scanner to return the correct bluetooth scanner (#75637)

This commit is contained in:
J. Nick Koston 2022-07-22 18:12:08 -05:00 committed by GitHub
parent cb543a21b3
commit 326e05dcf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 52 deletions

View File

@ -35,7 +35,7 @@ from homeassistant.loader import (
from . import models
from .const import DOMAIN
from .models import HaBleakScanner
from .models import HaBleakScanner, HaBleakScannerWrapper
from .usage import install_multiple_bleak_catcher, uninstall_multiple_bleak_catcher
_LOGGER = logging.getLogger(__name__)
@ -117,8 +117,12 @@ BluetoothCallback = Callable[
@hass_callback
def async_get_scanner(hass: HomeAssistant) -> HaBleakScanner:
"""Return a HaBleakScanner."""
def async_get_scanner(hass: HomeAssistant) -> HaBleakScannerWrapper:
"""Return a HaBleakScannerWrapper.
This is a wrapper around our BleakScanner singleton that allows
multiple integrations to share the same BleakScanner.
"""
if DOMAIN not in hass.data:
raise RuntimeError("Bluetooth integration not loaded")
manager: BluetoothManager = hass.data[DOMAIN]
@ -320,10 +324,9 @@ class BluetoothManager:
models.HA_BLEAK_SCANNER = self.scanner = HaBleakScanner()
@hass_callback
def async_get_scanner(self) -> HaBleakScanner:
def async_get_scanner(self) -> HaBleakScannerWrapper:
"""Get the scanner."""
assert self.scanner is not None
return self.scanner
return HaBleakScannerWrapper()
async def async_start(self, scanning_mode: BluetoothScanningMode) -> None:
"""Set up BT Discovery."""

View File

@ -1 +1,8 @@
"""Tests for the Bluetooth integration."""
from homeassistant.components.bluetooth import models
def _get_underlying_scanner():
"""Return the underlying scanner that has been wrapped."""
return models.HA_BLEAK_SCANNER

View File

@ -12,7 +12,6 @@ from homeassistant.components.bluetooth import (
UNAVAILABLE_TRACK_SECONDS,
BluetoothChange,
BluetoothServiceInfo,
async_get_scanner,
async_track_unavailable,
models,
)
@ -21,6 +20,8 @@ from homeassistant.core import callback
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from . import _get_underlying_scanner
from tests.common import MockConfigEntry, async_fire_time_changed
@ -135,7 +136,7 @@ async def test_discovery_match_by_service_uuid(
wrong_device = BLEDevice("44:44:33:11:23:45", "wrong_name")
wrong_adv = AdvertisementData(local_name="wrong_name", service_uuids=[])
async_get_scanner(hass)._callback(wrong_device, wrong_adv)
_get_underlying_scanner()._callback(wrong_device, wrong_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0
@ -145,7 +146,7 @@ async def test_discovery_match_by_service_uuid(
local_name="wohand", service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]
)
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1
@ -172,7 +173,7 @@ async def test_discovery_match_by_local_name(hass, mock_bleak_scanner_start):
wrong_device = BLEDevice("44:44:33:11:23:45", "wrong_name")
wrong_adv = AdvertisementData(local_name="wrong_name", service_uuids=[])
async_get_scanner(hass)._callback(wrong_device, wrong_adv)
_get_underlying_scanner()._callback(wrong_device, wrong_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0
@ -180,7 +181,7 @@ async def test_discovery_match_by_local_name(hass, mock_bleak_scanner_start):
switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand")
switchbot_adv = AdvertisementData(local_name="wohand", service_uuids=[])
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1
@ -219,7 +220,7 @@ async def test_discovery_match_by_manufacturer_id_and_first_byte(
manufacturer_data={76: b"\x06\x02\x03\x99"},
)
async_get_scanner(hass)._callback(hkc_device, hkc_adv)
_get_underlying_scanner()._callback(hkc_device, hkc_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 1
@ -227,7 +228,7 @@ async def test_discovery_match_by_manufacturer_id_and_first_byte(
mock_config_flow.reset_mock()
# 2nd discovery should not generate another flow
async_get_scanner(hass)._callback(hkc_device, hkc_adv)
_get_underlying_scanner()._callback(hkc_device, hkc_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0
@ -238,7 +239,7 @@ async def test_discovery_match_by_manufacturer_id_and_first_byte(
local_name="lock", service_uuids=[], manufacturer_data={76: b"\x02"}
)
async_get_scanner(hass)._callback(not_hkc_device, not_hkc_adv)
_get_underlying_scanner()._callback(not_hkc_device, not_hkc_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0
@ -247,7 +248,7 @@ async def test_discovery_match_by_manufacturer_id_and_first_byte(
local_name="lock", service_uuids=[], manufacturer_data={21: b"\x02"}
)
async_get_scanner(hass)._callback(not_apple_device, not_apple_adv)
_get_underlying_scanner()._callback(not_apple_device, not_apple_adv)
await hass.async_block_till_done()
assert len(mock_config_flow.mock_calls) == 0
@ -279,10 +280,10 @@ async def test_async_discovered_device_api(hass, mock_bleak_scanner_start):
wrong_device = BLEDevice("44:44:33:11:23:42", "wrong_name")
wrong_adv = AdvertisementData(local_name="wrong_name", service_uuids=[])
async_get_scanner(hass)._callback(wrong_device, wrong_adv)
_get_underlying_scanner()._callback(wrong_device, wrong_adv)
switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand")
switchbot_adv = AdvertisementData(local_name="wohand", service_uuids=[])
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
wrong_device_went_unavailable = False
switchbot_device_went_unavailable = False
@ -316,8 +317,8 @@ async def test_async_discovered_device_api(hass, mock_bleak_scanner_start):
assert wrong_device_went_unavailable is True
# See the devices again
async_get_scanner(hass)._callback(wrong_device, wrong_adv)
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(wrong_device, wrong_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
# Cancel the callbacks
wrong_device_unavailable_cancel()
switchbot_device_unavailable_cancel()
@ -382,25 +383,25 @@ async def test_register_callbacks(hass, mock_bleak_scanner_start, enable_bluetoo
service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"},
)
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
await hass.async_block_till_done()
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
# 3rd callback raises ValueError but is still tracked
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
await hass.async_block_till_done()
cancel()
# 4th callback should not be tracked since we canceled
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
await hass.async_block_till_done()
assert len(callbacks) == 3
@ -467,25 +468,25 @@ async def test_register_callback_by_address(
service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"},
)
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
await hass.async_block_till_done()
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
# 3rd callback raises ValueError but is still tracked
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
await hass.async_block_till_done()
cancel()
# 4th callback should not be tracked since we canceled
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
await hass.async_block_till_done()
# Now register again with a callback that fails to
@ -549,15 +550,15 @@ async def test_wrapped_instance_with_filter(
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
assert async_get_scanner(hass) is not None
assert _get_underlying_scanner() is not None
scanner = models.HaBleakScannerWrapper(
filters={"UUIDs": ["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]}
)
scanner.register_detection_callback(_device_detected)
mock_discovered = [MagicMock()]
type(async_get_scanner(hass)).discovered_devices = mock_discovered
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
type(_get_underlying_scanner()).discovered_devices = mock_discovered
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
discovered = await scanner.discover(timeout=0)
@ -572,17 +573,17 @@ async def test_wrapped_instance_with_filter(
# We should get a reply from the history when we register again
assert len(detected) == 3
type(async_get_scanner(hass)).discovered_devices = []
type(_get_underlying_scanner()).discovered_devices = []
discovered = await scanner.discover(timeout=0)
assert len(discovered) == 0
assert discovered == []
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
assert len(detected) == 4
# The filter we created in the wrapped scanner with should be respected
# and we should not get another callback
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
assert len(detected) == 4
@ -620,22 +621,22 @@ async def test_wrapped_instance_with_service_uuids(
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
assert async_get_scanner(hass) is not None
assert _get_underlying_scanner() is not None
scanner = models.HaBleakScannerWrapper(
service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]
)
scanner.register_detection_callback(_device_detected)
type(async_get_scanner(hass)).discovered_devices = [MagicMock()]
type(_get_underlying_scanner()).discovered_devices = [MagicMock()]
for _ in range(2):
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 2
# The UUIDs list we created in the wrapped scanner with should be respected
# and we should not get another callback
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
assert len(detected) == 2
@ -673,15 +674,15 @@ async def test_wrapped_instance_with_broken_callbacks(
service_data={"00000d00-0000-1000-8000-00805f9b34fb": b"H\x10c"},
)
assert async_get_scanner(hass) is not None
assert _get_underlying_scanner() is not None
scanner = models.HaBleakScannerWrapper(
service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]
)
scanner.register_detection_callback(_device_detected)
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 1
@ -719,23 +720,23 @@ async def test_wrapped_instance_changes_uuids(
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
assert async_get_scanner(hass) is not None
assert _get_underlying_scanner() is not None
scanner = models.HaBleakScannerWrapper()
scanner.set_scanning_filter(
service_uuids=["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]
)
scanner.register_detection_callback(_device_detected)
type(async_get_scanner(hass)).discovered_devices = [MagicMock()]
type(_get_underlying_scanner()).discovered_devices = [MagicMock()]
for _ in range(2):
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 2
# The UUIDs list we created in the wrapped scanner with should be respected
# and we should not get another callback
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
assert len(detected) == 2
@ -772,23 +773,23 @@ async def test_wrapped_instance_changes_filters(
empty_device = BLEDevice("11:22:33:44:55:62", "empty")
empty_adv = AdvertisementData(local_name="empty")
assert async_get_scanner(hass) is not None
assert _get_underlying_scanner() is not None
scanner = models.HaBleakScannerWrapper()
scanner.set_scanning_filter(
filters={"UUIDs": ["cba20d00-224d-11e6-9fb8-0002a5d5c51b"]}
)
scanner.register_detection_callback(_device_detected)
type(async_get_scanner(hass)).discovered_devices = [MagicMock()]
type(_get_underlying_scanner()).discovered_devices = [MagicMock()]
for _ in range(2):
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert len(detected) == 2
# The UUIDs list we created in the wrapped scanner with should be respected
# and we should not get another callback
async_get_scanner(hass)._callback(empty_device, empty_adv)
_get_underlying_scanner()._callback(empty_device, empty_adv)
assert len(detected) == 2
@ -807,7 +808,7 @@ async def test_wrapped_instance_unsupported_filter(
with patch.object(hass.config_entries.flow, "async_init"):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert async_get_scanner(hass) is not None
assert _get_underlying_scanner() is not None
scanner = models.HaBleakScannerWrapper()
scanner.set_scanning_filter(
filters={
@ -845,7 +846,7 @@ async def test_async_ble_device_from_address(hass, mock_bleak_scanner_start):
switchbot_device = BLEDevice("44:44:33:11:23:45", "wohand")
switchbot_adv = AdvertisementData(local_name="wohand", service_uuids=[])
async_get_scanner(hass)._callback(switchbot_device, switchbot_adv)
_get_underlying_scanner()._callback(switchbot_device, switchbot_adv)
await hass.async_block_till_done()
assert (
@ -935,3 +936,9 @@ async def test_raising_runtime_error_when_no_bluetooth(hass):
"""Test we raise an exception if we try to get the scanner when its not there."""
with pytest.raises(RuntimeError):
bluetooth.async_get_scanner(hass)
async def test_getting_the_scanner_returns_the_wrapped_instance(hass, enable_bluetooth):
"""Test getting the scanner returns the wrapped instance."""
scanner = bluetooth.async_get_scanner(hass)
assert isinstance(scanner, models.HaBleakScannerWrapper)

View File

@ -12,7 +12,6 @@ from homeassistant.components.bluetooth import (
DOMAIN,
UNAVAILABLE_TRACK_SECONDS,
BluetoothChange,
async_get_scanner,
)
from homeassistant.components.bluetooth.passive_update_coordinator import (
PassiveBluetoothCoordinatorEntity,
@ -27,6 +26,8 @@ from homeassistant.helpers.entity import DeviceInfo
from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util
from . import _get_underlying_scanner
from tests.common import MockEntityPlatform, async_fire_time_changed
_LOGGER = logging.getLogger(__name__)
@ -208,7 +209,7 @@ async def test_unavailable_after_no_data(hass, mock_bleak_scanner_start):
saved_callback(GENERIC_BLUETOOTH_SERVICE_INFO, BluetoothChange.ADVERTISEMENT)
assert len(mock_add_entities.mock_calls) == 1
assert coordinator.available is True
scanner = async_get_scanner(hass)
scanner = _get_underlying_scanner()
with patch(
"homeassistant.components.bluetooth.models.HaBleakScanner.discovered_devices",