Adjust pylint plugin for return type inheritance (#90046)

This commit is contained in:
epenet 2023-03-26 18:37:26 +02:00 committed by GitHub
parent bec7bbeb92
commit 6e92dac61f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 10 deletions

View File

@ -42,7 +42,6 @@ class TypeHintMatch:
"""named_arg_types is for named or keyword arguments"""
kwargs_type: str | None = None
"""kwargs_type is for the special case `**kwargs`"""
check_return_type_inheritance: bool = False
has_async_counterpart: bool = False
def need_to_check_function(self, node: nodes.FunctionDef) -> bool:
@ -398,7 +397,6 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
1: "ConfigType",
},
return_type=["DeviceScanner", None],
check_return_type_inheritance=True,
has_async_counterpart=True,
),
],
@ -466,7 +464,6 @@ _FUNCTION_MATCH: dict[str, list[TypeHintMatch]] = {
2: "DiscoveryInfoType | None",
},
return_type=["BaseNotificationService", None],
check_return_type_inheritance=True,
has_async_counterpart=True,
),
],
@ -493,7 +490,6 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = {
0: "ConfigEntry",
},
return_type="OptionsFlow",
check_return_type_inheritance=True,
),
TypeHintMatch(
function_name="async_step_dhcp",
@ -681,7 +677,6 @@ _RESTORE_ENTITY_MATCH: list[TypeHintMatch] = [
TypeHintMatch(
function_name="extra_restore_state_data",
return_type=["ExtraStoredData", None],
check_return_type_inheritance=True,
),
]
_TOGGLE_ENTITY_MATCH: list[TypeHintMatch] = [
@ -2842,15 +2837,13 @@ def _is_valid_return_type(match: TypeHintMatch, node: nodes.NodeNG) -> bool:
match, node.right
)
if (
match.check_return_type_inheritance
and isinstance(match.return_type, (str, list))
and isinstance(node, nodes.Name)
):
if isinstance(match.return_type, (str, list)) and isinstance(node, nodes.Name):
if isinstance(match.return_type, str):
valid_types = {match.return_type}
else:
valid_types = {el for el in match.return_type if isinstance(el, str)}
if "Mapping[str, Any]" in valid_types:
valid_types.add("TypedDict")
try:
for infer_node in node.infer():

View File

@ -724,6 +724,7 @@ def test_invalid_mapping_return_type(
"-> Mapping[str, bool | int]",
"-> dict[str, Any]",
"-> dict[str, str]",
"-> CustomTypedDict",
],
)
def test_valid_mapping_return_type(
@ -737,6 +738,11 @@ def test_valid_mapping_return_type(
class_node = astroid.extract_node(
f"""
from typing import TypedDict
class CustomTypedDict(TypedDict):
pass
class Entity():
pass