Fix bluetooth callback matchers when only matching on connectable (#78687)

This commit is contained in:
J. Nick Koston 2022-09-18 10:22:54 -05:00 committed by GitHub
parent 4d6151666e
commit d4181aa911
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 19 deletions

View File

@ -410,11 +410,11 @@ class BluetoothManager:
callback_matcher[CONNECTABLE] = matcher.get(CONNECTABLE, True)
connectable = callback_matcher[CONNECTABLE]
self._callback_index.add_with_address(callback_matcher)
self._callback_index.add_callback_matcher(callback_matcher)
@hass_callback
def _async_remove_callback() -> None:
self._callback_index.remove_with_address(callback_matcher)
self._callback_index.remove_callback_matcher(callback_matcher)
# If we have history for the subscriber, we can trigger the callback
# immediately with the last packet so the subscriber can see the

View File

@ -173,7 +173,7 @@ class BluetoothMatcherIndexBase(Generic[_T]):
self.service_data_uuid_set: set[str] = set()
self.manufacturer_id_set: set[int] = set()
def add(self, matcher: _T) -> None:
def add(self, matcher: _T) -> bool:
"""Add a matcher to the index.
Matchers must end up only in one bucket.
@ -185,26 +185,28 @@ class BluetoothMatcherIndexBase(Generic[_T]):
self.local_name.setdefault(
_local_name_to_index_key(matcher[LOCAL_NAME]), []
).append(matcher)
return
return True
# Manufacturer data is 2nd cheapest since its all ints
if MANUFACTURER_ID in matcher:
self.manufacturer_id.setdefault(matcher[MANUFACTURER_ID], []).append(
matcher
)
return
return True
if SERVICE_UUID in matcher:
self.service_uuid.setdefault(matcher[SERVICE_UUID], []).append(matcher)
return
return True
if SERVICE_DATA_UUID in matcher:
self.service_data_uuid.setdefault(matcher[SERVICE_DATA_UUID], []).append(
matcher
)
return
return True
def remove(self, matcher: _T) -> None:
return False
def remove(self, matcher: _T) -> bool:
"""Remove a matcher from the index.
Matchers only end up in one bucket, so once we have
@ -214,19 +216,21 @@ class BluetoothMatcherIndexBase(Generic[_T]):
self.local_name[_local_name_to_index_key(matcher[LOCAL_NAME])].remove(
matcher
)
return
return True
if MANUFACTURER_ID in matcher:
self.manufacturer_id[matcher[MANUFACTURER_ID]].remove(matcher)
return
return True
if SERVICE_UUID in matcher:
self.service_uuid[matcher[SERVICE_UUID]].remove(matcher)
return
return True
if SERVICE_DATA_UUID in matcher:
self.service_data_uuid[matcher[SERVICE_DATA_UUID]].remove(matcher)
return
return True
return False
def build(self) -> None:
"""Rebuild the index sets."""
@ -284,8 +288,11 @@ class BluetoothCallbackMatcherIndex(
"""Initialize the matcher index."""
super().__init__()
self.address: dict[str, list[BluetoothCallbackMatcherWithCallback]] = {}
self.connectable: list[BluetoothCallbackMatcherWithCallback] = []
def add_with_address(self, matcher: BluetoothCallbackMatcherWithCallback) -> None:
def add_callback_matcher(
self, matcher: BluetoothCallbackMatcherWithCallback
) -> None:
"""Add a matcher to the index.
Matchers must end up only in one bucket.
@ -296,10 +303,15 @@ class BluetoothCallbackMatcherIndex(
self.address.setdefault(matcher[ADDRESS], []).append(matcher)
return
super().add(matcher)
self.build()
if super().add(matcher):
self.build()
return
def remove_with_address(
if CONNECTABLE in matcher:
self.connectable.append(matcher)
return
def remove_callback_matcher(
self, matcher: BluetoothCallbackMatcherWithCallback
) -> None:
"""Remove a matcher from the index.
@ -311,8 +323,13 @@ class BluetoothCallbackMatcherIndex(
self.address[matcher[ADDRESS]].remove(matcher)
return
super().remove(matcher)
self.build()
if super().remove(matcher):
self.build()
return
if CONNECTABLE in matcher:
self.connectable.remove(matcher)
return
def match_callbacks(
self, service_info: BluetoothServiceInfoBleak
@ -322,6 +339,9 @@ class BluetoothCallbackMatcherIndex(
for matcher in self.address.get(service_info.address, []):
if ble_device_matches(matcher, service_info):
matches.append(matcher)
for matcher in self.connectable:
if ble_device_matches(matcher, service_info):
matches.append(matcher)
return matches
@ -355,7 +375,6 @@ def ble_device_matches(
# Don't check address here since all callers already
# check the address and we don't want to double check
# since it would result in an unreachable reject case.
if matcher.get(CONNECTABLE, True) and not service_info.connectable:
return False

View File

@ -1327,6 +1327,61 @@ async def test_register_callback_by_manufacturer_id(
assert service_info.manufacturer_id == 21
async def test_register_callback_by_connectable(
hass, mock_bleak_scanner_start, enable_bluetooth
):
"""Test registering a callback by connectable."""
mock_bt = []
callbacks = []
def _fake_subscriber(
service_info: BluetoothServiceInfo, change: BluetoothChange
) -> None:
"""Fake subscriber for the BleakScanner."""
callbacks.append((service_info, change))
with patch(
"homeassistant.components.bluetooth.async_get_bluetooth", return_value=mock_bt
):
await async_setup_with_default_adapter(hass)
with patch.object(hass.config_entries.flow, "async_init"):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
cancel = bluetooth.async_register_callback(
hass,
_fake_subscriber,
{CONNECTABLE: False},
BluetoothScanningMode.ACTIVE,
)
assert len(mock_bleak_scanner_start.mock_calls) == 1
apple_device = BLEDevice("44:44:33:11:23:45", "rtx")
apple_adv = AdvertisementData(
local_name="rtx",
manufacturer_data={7676: b"\xd8.\xad\xcd\r\x85"},
)
inject_advertisement(hass, apple_device, apple_adv)
empty_device = BLEDevice("11:22:33:44:55:66", "empty")
empty_adv = AdvertisementData(local_name="empty")
inject_advertisement(hass, empty_device, empty_adv)
await hass.async_block_till_done()
cancel()
assert len(callbacks) == 2
service_info: BluetoothServiceInfo = callbacks[0][0]
assert service_info.name == "rtx"
service_info: BluetoothServiceInfo = callbacks[1][0]
assert service_info.name == "empty"
async def test_not_filtering_wanted_apple_devices(
hass, mock_bleak_scanner_start, enable_bluetooth
):