1
mirror of https://github.com/home-assistant/core synced 2024-09-25 00:41:32 +02:00

Use assignment expressions 01 (#56394)

This commit is contained in:
Marc Mueller 2021-09-19 01:31:35 +02:00 committed by GitHub
parent a4f6c3336f
commit 7af67d34cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 73 additions and 171 deletions

View File

@ -132,16 +132,14 @@ def get_arguments() -> argparse.Namespace:
def daemonize() -> None:
"""Move current process to daemon process."""
# Create first fork
pid = os.fork()
if pid > 0:
if os.fork() > 0:
sys.exit(0)
# Decouple fork
os.setsid()
# Create second fork
pid = os.fork()
if pid > 0:
if os.fork() > 0:
sys.exit(0)
# redirect standard file descriptors to devnull

View File

@ -341,8 +341,7 @@ class AuthManager:
"System generated users cannot enable multi-factor auth module."
)
module = self.get_auth_mfa_module(mfa_module_id)
if module is None:
if (module := self.get_auth_mfa_module(mfa_module_id)) is None:
raise ValueError(f"Unable find multi-factor auth module: {mfa_module_id}")
await module.async_setup_user(user.id, data)
@ -356,8 +355,7 @@ class AuthManager:
"System generated users cannot disable multi-factor auth module."
)
module = self.get_auth_mfa_module(mfa_module_id)
if module is None:
if (module := self.get_auth_mfa_module(mfa_module_id)) is None:
raise ValueError(f"Unable find multi-factor auth module: {mfa_module_id}")
await module.async_depose_user(user.id)
@ -498,8 +496,7 @@ class AuthManager:
Will raise InvalidAuthError on errors.
"""
provider = self._async_resolve_provider(refresh_token)
if provider:
if provider := self._async_resolve_provider(refresh_token):
provider.async_validate_refresh_token(refresh_token, remote_ip)
async def async_validate_access_token(

View File

@ -96,8 +96,7 @@ class AuthStore:
groups = []
for group_id in group_ids or []:
group = self._groups.get(group_id)
if group is None:
if (group := self._groups.get(group_id)) is None:
raise ValueError(f"Invalid group specified {group_id}")
groups.append(group)
@ -160,8 +159,7 @@ class AuthStore:
if group_ids is not None:
groups = []
for grid in group_ids:
group = self._groups.get(grid)
if group is None:
if (group := self._groups.get(grid)) is None:
raise ValueError("Invalid group specified.")
groups.append(group)
@ -446,16 +444,14 @@ class AuthStore:
)
continue
token_type = rt_dict.get("token_type")
if token_type is None:
if (token_type := rt_dict.get("token_type")) is None:
if rt_dict["client_id"] is None:
token_type = models.TOKEN_TYPE_SYSTEM
else:
token_type = models.TOKEN_TYPE_NORMAL
# old refresh_token don't have last_used_at (pre-0.78)
last_used_at_str = rt_dict.get("last_used_at")
if last_used_at_str:
if last_used_at_str := rt_dict.get("last_used_at"):
last_used_at = dt_util.parse_datetime(last_used_at_str)
else:
last_used_at = None

View File

@ -118,9 +118,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
if self._user_settings is not None:
return
data = await self._user_store.async_load()
if data is None:
if (data := await self._user_store.async_load()) is None:
data = {STORAGE_USERS: {}}
self._user_settings = {
@ -207,8 +205,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
await self._async_load()
assert self._user_settings is not None
notify_setting = self._user_settings.get(user_id)
if notify_setting is None:
if (notify_setting := self._user_settings.get(user_id)) is None:
return False
# user_input has been validate in caller
@ -225,8 +222,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
await self._async_load()
assert self._user_settings is not None
notify_setting = self._user_settings.get(user_id)
if notify_setting is None:
if (notify_setting := self._user_settings.get(user_id)) is None:
raise ValueError("Cannot find user_id")
def generate_secret_and_one_time_password() -> str:

View File

@ -92,9 +92,7 @@ class TotpAuthModule(MultiFactorAuthModule):
if self._users is not None:
return
data = await self._user_store.async_load()
if data is None:
if (data := await self._user_store.async_load()) is None:
data = {STORAGE_USERS: {}}
self._users = data.get(STORAGE_USERS, {})
@ -163,8 +161,7 @@ class TotpAuthModule(MultiFactorAuthModule):
"""Validate two factor authentication code."""
import pyotp # pylint: disable=import-outside-toplevel
ota_secret = self._users.get(user_id) # type: ignore
if ota_secret is None:
if (ota_secret := self._users.get(user_id)) is None: # type: ignore
# even we cannot find user, we still do verify
# to make timing the same as if user was found.
pyotp.TOTP(DUMMY_SECRET).verify(code, valid_window=1)

View File

@ -33,9 +33,7 @@ class AbstractPermissions:
def check_entity(self, entity_id: str, key: str) -> bool:
"""Check if we can access entity."""
entity_func = self._cached_entity_func
if entity_func is None:
if (entity_func := self._cached_entity_func) is None:
entity_func = self._cached_entity_func = self._entity_func()
return entity_func(entity_id, key)

View File

@ -72,8 +72,7 @@ def compile_policy(
def apply_policy_funcs(object_id: str, key: str) -> bool:
"""Apply several policy functions."""
for func in funcs:
result = func(object_id, key)
if result is not None:
if (result := func(object_id, key)) is not None:
return result
return False

View File

@ -169,9 +169,7 @@ async def load_auth_provider_module(
if hass.config.skip_pip or not hasattr(module, "REQUIREMENTS"):
return module
processed = hass.data.get(DATA_REQS)
if processed is None:
if (processed := hass.data.get(DATA_REQS)) is None:
processed = hass.data[DATA_REQS] = set()
elif provider in processed:
return module

View File

@ -82,9 +82,7 @@ class Data:
async def async_load(self) -> None:
"""Load stored data."""
data = await self._store.async_load()
if data is None:
if (data := await self._store.async_load()) is None:
data = {"users": []}
seen: set[str] = set()
@ -93,9 +91,7 @@ class Data:
username = user["username"]
# check if we have duplicates
folded = username.casefold()
if folded in seen:
if (folded := username.casefold()) in seen:
self.is_legacy = True
logging.getLogger(__name__).warning(

View File

@ -109,9 +109,8 @@ async def async_setup_hass(
config_dict = None
basic_setup_success = False
safe_mode = runtime_config.safe_mode
if not safe_mode:
if not (safe_mode := runtime_config.safe_mode):
await hass.async_add_executor_job(conf_util.process_ha_config_upgrade, hass)
try:
@ -368,8 +367,7 @@ async def async_mount_local_lib_path(config_dir: str) -> str:
This function is a coroutine.
"""
deps_dir = os.path.join(config_dir, "deps")
lib_dir = await async_get_user_site(deps_dir)
if lib_dir not in sys.path:
if (lib_dir := await async_get_user_site(deps_dir)) not in sys.path:
sys.path.insert(0, lib_dir)
return deps_dir
@ -494,17 +492,13 @@ async def _async_set_up_integrations(
_LOGGER.info("Domains to be set up: %s", domains_to_setup)
logging_domains = domains_to_setup & LOGGING_INTEGRATIONS
# Load logging as soon as possible
if logging_domains:
if logging_domains := domains_to_setup & LOGGING_INTEGRATIONS:
_LOGGER.info("Setting up logging: %s", logging_domains)
await async_setup_multi_components(hass, logging_domains, config)
# Start up debuggers. Start these first in case they want to wait.
debuggers = domains_to_setup & DEBUGGER_INTEGRATIONS
if debuggers:
if debuggers := domains_to_setup & DEBUGGER_INTEGRATIONS:
_LOGGER.debug("Setting up debuggers: %s", debuggers)
await async_setup_multi_components(hass, debuggers, config)
@ -524,9 +518,7 @@ async def _async_set_up_integrations(
stage_1_domains.add(domain)
dep_itg = integration_cache.get(domain)
if dep_itg is None:
if (dep_itg := integration_cache.get(domain)) is None:
continue
deps_promotion.update(dep_itg.all_dependencies)

View File

@ -512,9 +512,7 @@ async def async_process_ha_core_config(hass: HomeAssistant, config: dict) -> Non
# Only load auth during startup.
if not hasattr(hass, "auth"):
auth_conf = config.get(CONF_AUTH_PROVIDERS)
if auth_conf is None:
if (auth_conf := config.get(CONF_AUTH_PROVIDERS)) is None:
auth_conf = [{"type": "homeassistant"}]
mfa_conf = config.get(
@ -598,9 +596,7 @@ async def async_process_ha_core_config(hass: HomeAssistant, config: dict) -> Non
cust_glob = OrderedDict(config[CONF_CUSTOMIZE_GLOB])
for name, pkg in config[CONF_PACKAGES].items():
pkg_cust = pkg.get(CONF_CORE)
if pkg_cust is None:
if (pkg_cust := pkg.get(CONF_CORE)) is None:
continue
try:
@ -957,9 +953,7 @@ def async_notify_setup_error(
# pylint: disable=import-outside-toplevel
from homeassistant.components import persistent_notification
errors = hass.data.get(DATA_PERSISTENT_ERRORS)
if errors is None:
if (errors := hass.data.get(DATA_PERSISTENT_ERRORS)) is None:
errors = hass.data[DATA_PERSISTENT_ERRORS] = {}
errors[component] = errors.get(component) or display_link

View File

@ -492,8 +492,7 @@ class ConfigEntry:
Returns True if config entry is up-to-date or has been migrated.
"""
handler = HANDLERS.get(self.domain)
if handler is None:
if (handler := HANDLERS.get(self.domain)) is None:
_LOGGER.error(
"Flow handler not found for entry %s for %s", self.title, self.domain
)
@ -716,9 +715,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager):
)
raise data_entry_flow.UnknownHandler
handler = HANDLERS.get(handler_key)
if handler is None:
if (handler := HANDLERS.get(handler_key)) is None:
raise data_entry_flow.UnknownHandler
if not context or "source" not in context:
@ -814,9 +811,7 @@ class ConfigEntries:
async def async_remove(self, entry_id: str) -> dict[str, Any]:
"""Remove an entry."""
entry = self.async_get_entry(entry_id)
if entry is None:
if (entry := self.async_get_entry(entry_id)) is None:
raise UnknownEntry
if not entry.state.recoverable:
@ -933,9 +928,7 @@ class ConfigEntries:
Return True if entry has been successfully loaded.
"""
entry = self.async_get_entry(entry_id)
if entry is None:
if (entry := self.async_get_entry(entry_id)) is None:
raise UnknownEntry
if entry.state is not ConfigEntryState.NOT_LOADED:
@ -957,9 +950,7 @@ class ConfigEntries:
async def async_unload(self, entry_id: str) -> bool:
"""Unload a config entry."""
entry = self.async_get_entry(entry_id)
if entry is None:
if (entry := self.async_get_entry(entry_id)) is None:
raise UnknownEntry
if not entry.state.recoverable:
@ -972,9 +963,7 @@ class ConfigEntries:
If an entry was not loaded, will just load.
"""
entry = self.async_get_entry(entry_id)
if entry is None:
if (entry := self.async_get_entry(entry_id)) is None:
raise UnknownEntry
unload_result = await self.async_unload(entry_id)
@ -991,9 +980,7 @@ class ConfigEntries:
If disabled_by is changed, the config entry will be reloaded.
"""
entry = self.async_get_entry(entry_id)
if entry is None:
if (entry := self.async_get_entry(entry_id)) is None:
raise UnknownEntry
if entry.disabled_by == disabled_by:
@ -1066,8 +1053,7 @@ class ConfigEntries:
return False
for listener_ref in entry.update_listeners:
listener = listener_ref()
if listener is not None:
if (listener := listener_ref()) is not None:
self.hass.async_create_task(listener(self.hass, entry))
self._async_schedule_save()

View File

@ -971,8 +971,7 @@ class State:
if isinstance(last_updated, str):
last_updated = dt_util.parse_datetime(last_updated)
context = json_dict.get("context")
if context:
if context := json_dict.get("context"):
context = Context(id=context.get("id"), user_id=context.get("user_id"))
return cls(
@ -1199,8 +1198,7 @@ class StateMachine:
entity_id = entity_id.lower()
new_state = str(new_state)
attributes = attributes or {}
old_state = self._states.get(entity_id)
if old_state is None:
if (old_state := self._states.get(entity_id)) is None:
same_state = False
same_attr = False
last_changed = None
@ -1658,9 +1656,7 @@ class Config:
def set_time_zone(self, time_zone_str: str) -> None:
"""Help to set the time zone."""
time_zone = dt_util.get_time_zone(time_zone_str)
if time_zone:
if time_zone := dt_util.get_time_zone(time_zone_str):
self.time_zone = time_zone_str
dt_util.set_default_time_zone(time_zone)
else:
@ -1717,9 +1713,8 @@ class Config:
store = self.hass.helpers.storage.Store(
CORE_STORAGE_VERSION, CORE_STORAGE_KEY, private=True
)
data = await store.async_load()
if not data:
if not (data := await store.async_load()):
return
# In 2021.9 we fixed validation to disallow a path (because that's never correct)
@ -1792,8 +1787,7 @@ def _async_create_timer(hass: HomeAssistant) -> None:
)
# If we are more than a second late, a tick was missed
late = monotonic() - target
if late > 1:
if (late := monotonic() - target) > 1:
hass.bus.async_fire(
EVENT_TIMER_OUT_OF_SYNC,
{ATTR_SECONDS: late},

View File

@ -93,9 +93,7 @@ class FlowManager(abc.ABC):
async def async_wait_init_flow_finish(self, handler: str) -> None:
"""Wait till all flows in progress are initialized."""
current = self._initializing.get(handler)
if not current:
if not (current := self._initializing.get(handler)):
return
await asyncio.wait(current)
@ -189,9 +187,7 @@ class FlowManager(abc.ABC):
self, flow_id: str, user_input: dict | None = None
) -> FlowResult:
"""Continue a configuration flow."""
flow = self._progress.get(flow_id)
if flow is None:
if (flow := self._progress.get(flow_id)) is None:
raise UnknownFlow
cur_step = flow.cur_step

View File

@ -18,9 +18,7 @@ def config_per_platform(config: ConfigType, domain: str) -> Iterable[tuple[Any,
Async friendly.
"""
for config_key in extract_domain_configs(config, domain):
platform_config = config[config_key]
if not platform_config:
if not (platform_config := config[config_key]):
continue
if not isinstance(platform_config, list):

View File

@ -99,13 +99,11 @@ def get_capability(hass: HomeAssistant, entity_id: str, capability: str) -> Any
First try the statemachine, then entity registry.
"""
state = hass.states.get(entity_id)
if state:
if state := hass.states.get(entity_id):
return state.attributes.get(capability)
entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id)
if not entry:
if not (entry := entity_registry.async_get(entity_id)):
raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.capabilities.get(capability) if entry.capabilities else None
@ -116,13 +114,11 @@ def get_device_class(hass: HomeAssistant, entity_id: str) -> str | None:
First try the statemachine, then entity registry.
"""
state = hass.states.get(entity_id)
if state:
if state := hass.states.get(entity_id):
return state.attributes.get(ATTR_DEVICE_CLASS)
entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id)
if not entry:
if not (entry := entity_registry.async_get(entity_id)):
raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.device_class
@ -133,13 +129,11 @@ def get_supported_features(hass: HomeAssistant, entity_id: str) -> int:
First try the statemachine, then entity registry.
"""
state = hass.states.get(entity_id)
if state:
if state := hass.states.get(entity_id):
return state.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id)
if not entry:
if not (entry := entity_registry.async_get(entity_id)):
raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.supported_features or 0
@ -150,13 +144,11 @@ def get_unit_of_measurement(hass: HomeAssistant, entity_id: str) -> str | None:
First try the statemachine, then entity registry.
"""
state = hass.states.get(entity_id)
if state:
if state := hass.states.get(entity_id):
return state.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
entity_registry = er.async_get(hass)
entry = entity_registry.async_get(entity_id)
if not entry:
if not (entry := entity_registry.async_get(entity_id)):
raise HomeAssistantError(f"Unknown entity {entity_id}")
return entry.unit_of_measurement
@ -467,8 +459,7 @@ class Entity(ABC):
"""Convert state to string."""
if not self.available:
return STATE_UNAVAILABLE
state = self.state
if state is None:
if (state := self.state) is None:
return STATE_UNKNOWN
if isinstance(state, float):
# If the entity's state is a float, limit precision according to machine
@ -511,28 +502,22 @@ class Entity(ABC):
entry = self.registry_entry
# pylint: disable=consider-using-ternary
name = (entry and entry.name) or self.name
if name is not None:
if (name := (entry and entry.name) or self.name) is not None:
attr[ATTR_FRIENDLY_NAME] = name
icon = (entry and entry.icon) or self.icon
if icon is not None:
if (icon := (entry and entry.icon) or self.icon) is not None:
attr[ATTR_ICON] = icon
entity_picture = self.entity_picture
if entity_picture is not None:
if (entity_picture := self.entity_picture) is not None:
attr[ATTR_ENTITY_PICTURE] = entity_picture
assumed_state = self.assumed_state
if assumed_state:
if assumed_state := self.assumed_state:
attr[ATTR_ASSUMED_STATE] = assumed_state
supported_features = self.supported_features
if supported_features is not None:
if (supported_features := self.supported_features) is not None:
attr[ATTR_SUPPORTED_FEATURES] = supported_features
device_class = self.device_class
if device_class is not None:
if (device_class := self.device_class) is not None:
attr[ATTR_DEVICE_CLASS] = str(device_class)
end = timer()
@ -636,8 +621,7 @@ class Entity(ABC):
finished, _ = await asyncio.wait([task], timeout=SLOW_UPDATE_WARNING)
for done in finished:
exc = done.exception()
if exc:
if exc := done.exception():
raise exc
return

View File

@ -175,16 +175,14 @@ def async_track_state_change(
def state_change_filter(event: Event) -> bool:
"""Handle specific state changes."""
if from_state is not None:
old_state = event.data.get("old_state")
if old_state is not None:
if (old_state := event.data.get("old_state")) is not None:
old_state = old_state.state
if not match_from_state(old_state):
return False
if to_state is not None:
new_state = event.data.get("new_state")
if new_state is not None:
if (new_state := event.data.get("new_state")) is not None:
new_state = new_state.state
if not match_to_state(new_state):
@ -246,8 +244,7 @@ def async_track_state_change_event(
care about the state change events so we can
do a fast dict lookup to route events.
"""
entity_ids = _async_string_to_lower_list(entity_ids)
if not entity_ids:
if not (entity_ids := _async_string_to_lower_list(entity_ids)):
return _remove_empty_listener
entity_callbacks = hass.data.setdefault(TRACK_STATE_CHANGE_CALLBACKS, {})
@ -336,8 +333,7 @@ def async_track_entity_registry_updated_event(
Similar to async_track_state_change_event.
"""
entity_ids = _async_string_to_lower_list(entity_ids)
if not entity_ids:
if not (entity_ids := _async_string_to_lower_list(entity_ids)):
return _remove_empty_listener
entity_callbacks = hass.data.setdefault(TRACK_ENTITY_REGISTRY_UPDATED_CALLBACKS, {})
@ -419,8 +415,7 @@ def async_track_state_added_domain(
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is added to domains."""
domains = _async_string_to_lower_list(domains)
if not domains:
if not (domains := _async_string_to_lower_list(domains)):
return _remove_empty_listener
domain_callbacks = hass.data.setdefault(TRACK_STATE_ADDED_DOMAIN_CALLBACKS, {})
@ -472,8 +467,7 @@ def async_track_state_removed_domain(
action: Callable[[Event], Any],
) -> Callable[[], None]:
"""Track state change events when an entity is removed from domains."""
domains = _async_string_to_lower_list(domains)
if not domains:
if not (domains := _async_string_to_lower_list(domains)):
return _remove_empty_listener
domain_callbacks = hass.data.setdefault(TRACK_STATE_REMOVED_DOMAIN_CALLBACKS, {})
@ -1185,8 +1179,7 @@ def async_track_point_in_utc_time(
# as measured by utcnow(). That is bad when callbacks have assumptions
# about the current time. Thus, we rearm the timer for the remaining
# time.
delta = (utc_point_in_time - now).total_seconds()
if delta > 0:
if (delta := (utc_point_in_time - now).total_seconds()) > 0:
_LOGGER.debug("Called %f seconds too early, rearming", delta)
cancel_callback = hass.loop.call_later(delta, run_action, job)
@ -1520,11 +1513,9 @@ def _rate_limit_for_event(
event: Event, info: RenderInfo, track_template_: TrackTemplate
) -> timedelta | None:
"""Determine the rate limit for an event."""
entity_id = event.data.get(ATTR_ENTITY_ID)
# Specifically referenced entities are excluded
# from the rate limit
if entity_id in info.entities:
if event.data.get(ATTR_ENTITY_ID) in info.entities:
return None
if track_template_.rate_limit is not None:

View File

@ -366,9 +366,7 @@ async def async_process_deps_reqs(
Module is a Python module of either a component or platform.
"""
processed = hass.data.get(DATA_DEPS_REQS)
if processed is None:
if (processed := hass.data.get(DATA_DEPS_REQS)) is None:
processed = hass.data[DATA_DEPS_REQS] = set()
elif integration.domain in processed:
return

View File

@ -132,8 +132,7 @@ def parse_datetime(dt_str: str) -> dt.datetime | None:
with suppress(ValueError, IndexError):
return ciso8601.parse_datetime(dt_str)
match = DATETIME_RE.match(dt_str)
if not match:
if not (match := DATETIME_RE.match(dt_str)):
return None
kws: dict[str, Any] = match.groupdict()
if kws["microsecond"]:
@ -269,16 +268,14 @@ def find_next_time_expression_time(
Return None if no such value exists.
"""
left = bisect.bisect_left(arr, cmp)
if left == len(arr):
if (left := bisect.bisect_left(arr, cmp)) == len(arr):
return None
return arr[left]
result = now.replace(microsecond=0)
# Match next second
next_second = _lower_bound(seconds, result.second)
if next_second is None:
if (next_second := _lower_bound(seconds, result.second)) is None:
# No second to match in this minute. Roll-over to next minute.
next_second = seconds[0]
result += dt.timedelta(minutes=1)

View File

@ -43,8 +43,7 @@ def percentage_to_ordered_list_item(ordered_list: list[T], percentage: int) -> T
51-75: high
76-100: very_high
"""
list_len = len(ordered_list)
if not list_len:
if not (list_len := len(ordered_list)):
raise ValueError("The ordered list is empty")
for offset, speed in enumerate(ordered_list):

View File

@ -60,9 +60,7 @@ class Secrets:
def _load_secret_yaml(self, secret_dir: Path) -> dict[str, str]:
"""Load the secrets yaml from path."""
secret_path = secret_dir / SECRET_YAML
if secret_path in self._cache:
if (secret_path := secret_dir / SECRET_YAML) in self._cache:
return self._cache[secret_path]
_LOGGER.debug("Loading %s", secret_path)