Add negative tests for identify schema for packages (#33050)

This commit is contained in:
Paulus Schoutsen 2020-03-20 13:34:56 -07:00 committed by GitHub
parent c2ac8e813a
commit d16d44d3e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 14 deletions

View File

@ -562,12 +562,12 @@ def _log_pkg_error(package: str, component: str, config: Dict, message: str) ->
_LOGGER.error(message)
def _identify_config_schema(module: ModuleType) -> Tuple[Optional[str], Optional[Dict]]:
def _identify_config_schema(module: ModuleType) -> Optional[str]:
"""Extract the schema and identify list or dict based."""
try:
key = next(k for k in module.CONFIG_SCHEMA.schema if k == module.DOMAIN) # type: ignore
except (AttributeError, StopIteration):
return None, None
return None
schema = module.CONFIG_SCHEMA.schema[key] # type: ignore
@ -577,19 +577,19 @@ def _identify_config_schema(module: ModuleType) -> Tuple[Optional[str], Optional
default_value = schema(key.default())
if isinstance(default_value, dict):
return "dict", schema
return "dict"
if isinstance(default_value, list):
return "list", schema
return "list"
return None, None
return None
t_schema = str(schema)
if t_schema.startswith("{") or "schema_with_slug_keys" in t_schema:
return ("dict", schema)
return "dict"
if t_schema.startswith(("[", "All(<function ensure_list")):
return ("list", schema)
return "", schema
return "list"
return None
def _recursive_merge(conf: Dict[str, Any], package: Dict[str, Any]) -> Union[bool, str]:
@ -642,8 +642,7 @@ async def merge_packages_config(
merge_list = hasattr(component, "PLATFORM_SCHEMA")
if not merge_list and hasattr(component, "CONFIG_SCHEMA"):
merge_type, _ = _identify_config_schema(component)
merge_list = merge_type == "list"
merge_list = _identify_config_schema(component) == "list"
if merge_list:
config[comp_name] = cv.remove_falsy(

View File

@ -722,7 +722,7 @@ async def test_merge_id_schema(hass):
for domain, expected_type in types.items():
integration = await async_get_integration(hass, domain)
module = integration.get_component()
typ, _ = config_util._identify_config_schema(module)
typ = config_util._identify_config_schema(module)
assert typ == expected_type, f"{domain} expected {expected_type}, got {typ}"
@ -997,13 +997,16 @@ async def test_component_config_exceptions(hass, caplog):
[
("zone", vol.Schema({vol.Optional("zone", default=[]): list}), "list"),
("zone", vol.Schema({vol.Optional("zone", default=dict): dict}), "dict"),
("zone", vol.Schema({vol.Optional("zone"): int}), None),
("zone", vol.Schema({"zone": int}), None),
("not_existing", vol.Schema({vol.Optional("zone", default=dict): dict}), None,),
("non_existing", vol.Schema({"zone": int}), None),
("zone", vol.Schema({}), None),
],
)
def test_identify_config_schema(domain, schema, expected):
"""Test identify config schema."""
assert (
config_util._identify_config_schema(Mock(DOMAIN=domain, CONFIG_SCHEMA=schema))[
0
]
config_util._identify_config_schema(Mock(DOMAIN=domain, CONFIG_SCHEMA=schema))
== expected
)