More mypy checks/fixes.

This commit is contained in:
Dain Nilsson 2021-08-19 10:31:20 +02:00
parent f12247ed51
commit fed257922c
No known key found for this signature in database
GPG Key ID: F04367096FBA95E8
12 changed files with 61 additions and 40 deletions

View File

@ -49,6 +49,7 @@ from typing import Type, Any, Union, Callable, Optional, Mapping, Sequence
import json
import platform
import inspect
class ClientData(bytes):
@ -387,8 +388,10 @@ class Fido2ClientAssertionSelection(AssertionSelection):
return extension_outputs
def _default_extensions():
return [cls for cls in Ctap2Extension.__subclasses__() if hasattr(cls, "NAME")]
def _default_extensions() -> Sequence[Type[Ctap2Extension]]:
return [
cls for cls in Ctap2Extension.__subclasses__() if not inspect.isabstract(cls)
]
_CTAP1_INFO = Info.create(versions=["U2F_V2"], aaguid=b"\0" * 32)
@ -410,7 +413,7 @@ class Fido2Client(_BaseClient):
device: CtapDevice,
origin: str,
verify: Callable[[str, str], bool] = verify_rp_id,
extension_types: Optional[Type[Ctap2Extension]] = None,
extension_types: Sequence[Type[Ctap2Extension]] = [],
):
super(Fido2Client, self).__init__(origin, verify)
@ -420,9 +423,9 @@ class Fido2Client(_BaseClient):
self.ctap2 = Ctap2(device)
self.info = self.ctap2.info
try:
self.client_pin: Optional[ClientPin] = ClientPin(self.ctap2)
self.client_pin: ClientPin = ClientPin(self.ctap2)
except ValueError:
self.client_pin = None
self.client_pin = None # type: ignore
self._do_make_credential = self._ctap2_make_credential
self._do_get_assertion = self._ctap2_get_assertion
except (ValueError, CtapError):
@ -760,10 +763,7 @@ class Fido2Client(_BaseClient):
pin_protocol, pin_auth, internal_uv = self._get_auth_params(
client_data, rp_id, user_verification, pin, event, on_keepalive
)
if internal_uv:
options = {"uv": True}
else:
options = None
options = {"uv": True} if internal_uv else None
if allow_list:
# Filter out credential IDs which are too long
@ -966,5 +966,12 @@ class WindowsClient(_BaseClient):
user = {"id": user_id} if user_id else None
return AssertionSelection(
client_data,
[AssertionResponse.create(credential, auth_data, signature, user)],
[
AssertionResponse.create(
credential=credential,
auth_data=auth_data,
signature=signature,
user=user,
)
],
)

View File

@ -29,6 +29,7 @@ from .utils import bytes2int, int2bytes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
from typing import Sequence, Type
try:
from cryptography.hazmat.primitives.asymmetric import ed25519
@ -101,9 +102,9 @@ class CoseKey(dict):
def supported_algorithms():
"""Get a list of all supported algorithm identifiers"""
if ed25519:
algs = (ES256, EdDSA, PS256, RS256)
algs: Sequence[Type[CoseKey]] = [ES256, EdDSA, PS256, RS256]
else:
algs = (ES256, PS256, RS256)
algs = [ES256, PS256, RS256]
return [cls.ALGORITHM for cls in algs]

View File

