mirror of https://github.com/streamlink/streamlink
plugin.api: fix typing issues
This commit is contained in:
parent
c6b68b098b
commit
17697f0bc0
|
@ -13,18 +13,19 @@ Copyright 2015 Red Hat, Inc.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from io import BytesIO
|
||||
|
||||
import sys
|
||||
from requests.adapters import BaseAdapter
|
||||
from requests.compat import urlparse, unquote, urljoin
|
||||
from requests import Response, codes
|
||||
import errno
|
||||
import io
|
||||
import locale
|
||||
import os
|
||||
import os.path
|
||||
import stat
|
||||
import locale
|
||||
import io
|
||||
import sys
|
||||
from io import BytesIO
|
||||
from urllib.parse import unquote, urljoin, urlparse
|
||||
|
||||
from requests import Response, codes
|
||||
from requests.adapters import BaseAdapter
|
||||
|
||||
from streamlink.compat import is_win32
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class _HTTPResponse(urllib3.response.HTTPResponse):
|
|||
|
||||
|
||||
# override all urllib3.response.HTTPResponse references in requests.adapters.HTTPAdapter.send
|
||||
urllib3.connectionpool.HTTPConnectionPool.ResponseCls = _HTTPResponse
|
||||
urllib3.connectionpool.HTTPConnectionPool.ResponseCls = _HTTPResponse # type: ignore[attr-defined]
|
||||
requests.adapters.HTTPResponse = _HTTPResponse
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ requests.adapters.HTTPResponse = _HTTPResponse
|
|||
# > encodings.
|
||||
if urllib3_version >= (1, 25, 4):
|
||||
class Urllib3UtilUrlPercentReOverride:
|
||||
_re_percent_encoding: Pattern = urllib3.util.url.PERCENT_RE
|
||||
_re_percent_encoding: Pattern = urllib3.util.url.PERCENT_RE # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def _num_percent_encodings(cls, string) -> int:
|
||||
|
@ -77,7 +77,7 @@ if urllib3_version >= (1, 25, 4):
|
|||
|
||||
return _List()
|
||||
|
||||
urllib3.util.url.PERCENT_RE = Urllib3UtilUrlPercentReOverride
|
||||
urllib3.util.url.PERCENT_RE = Urllib3UtilUrlPercentReOverride # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _parse_keyvalue_list(val):
|
||||
|
|
|
@ -5,19 +5,21 @@ from typing import Optional, Sequence, Union
|
|||
class ValidationError(ValueError):
|
||||
MAX_LENGTH = 60
|
||||
|
||||
errors: Union[str, Exception, Sequence[Union[str, Exception]]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*error: Union[str, Exception, Sequence[Union[str, Exception]]],
|
||||
*errors,
|
||||
schema: Optional[Union[str, object]] = None,
|
||||
context: Optional[Union[Exception]] = None,
|
||||
**errkeywords
|
||||
):
|
||||
self.schema = schema
|
||||
self.context = context
|
||||
if len(error) == 1 and type(error[0]) is str:
|
||||
self.errors = (self._truncate(error[0], **errkeywords), )
|
||||
if len(errors) == 1 and type(errors[0]) is str:
|
||||
self.errors = (self._truncate(errors[0], **errkeywords), )
|
||||
else:
|
||||
self.errors = error
|
||||
self.errors = errors
|
||||
|
||||
def _ellipsis(self, string: str):
|
||||
return string if len(string) <= self.MAX_LENGTH else f"<{string[:self.MAX_LENGTH - 5]}...>"
|
||||
|
@ -32,7 +34,7 @@ class ValidationError(ValueError):
|
|||
return ""
|
||||
if type(self.schema) is str:
|
||||
return f"({self.schema})"
|
||||
return f"({self.schema.__name__})"
|
||||
return f"({self.schema.__name__})" # type: ignore[attr-defined]
|
||||
|
||||
def __str__(self):
|
||||
cls = self.__class__
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable, FrozenSet, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, FrozenSet, List, Optional, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
|
||||
class SchemaContainer:
|
||||
|
@ -108,7 +108,7 @@ class UnionGetSchema:
|
|||
*getters,
|
||||
seq: Type[Union[List, FrozenSet, Set, Tuple]] = tuple,
|
||||
):
|
||||
self.getters: Tuple[GetItemSchema] = tuple(GetItemSchema(getter) for getter in getters)
|
||||
self.getters: Sequence[GetItemSchema] = tuple(GetItemSchema(getter) for getter in getters)
|
||||
self.seq = seq
|
||||
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ def _validate_dict(schema, value):
|
|||
return new
|
||||
|
||||
|
||||
@validate.register(abc.Callable)
|
||||
@validate.register
|
||||
def _validate_callable(schema: abc.Callable, value):
|
||||
if not schema(value):
|
||||
raise ValidationError(
|
||||
|
@ -133,7 +133,7 @@ def _validate_callable(schema: abc.Callable, value):
|
|||
return value
|
||||
|
||||
|
||||
@validate.register(AllSchema)
|
||||
@validate.register
|
||||
def _validate_allschema(schema: AllSchema, value):
|
||||
for schema in schema.schema:
|
||||
value = validate(schema, value)
|
||||
|
@ -141,7 +141,7 @@ def _validate_allschema(schema: AllSchema, value):
|
|||
return value
|
||||
|
||||
|
||||
@validate.register(AnySchema)
|
||||
@validate.register
|
||||
def _validate_anyschema(schema: AnySchema, value):
|
||||
errors = []
|
||||
for subschema in schema.schema:
|
||||
|
@ -153,13 +153,13 @@ def _validate_anyschema(schema: AnySchema, value):
|
|||
raise ValidationError(*errors, schema=AnySchema)
|
||||
|
||||
|
||||
@validate.register(TransformSchema)
|
||||
@validate.register
|
||||
def _validate_transformschema(schema: TransformSchema, value):
|
||||
validate(abc.Callable, schema.func)
|
||||
return schema.func(value, *schema.args, **schema.kwargs)
|
||||
|
||||
|
||||
@validate.register(GetItemSchema)
|
||||
@validate.register
|
||||
def _validate_getitemschema(schema: GetItemSchema, value):
|
||||
item = schema.item if type(schema.item) is tuple and not schema.strict else (schema.item,)
|
||||
idx = 0
|
||||
|
@ -194,7 +194,7 @@ def _validate_getitemschema(schema: GetItemSchema, value):
|
|||
)
|
||||
|
||||
|
||||
@validate.register(AttrSchema)
|
||||
@validate.register
|
||||
def _validate_attrschema(schema: AttrSchema, value):
|
||||
new = copy(value)
|
||||
|
||||
|
@ -222,7 +222,7 @@ def _validate_attrschema(schema: AttrSchema, value):
|
|||
return new
|
||||
|
||||
|
||||
@validate.register(XmlElementSchema)
|
||||
@validate.register
|
||||
def _validate_xmlelementschema(schema: XmlElementSchema, value):
|
||||
validate(iselement, value)
|
||||
tag = value.tag
|
||||
|
@ -263,14 +263,14 @@ def _validate_xmlelementschema(schema: XmlElementSchema, value):
|
|||
return new
|
||||
|
||||
|
||||
@validate.register(UnionGetSchema)
|
||||
@validate.register
|
||||
def _validate_uniongetschema(schema: UnionGetSchema, value):
|
||||
return schema.seq(
|
||||
validate(getter, value) for getter in schema.getters
|
||||
)
|
||||
|
||||
|
||||
@validate.register(UnionSchema)
|
||||
@validate.register
|
||||
def _validate_unionschema(schema: UnionSchema, value):
|
||||
try:
|
||||
return validate_union(schema.schema, value)
|
||||
|
|
|
@ -167,7 +167,7 @@ def validator_hasattr(attr: Any) -> Callable[[Any], bool]:
|
|||
# Sequence related validators
|
||||
|
||||
|
||||
def validator_filter(func: Callable[[Any], bool]) -> TransformSchema:
|
||||
def validator_filter(func: Callable[..., bool]) -> TransformSchema:
|
||||
"""
|
||||
Filter out unwanted items from the input using the specified function.
|
||||
|
||||
|
@ -187,7 +187,7 @@ def validator_filter(func: Callable[[Any], bool]) -> TransformSchema:
|
|||
return TransformSchema(filter_values)
|
||||
|
||||
|
||||
def validator_map(func: Callable[[Any], Any]) -> TransformSchema:
|
||||
def validator_map(func: Callable[..., Any]) -> TransformSchema:
|
||||
"""
|
||||
Transform items from the input using the specified function.
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ class WebsocketClient(Thread):
|
|||
session: Streamlink,
|
||||
url: str,
|
||||
subprotocols: Optional[List[str]] = None,
|
||||
header: Optional[Union[List, Dict]] = None,
|
||||
header: Optional[Union[List[str], Dict[str, str]]] = None,
|
||||
cookie: Optional[str] = None,
|
||||
sockopt: Optional[Tuple] = None,
|
||||
sslopt: Optional[Dict] = None,
|
||||
|
@ -39,10 +39,12 @@ class WebsocketClient(Thread):
|
|||
|
||||
if not header:
|
||||
header = []
|
||||
elif isinstance(header, dict):
|
||||
header = [f"{str(k)}: {str(v)}" for k, v in header.items()]
|
||||
if not any(True for h in header if h.startswith("User-Agent: ")):
|
||||
header.append(f"User-Agent: {session.http.headers['User-Agent']}")
|
||||
header.append(f"User-Agent: {str(session.http.headers['User-Agent'])}")
|
||||
|
||||
proxy_options = {}
|
||||
proxy_options: Dict[str, Any] = {}
|
||||
http_proxy: Optional[str] = session.get_option("http-proxy")
|
||||
if http_proxy:
|
||||
p = urlparse(http_proxy)
|
||||
|
@ -127,7 +129,9 @@ class WebsocketClient(Thread):
|
|||
)
|
||||
|
||||
def close(self, status: int = STATUS_NORMAL, reason: Union[str, bytes] = "", timeout: int = 3) -> None:
|
||||
self.ws.close(status=status, reason=bytes(reason, encoding="utf-8"), timeout=timeout)
|
||||
if type(reason) is str: # pragma: no branch
|
||||
reason = bytes(reason, encoding="utf-8")
|
||||
self.ws.close(status=status, reason=reason, timeout=timeout)
|
||||
if self.is_alive(): # pragma: no branch
|
||||
self.join()
|
||||
|
||||
|
|
|
@ -37,6 +37,11 @@ class TestWebsocketClient(unittest.TestCase):
|
|||
"User-Agent: foo"
|
||||
])
|
||||
|
||||
client = WebsocketClient(self.session, "wss://localhost:0", header={"User-Agent": "bar"})
|
||||
self.assertEqual(client.ws.header, [
|
||||
"User-Agent: bar"
|
||||
])
|
||||
|
||||
def test_args_and_proxy(self):
|
||||
self.session.set_option("http-proxy", "https://username:password@hostname:1234")
|
||||
client = WebsocketClient(
|
||||
|
|
Loading…
Reference in New Issue