Use assignment expressions 03 (#57710)

This commit is contained in:
Marc Mueller 2021-10-17 20:08:11 +02:00 committed by GitHub
parent 2a8eaf0e0f
commit 238b488642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 49 additions and 102 deletions

View File

@ -192,8 +192,7 @@ def _async_register_clientsession_shutdown(
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
)
config_entry = config_entries.current_entry.get()
if not config_entry:
if not (config_entry := config_entries.current_entry.get()):
return
config_entry.async_on_unload(unsub)

View File

@ -328,9 +328,8 @@ def async_numeric_state( # noqa: C901
if isinstance(entity, str):
entity_id = entity
entity = hass.states.get(entity)
if entity is None:
if (entity := hass.states.get(entity)) is None:
raise ConditionErrorMessage("numeric_state", f"unknown entity {entity_id}")
else:
entity_id = entity.entity_id
@ -371,8 +370,7 @@ def async_numeric_state( # noqa: C901
if below is not None:
if isinstance(below, str):
below_entity = hass.states.get(below)
if not below_entity:
if not (below_entity := hass.states.get(below)):
raise ConditionErrorMessage(
"numeric_state", f"unknown 'below' entity {below}"
)
@ -400,8 +398,7 @@ def async_numeric_state( # noqa: C901
if above is not None:
if isinstance(above, str):
above_entity = hass.states.get(above)
if not above_entity:
if not (above_entity := hass.states.get(above)):
raise ConditionErrorMessage(
"numeric_state", f"unknown 'above' entity {above}"
)
@ -497,9 +494,8 @@ def state(
if isinstance(entity, str):
entity_id = entity
entity = hass.states.get(entity)
if entity is None:
if (entity := hass.states.get(entity)) is None:
raise ConditionErrorMessage("state", f"unknown entity {entity_id}")
else:
entity_id = entity.entity_id
@ -526,8 +522,7 @@ def state(
isinstance(req_state_value, str)
and INPUT_ENTITY_ID.match(req_state_value) is not None
):
state_entity = hass.states.get(req_state_value)
if not state_entity:
if not (state_entity := hass.states.get(req_state_value)):
raise ConditionErrorMessage(
"state", f"the 'state' entity {req_state_value} is unavailable"
)
@ -738,8 +733,7 @@ def time(
if after is None:
after = dt_util.dt.time(0)
elif isinstance(after, str):
after_entity = hass.states.get(after)
if not after_entity:
if not (after_entity := hass.states.get(after)):
raise ConditionErrorMessage("time", f"unknown 'after' entity {after}")
if after_entity.domain == "input_datetime":
after = dt_util.dt.time(
@ -763,8 +757,7 @@ def time(
if before is None:
before = dt_util.dt.time(23, 59, 59, 999999)
elif isinstance(before, str):
before_entity = hass.states.get(before)
if not before_entity:
if not (before_entity := hass.states.get(before)):
raise ConditionErrorMessage("time", f"unknown 'before' entity {before}")
if before_entity.domain == "input_datetime":
before = dt_util.dt.time(
@ -840,9 +833,8 @@ def zone(
if isinstance(zone_ent, str):
zone_ent_id = zone_ent
zone_ent = hass.states.get(zone_ent)
if zone_ent is None:
if (zone_ent := hass.states.get(zone_ent)) is None:
raise ConditionErrorMessage("zone", f"unknown zone {zone_ent_id}")
if entity is None:
@ -850,9 +842,8 @@ def zone(
if isinstance(entity, str):
entity_id = entity
entity = hass.states.get(entity)
if entity is None:
if (entity := hass.states.get(entity)) is None:
raise ConditionErrorMessage("zone", f"unknown entity {entity_id}")
else:
entity_id = entity.entity_id
@ -1029,9 +1020,7 @@ def async_extract_devices(config: ConfigType | Template) -> set[str]:
if condition != "device":
continue
device_id = config.get(CONF_DEVICE_ID)
if device_id is not None:
if (device_id := config.get(CONF_DEVICE_ID)) is not None:
referenced.add(device_id)
return referenced

View File

@ -129,14 +129,10 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
@property
def redirect_uri(self) -> str:
"""Return the redirect uri."""
req = http.current_request.get()
if req is None:
if (req := http.current_request.get()) is None:
raise RuntimeError("No current request in context")
ha_host = req.headers.get(HEADER_FRONTEND_BASE)
if ha_host is None:
if (ha_host := req.headers.get(HEADER_FRONTEND_BASE)) is None:
raise RuntimeError("No header in request")
return f"{ha_host}{AUTH_CALLBACK_PATH}"
@ -501,9 +497,7 @@ async def async_oauth2_request(
@callback
def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
"""JWT encode data."""
secret = hass.data.get(DATA_JWT_SECRET)
if secret is None:
if (secret := hass.data.get(DATA_JWT_SECRET)) is None:
secret = hass.data[DATA_JWT_SECRET] = secrets.token_hex()
return jwt.encode(data, secret, algorithm="HS256")

View File

@ -38,8 +38,7 @@ class _BaseFlowManagerView(HomeAssistantView):
data = result.copy()
schema = data["data_schema"]
if schema is None:
if (schema := data["data_schema"]) is None:
data["data_schema"] = []
else:
data["data_schema"] = voluptuous_serialize.convert(

View File

@ -111,9 +111,7 @@ def async_listen_platform(
async def discovery_platform_listener(discovered: DiscoveryDict) -> None:
"""Listen for platform discovery events."""
platform = discovered["platform"]
if not platform:
if not (platform := discovered["platform"]):
return
task = hass.async_run_hass_job(job, platform, discovered.get("discovered"))

View File

@ -727,8 +727,7 @@ current_platform: ContextVar[EntityPlatform | None] = ContextVar(
@callback
def async_get_current_platform() -> EntityPlatform:
"""Get the current platform from context."""
platform = current_platform.get()
if platform is None:
if (platform := current_platform.get()) is None:
raise RuntimeError("Cannot get non-set current platform")
return platform

View File

@ -33,8 +33,7 @@ SPEECH_TYPE_SSML = "ssml"
@bind_hass
def async_register(hass: HomeAssistant, handler: IntentHandler) -> None:
"""Register an intent with Home Assistant."""
intents = hass.data.get(DATA_KEY)
if intents is None:
if (intents := hass.data.get(DATA_KEY)) is None:
intents = hass.data[DATA_KEY] = {}
assert handler.intent_type is not None, "intent_type cannot be None"

View File

@ -51,9 +51,7 @@ def find_coordinates(
hass: HomeAssistant, entity_id: str, recursion_history: list | None = None
) -> str | None:
"""Find the gps coordinates of the entity in the form of '90.000,180.000'."""
entity_state = hass.states.get(entity_id)
if entity_state is None:
if (entity_state := hass.states.get(entity_id)) is None:
_LOGGER.error("Unable to find entity %s", entity_id)
return None

View File

@ -118,8 +118,7 @@ def get_url(
def _get_request_host() -> str | None:
"""Get the host address of the current request."""
request = http.current_request.get()
if request is None:
if (request := http.current_request.get()) is None:
raise NoURLAvailableError
return yarl.URL(request.url).host

View File

@ -78,8 +78,7 @@ class KeyedRateLimit:
if rate_limit is None:
return None
last_triggered = self._last_triggered.get(key)
if not last_triggered:
if not (last_triggered := self._last_triggered.get(key)):
return None
next_call_time = last_triggered + rate_limit

View File

@ -953,8 +953,7 @@ class Script:
variables: ScriptVariables | None = None,
) -> None:
"""Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS)
if not all_scripts:
if not (all_scripts := hass.data.get(DATA_SCRIPTS)):
all_scripts = hass.data[DATA_SCRIPTS] = []
hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
@ -1273,8 +1272,7 @@ class Script:
config_cache_key = config.template
else:
config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key)
if not cond:
if not (cond := self._config_cache.get(config_cache_key)):
cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond
return cond
@ -1297,8 +1295,7 @@ class Script:
return sub_script
def _get_repeat_script(self, step: int) -> Script:
sub_script = self._repeat_script.get(step)
if not sub_script:
if not (sub_script := self._repeat_script.get(step)):
sub_script = self._prep_repeat_script(step)
self._repeat_script[step] = sub_script
return sub_script
@ -1351,8 +1348,7 @@ class Script:
return {"choices": choices, "default": default_script}
async def _async_get_choose_data(self, step: int) -> _ChooseData:
choose_data = self._choose_data.get(step)
if not choose_data:
if not (choose_data := self._choose_data.get(step)):
choose_data = await self._async_prep_choose_data(step)
self._choose_data[step] = choose_data
return choose_data

View File

@ -22,9 +22,7 @@ def validate_selector(config: Any) -> dict:
selector_type = list(config)[0]
selector_class = SELECTORS.get(selector_type)
if selector_class is None:
if (selector_class := SELECTORS.get(selector_type)) is None:
raise vol.Invalid(f"Unknown selector type {selector_type} found")
# Selectors can be empty

View File

@ -396,10 +396,11 @@ async def async_extract_config_entry_ids(
# Some devices may have no entities
for device_id in referenced.referenced_devices:
if device_id in dev_reg.devices:
device = dev_reg.async_get(device_id)
if device is not None:
config_entry_ids.update(device.config_entries)
if (
device_id in dev_reg.devices
and (device := dev_reg.async_get(device_id)) is not None
):
config_entry_ids.update(device.config_entries)
for entity_id in referenced.referenced | referenced.indirectly_referenced:
entry = ent_reg.async_get(entity_id)

View File

@ -813,8 +813,7 @@ class TemplateState(State):
def _collect_state(hass: HomeAssistant, entity_id: str) -> None:
entity_collect = hass.data.get(_RENDER_INFO)
if entity_collect is not None:
if (entity_collect := hass.data.get(_RENDER_INFO)) is not None:
entity_collect.entities.add(entity_id)
@ -1188,8 +1187,7 @@ def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
def now(hass: HomeAssistant) -> datetime:
"""Record fetching now."""
render_info = hass.data.get(_RENDER_INFO)
if render_info is not None:
if (render_info := hass.data.get(_RENDER_INFO)) is not None:
render_info.has_time = True
return dt_util.now()
@ -1197,8 +1195,7 @@ def now(hass: HomeAssistant) -> datetime:
def utcnow(hass: HomeAssistant) -> datetime:
"""Record fetching utcnow."""
render_info = hass.data.get(_RENDER_INFO)
if render_info is not None:
if (render_info := hass.data.get(_RENDER_INFO)) is not None:
render_info.has_time = True
return dt_util.utcnow()
@ -1843,9 +1840,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
# any instance of this.
return super().compile(source, name, filename, raw, defer_init)
cached = self.template_cache.get(source)
if cached is None:
if (cached := self.template_cache.get(source)) is None:
cached = self.template_cache[source] = super().compile(source)
return cached

View File

@ -113,8 +113,7 @@ def trace_id_get() -> tuple[tuple[str, str], str] | None:
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
"""Push an element to the top of a trace stack."""
trace_stack = trace_stack_var.get()
if trace_stack is None:
if (trace_stack := trace_stack_var.get()) is None:
trace_stack = []
trace_stack_var.set(trace_stack)
trace_stack.append(node)
@ -149,8 +148,7 @@ def trace_path_pop(count: int) -> None:
def trace_path_get() -> str:
"""Return a string representing the current location in the config tree."""
path = trace_path_stack_cv.get()
if not path:
if not (path := trace_path_stack_cv.get()):
return ""
return "/".join(path)
@ -160,12 +158,10 @@ def trace_append_element(
maxlen: int | None = None,
) -> None:
"""Append a TraceElement to trace[path]."""
path = trace_element.path
trace = trace_cv.get()
if trace is None:
if (trace := trace_cv.get()) is None:
trace = {}
trace_cv.set(trace)
if path not in trace:
if (path := trace_element.path) not in trace:
trace[path] = deque(maxlen=maxlen)
trace[path].append(trace_element)
@ -213,16 +209,14 @@ class StopReason:
def script_execution_set(reason: str) -> None:
"""Set stop reason."""
data = script_execution_cv.get()
if data is None:
if (data := script_execution_cv.get()) is None:
return
data.script_execution = reason
def script_execution_get() -> str | None:
"""Return the current trace."""
data = script_execution_cv.get()
if data is None:
if (data := script_execution_cv.get()) is None:
return None
return data.script_execution

View File

@ -146,9 +146,7 @@ async def async_get_custom_components(
hass: HomeAssistant,
) -> dict[str, Integration]:
"""Return cached list of custom integrations."""
reg_or_evt = hass.data.get(DATA_CUSTOM_COMPONENTS)
if reg_or_evt is None:
if (reg_or_evt := hass.data.get(DATA_CUSTOM_COMPONENTS)) is None:
evt = hass.data[DATA_CUSTOM_COMPONENTS] = asyncio.Event()
reg = await _async_get_custom_components(hass)
@ -543,8 +541,7 @@ class Integration:
async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
"""Get an integration."""
cache = hass.data.get(DATA_INTEGRATIONS)
if cache is None:
if (cache := hass.data.get(DATA_INTEGRATIONS)) is None:
if not _async_mount_config_dir(hass):
raise IntegrationNotFound(domain)
cache = hass.data[DATA_INTEGRATIONS] = {}
@ -553,12 +550,11 @@ async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration
if isinstance(int_or_evt, asyncio.Event):
await int_or_evt.wait()
int_or_evt = cache.get(domain, _UNDEF)
# When we have waited and it's _UNDEF, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid.
if int_or_evt is _UNDEF:
if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF:
raise IntegrationNotFound(domain)
if int_or_evt is not _UNDEF:
@ -630,8 +626,7 @@ def _load_file(
with suppress(KeyError):
return hass.data[DATA_COMPONENTS][comp_or_platform] # type: ignore
cache = hass.data.get(DATA_COMPONENTS)
if cache is None:
if (cache := hass.data.get(DATA_COMPONENTS)) is None:
if not _async_mount_config_dir(hass):
return None
cache = hass.data[DATA_COMPONENTS] = {}

View File

@ -60,8 +60,7 @@ async def async_get_integration_with_requirements(
if hass.config.skip_pip:
return integration
cache = hass.data.get(DATA_INTEGRATIONS_WITH_REQS)
if cache is None:
if (cache := hass.data.get(DATA_INTEGRATIONS_WITH_REQS)) is None:
cache = hass.data[DATA_INTEGRATIONS_WITH_REQS] = {}
int_or_evt: Integration | asyncio.Event | None | UndefinedType = cache.get(
@ -71,12 +70,10 @@ async def async_get_integration_with_requirements(
if isinstance(int_or_evt, asyncio.Event):
await int_or_evt.wait()
int_or_evt = cache.get(domain, UNDEFINED)
# When we have waited and it's UNDEFINED, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid.
if int_or_evt is UNDEFINED:
if (int_or_evt := cache.get(domain, UNDEFINED)) is UNDEFINED:
raise IntegrationNotFound(domain)
if int_or_evt is not UNDEFINED:
@ -154,8 +151,7 @@ async def async_process_requirements(
This method is a coroutine. It will raise RequirementsNotFound
if an requirement can't be satisfied.
"""
pip_lock = hass.data.get(DATA_PIP_LOCK)
if pip_lock is None:
if (pip_lock := hass.data.get(DATA_PIP_LOCK)) is None:
pip_lock = hass.data[DATA_PIP_LOCK] = asyncio.Lock()
install_failure_history = hass.data.get(DATA_INSTALL_FAILURE_HISTORY)
if install_failure_history is None:

View File

@ -83,8 +83,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[valid
def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None:
"""Handle all exception inside the core loop."""
kwargs = {}
exception = context.get("exception")
if exception:
if exception := context.get("exception"):
kwargs["exc_info"] = (type(exception), exception, exception.__traceback__)
logging.getLogger(__package__).error(