@ -36,6 +36,8 @@ class Ctap2Extension(abc.ABC):
the extension.
"""
NAME: str = None # type: ignore
def __init__(self, ctap):
self.ctap = ctap

View File

@ -163,7 +163,7 @@ class CtapHidDevice(CtapDevice):
try:
# Read response
seq = 0
response = None
response = b""
last_ka = None
while True:
if event.is_set():
@ -182,11 +182,11 @@ class CtapHidDevice(CtapDevice):
if r_channel != self._channel_id:
raise Exception("Wrong channel")
if response is None: # Initialization packet
if not response: # Initialization packet
r_cmd, r_len = struct.unpack_from(">BH", recv)
recv = recv[3:]
if r_cmd == TYPE_INIT | cmd:
response = b""
pass # first data packet
elif r_cmd == TYPE_INIT | CTAPHID.KEEPALIVE:
ka_status = struct.unpack_from(">B", recv)[0]
logger.debug("Got keepalive status: %02x" % ka_status)

View File

@ -24,6 +24,7 @@ import os
from .base import HidDescriptor, parse_report_descriptor, FileCtapHidConnection
import logging
from typing import Dict, Optional
logger = logging.getLogger(__name__)
@ -101,7 +102,7 @@ def _enumerate():
if retval != 0:
continue
dev = {}
dev: Dict[str, Optional[str]] = {}
dev["name"] = uhid[len(devdir) :]
dev["path"] = uhid

View File

@ -277,7 +277,7 @@ class MacCtapHidConnection(CtapHidConnection):
raise OSError("Failed to open device for communication: {}".format(result))
# Create read queue
self.read_queue = Queue()
self.read_queue: Queue = Queue()
# Create and start read thread
self.run_loop_ref = None

View File

@ -99,7 +99,7 @@ def get_descriptor(path):
dev_info = UsbDeviceInfo()
try:
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info)
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info) # type: ignore
finally:
os.close(f)

View File

@ -19,7 +19,7 @@ from .base import HidDescriptor, CtapHidConnection, FIDO_USAGE_PAGE, FIDO_USAGE
import ctypes
import platform
from ctypes import WinDLL # type: ignore
from ctypes import WinDLL, WinError # type: ignore
from ctypes import wintypes, LibraryLoader
import logging
@ -193,7 +193,7 @@ class WinCtapHidConnection(CtapHidConnection):
None,
)
if self.handle == INVALID_HANDLE_VALUE:
raise ctypes.WinError()
raise WinError()
def close(self):
kernel32.CloseHandle(self.handle)
@ -205,7 +205,7 @@ class WinCtapHidConnection(CtapHidConnection):
self.handle, out, len(out), ctypes.byref(num_written), None
)
if not ret:
raise ctypes.WinError()
raise WinError()
if num_written.value != len(out):
raise OSError(
"Failed to write complete packet. "
@ -219,7 +219,7 @@ class WinCtapHidConnection(CtapHidConnection):
self.handle, buf, len(buf), ctypes.byref(num_read), None
)
if not ret:
raise ctypes.WinError()
raise WinError()
if num_read.value != self.descriptor.report_size_in + 1:
raise OSError("Failed to read full length report from device.")
@ -231,7 +231,7 @@ def get_vid_pid(device):
attributes = HidAttributes()
result = hid.HidD_GetAttributes(device, ctypes.byref(attributes))
if not result:
raise ctypes.WinError()
raise WinError()
return attributes.VendorID, attributes.ProductID
@ -269,19 +269,19 @@ def get_descriptor(path):
None,
)
if device == INVALID_HANDLE_VALUE:
raise ctypes.WinError()
raise WinError()
try:
preparsed_data = PHIDP_PREPARSED_DATA(0)
ret = hid.HidD_GetPreparsedData(device, ctypes.byref(preparsed_data))
if not ret:
raise ctypes.WinError()
raise WinError()
try:
caps = HidCapabilities()
ret = hid.HidP_GetCaps(preparsed_data, ctypes.byref(caps))
if ret != HIDP_STATUS_SUCCESS:
raise ctypes.WinError()
raise WinError()
if caps.UsagePage == FIDO_USAGE_PAGE and caps.Usage == FIDO_USAGE:
vid, pid = get_vid_pid(device)
@ -331,19 +331,19 @@ def list_descriptors():
if not result:
break
detail_len = wintypes.DWORD()
dw_detail_len = wintypes.DWORD()
result = setupapi.SetupDiGetDeviceInterfaceDetailA(
collection,
ctypes.byref(interface_info),
None,
0,
ctypes.byref(detail_len),
ctypes.byref(dw_detail_len),
None,
)
if result:
raise ctypes.WinError()
raise WinError()
detail_len = detail_len.value
detail_len = dw_detail_len.value
if detail_len == 0:
# skip this device, some kind of error
continue
@ -361,7 +361,7 @@ def list_descriptors():
None,
)
if not result:
raise ctypes.WinError()
raise WinError()
path = ctypes.string_at(interface_detail.DevicePath)

View File

@ -103,7 +103,7 @@ def _ignore_attestation(attestation_object, client_data_hash):
def _default_attestations():
return [
cls()
cls() # type: ignore
for cls in Attestation.__subclasses__()
if getattr(cls, "FORMAT", "none") != "none"
]
@ -183,7 +183,7 @@ class Fido2Server:
self.timeout = None
self.attestation = AttestationConveyancePreference._wrap(attestation)
self.allowed_algorithms = [
PublicKeyCredentialParameters("public-key", alg)
PublicKeyCredentialParameters(PublicKeyCredentialType.PUBLIC_KEY, alg)
for alg in CoseKey.supported_algorithms()
]
self._verify_attestation = verify_attestation or _ignore_attestation

View File

@ -30,7 +30,7 @@ from .cose import CoseKey, ES256
from .utils import sha256, ByteBuffer
from enum import Enum, unique, IntFlag
from dataclasses import dataclass, fields, field as _field
from typing import Any, Mapping, Optional, Sequence, Tuple
from typing import Any, Mapping, Optional, Sequence, Tuple, cast
import re
import struct
@ -256,7 +256,7 @@ class AttestationObject(bytes): # , Mapping[str, Any]):
def __init__(self, _):
super().__init__()
data = cbor.decode(bytes(self))
data = cast(Mapping[str, Any], cbor.decode(bytes(self)))
self.fmt = data["fmt"]
self.auth_data = AuthenticatorData(data["authData"])
self.att_stmt = data["attStmt"]
@ -368,7 +368,10 @@ class _DataObject(Mapping[str, Any]):
self._keys.append(_snake2camel(f.name))
def __getitem__(self, key):
return getattr(self, _camel2snake(key))
try:
return getattr(self, _camel2snake(key))
except AttributeError as e:
raise KeyError(e)
def __iter__(self):
return iter(self._keys)

View File

@ -40,8 +40,14 @@ https://github.com/microsoft/webauthn
from enum import IntEnum, unique
from ctypes.wintypes import BOOL, DWORD, LONG, LPCWSTR, HWND
from threading import Thread
from typing import Mapping
import ctypes
from ctypes import WinDLL # type: ignore
from ctypes import LibraryLoader
windll = LibraryLoader(WinDLL)
PBYTE = ctypes.POINTER(ctypes.c_ubyte) # Different from wintypes.PBYTE, which is signed
@ -558,7 +564,7 @@ class WebAuthNCTAPTransport(IntEnum):
HRESULT = ctypes.HRESULT # type: ignore
WEBAUTHN = ctypes.windll.webauthn # type: ignore
WEBAUTHN = windll.webauthn # type: ignore
WEBAUTHN_API_VERSION = WEBAUTHN.WebAuthNGetApiVersionNumber()
# The following is derived from
# https://github.com/microsoft/webauthn/blob/master/webauthn.h#L37
@ -603,7 +609,7 @@ WEBAUTHN.WebAuthNGetErrorName.argtypes = [HRESULT]
WEBAUTHN.WebAuthNGetErrorName.restype = PCWSTR
WEBAUTHN_STRUCT_VERSIONS = {
WEBAUTHN_STRUCT_VERSIONS: Mapping[int, Mapping[str, int]] = {
1: {
"WebAuthNRpEntityInformation": 1,
"WebAuthNUserEntityInformation": 1,
@ -621,7 +627,7 @@ WEBAUTHN_STRUCT_VERSIONS = {
}
def get_version(class_name):
def get_version(class_name: str) -> int:
"""Get version of struct.
:param str class_name: Struct class name.
@ -663,7 +669,7 @@ class WinAPI:
version = WEBAUTHN_API_VERSION
def __init__(self, handle=None):
self.handle = handle or ctypes.windll.user32.GetForegroundWindow()
self.handle = handle or windll.user32.GetForegroundWindow()
def get_error_name(self, winerror):
"""Returns an error name given an error HRESULT value.

View File

@ -1,5 +1,6 @@
[mypy]
files = fido2/
check_untyped_defs = True
[mypy-smartcard.*]
ignore_missing_imports = True