plugin.api: fix typing issues

This commit is contained in:
bastimeyer 2022-05-21 00:21:33 +02:00 committed by back-to
parent c6b68b098b
commit 17697f0bc0
8 changed files with 44 additions and 32 deletions

View File

@ -13,18 +13,19 @@ Copyright 2015 Red Hat, Inc.
See the License for the specific language governing permissions and
limitations under the License.
"""
from io import BytesIO
import sys
from requests.adapters import BaseAdapter
from requests.compat import urlparse, unquote, urljoin
from requests import Response, codes
import errno
import io
import locale
import os
import os.path
import stat
import locale
import io
import sys
from io import BytesIO
from urllib.parse import unquote, urljoin, urlparse
from requests import Response, codes
from requests.adapters import BaseAdapter
from streamlink.compat import is_win32

View File

@ -38,7 +38,7 @@ class _HTTPResponse(urllib3.response.HTTPResponse):
# override all urllib3.response.HTTPResponse references in requests.adapters.HTTPAdapter.send
urllib3.connectionpool.HTTPConnectionPool.ResponseCls = _HTTPResponse
urllib3.connectionpool.HTTPConnectionPool.ResponseCls = _HTTPResponse # type: ignore[attr-defined]
requests.adapters.HTTPResponse = _HTTPResponse
@ -55,7 +55,7 @@ requests.adapters.HTTPResponse = _HTTPResponse
# > encodings.
if urllib3_version >= (1, 25, 4):
class Urllib3UtilUrlPercentReOverride:
_re_percent_encoding: Pattern = urllib3.util.url.PERCENT_RE
_re_percent_encoding: Pattern = urllib3.util.url.PERCENT_RE # type: ignore[attr-defined]
@classmethod
def _num_percent_encodings(cls, string) -> int:
@ -77,7 +77,7 @@ if urllib3_version >= (1, 25, 4):
return _List()
urllib3.util.url.PERCENT_RE = Urllib3UtilUrlPercentReOverride
urllib3.util.url.PERCENT_RE = Urllib3UtilUrlPercentReOverride # type: ignore[attr-defined]
def _parse_keyvalue_list(val):

View File

@ -5,19 +5,21 @@ from typing import Optional, Sequence, Union
class ValidationError(ValueError):
MAX_LENGTH = 60
errors: Union[str, Exception, Sequence[Union[str, Exception]]]
def __init__(
self,
*error: Union[str, Exception, Sequence[Union[str, Exception]]],
*errors,
schema: Optional[Union[str, object]] = None,
context: Optional[Union[Exception]] = None,
**errkeywords
):
self.schema = schema
self.context = context
if len(error) == 1 and type(error[0]) is str:
self.errors = (self._truncate(error[0], **errkeywords), )
if len(errors) == 1 and type(errors[0]) is str:
self.errors = (self._truncate(errors[0], **errkeywords), )
else:
self.errors = error
self.errors = errors
def _ellipsis(self, string: str):
return string if len(string) <= self.MAX_LENGTH else f"<{string[:self.MAX_LENGTH - 5]}...>"
@ -32,7 +34,7 @@ class ValidationError(ValueError):
return ""
if type(self.schema) is str:
return f"({self.schema})"
return f"({self.schema.__name__})"
return f"({self.schema.__name__})" # type: ignore[attr-defined]
def __str__(self):
cls = self.__class__

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, FrozenSet, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, FrozenSet, List, Optional, Sequence, Set, Tuple, Type, Union
class SchemaContainer:
@ -108,7 +108,7 @@ class UnionGetSchema:
*getters,
seq: Type[Union[List, FrozenSet, Set, Tuple]] = tuple,
):
self.getters: Tuple[GetItemSchema] = tuple(GetItemSchema(getter) for getter in getters)
self.getters: Sequence[GetItemSchema] = tuple(GetItemSchema(getter) for getter in getters)
self.seq = seq

View File

@ -121,7 +121,7 @@ def _validate_dict(schema, value):
return new
@validate.register(abc.Callable)
@validate.register
def _validate_callable(schema: abc.Callable, value):
if not schema(value):
raise ValidationError(
@ -133,7 +133,7 @@ def _validate_callable(schema: abc.Callable, value):
return value
@validate.register(AllSchema)
@validate.register
def _validate_allschema(schema: AllSchema, value):
for schema in schema.schema:
value = validate(schema, value)
@ -141,7 +141,7 @@ def _validate_allschema(schema: AllSchema, value):
return value
@validate.register(AnySchema)
@validate.register
def _validate_anyschema(schema: AnySchema, value):
errors = []
for subschema in schema.schema:
@ -153,13 +153,13 @@ def _validate_anyschema(schema: AnySchema, value):
raise ValidationError(*errors, schema=AnySchema)
@validate.register(TransformSchema)
@validate.register
def _validate_transformschema(schema: TransformSchema, value):
validate(abc.Callable, schema.func)
return schema.func(value, *schema.args, **schema.kwargs)
@validate.register(GetItemSchema)
@validate.register
def _validate_getitemschema(schema: GetItemSchema, value):
item = schema.item if type(schema.item) is tuple and not schema.strict else (schema.item,)
idx = 0
@ -194,7 +194,7 @@ def _validate_getitemschema(schema: GetItemSchema, value):
)
@validate.register(AttrSchema)
@validate.register
def _validate_attrschema(schema: AttrSchema, value):
new = copy(value)
@ -222,7 +222,7 @@ def _validate_attrschema(schema: AttrSchema, value):
return new
@validate.register(XmlElementSchema)
@validate.register
def _validate_xmlelementschema(schema: XmlElementSchema, value):
validate(iselement, value)
tag = value.tag
@ -263,14 +263,14 @@ def _validate_xmlelementschema(schema: XmlElementSchema, value):
return new
@validate.register(UnionGetSchema)
@validate.register
def _validate_uniongetschema(schema: UnionGetSchema, value):
return schema.seq(
validate(getter, value) for getter in schema.getters
)
@validate.register(UnionSchema)
@validate.register
def _validate_unionschema(schema: UnionSchema, value):
try:
return validate_union(schema.schema, value)

View File

@ -167,7 +167,7 @@ def validator_hasattr(attr: Any) -> Callable[[Any], bool]:
# Sequence related validators
def validator_filter(func: Callable[[Any], bool]) -> TransformSchema:
def validator_filter(func: Callable[..., bool]) -> TransformSchema:
"""
Filter out unwanted items from the input using the specified function.
@ -187,7 +187,7 @@ def validator_filter(func: Callable[[Any], bool]) -> TransformSchema:
return TransformSchema(filter_values)
def validator_map(func: Callable[[Any], Any]) -> TransformSchema:
def validator_map(func: Callable[..., Any]) -> TransformSchema:
"""
Transform items from the input using the specified function.

View File

@ -23,7 +23,7 @@ class WebsocketClient(Thread):
session: Streamlink,
url: str,
subprotocols: Optional[List[str]] = None,
header: Optional[Union[List, Dict]] = None,
header: Optional[Union[List[str], Dict[str, str]]] = None,
cookie: Optional[str] = None,
sockopt: Optional[Tuple] = None,
sslopt: Optional[Dict] = None,
@ -39,10 +39,12 @@ class WebsocketClient(Thread):
if not header:
header = []
elif isinstance(header, dict):
header = [f"{str(k)}: {str(v)}" for k, v in header.items()]
if not any(True for h in header if h.startswith("User-Agent: ")):
header.append(f"User-Agent: {session.http.headers['User-Agent']}")
header.append(f"User-Agent: {str(session.http.headers['User-Agent'])}")
proxy_options = {}
proxy_options: Dict[str, Any] = {}
http_proxy: Optional[str] = session.get_option("http-proxy")
if http_proxy:
p = urlparse(http_proxy)
@ -127,7 +129,9 @@ class WebsocketClient(Thread):
)
def close(self, status: int = STATUS_NORMAL, reason: Union[str, bytes] = "", timeout: int = 3) -> None:
self.ws.close(status=status, reason=bytes(reason, encoding="utf-8"), timeout=timeout)
if type(reason) is str: # pragma: no branch
reason = bytes(reason, encoding="utf-8")
self.ws.close(status=status, reason=reason, timeout=timeout)
if self.is_alive(): # pragma: no branch
self.join()

View File

@ -37,6 +37,11 @@ class TestWebsocketClient(unittest.TestCase):
"User-Agent: foo"
])
client = WebsocketClient(self.session, "wss://localhost:0", header={"User-Agent": "bar"})
self.assertEqual(client.ws.header, [
"User-Agent: bar"
])
def test_args_and_proxy(self):
self.session.set_option("http-proxy", "https://username:password@hostname:1234")
client = WebsocketClient(