More mypy checks/fixes.

This commit is contained in:
Dain Nilsson 2021-08-19 10:31:20 +02:00
parent f12247ed51
commit fed257922c
No known key found for this signature in database
GPG Key ID: F04367096FBA95E8
12 changed files with 61 additions and 40 deletions

View File

@ -49,6 +49,7 @@ from typing import Type, Any, Union, Callable, Optional, Mapping, Sequence
import json import json
import platform import platform
import inspect
class ClientData(bytes): class ClientData(bytes):
@ -387,8 +388,10 @@ class Fido2ClientAssertionSelection(AssertionSelection):
return extension_outputs return extension_outputs
def _default_extensions(): def _default_extensions() -> Sequence[Type[Ctap2Extension]]:
return [cls for cls in Ctap2Extension.__subclasses__() if hasattr(cls, "NAME")] return [
cls for cls in Ctap2Extension.__subclasses__() if not inspect.isabstract(cls)
]
_CTAP1_INFO = Info.create(versions=["U2F_V2"], aaguid=b"\0" * 32) _CTAP1_INFO = Info.create(versions=["U2F_V2"], aaguid=b"\0" * 32)
@ -410,7 +413,7 @@ class Fido2Client(_BaseClient):
device: CtapDevice, device: CtapDevice,
origin: str, origin: str,
verify: Callable[[str, str], bool] = verify_rp_id, verify: Callable[[str, str], bool] = verify_rp_id,
extension_types: Optional[Type[Ctap2Extension]] = None, extension_types: Sequence[Type[Ctap2Extension]] = [],
): ):
super(Fido2Client, self).__init__(origin, verify) super(Fido2Client, self).__init__(origin, verify)
@ -420,9 +423,9 @@ class Fido2Client(_BaseClient):
self.ctap2 = Ctap2(device) self.ctap2 = Ctap2(device)
self.info = self.ctap2.info self.info = self.ctap2.info
try: try:
self.client_pin: Optional[ClientPin] = ClientPin(self.ctap2) self.client_pin: ClientPin = ClientPin(self.ctap2)
except ValueError: except ValueError:
self.client_pin = None self.client_pin = None # type: ignore
self._do_make_credential = self._ctap2_make_credential self._do_make_credential = self._ctap2_make_credential
self._do_get_assertion = self._ctap2_get_assertion self._do_get_assertion = self._ctap2_get_assertion
except (ValueError, CtapError): except (ValueError, CtapError):
@ -760,10 +763,7 @@ class Fido2Client(_BaseClient):
pin_protocol, pin_auth, internal_uv = self._get_auth_params( pin_protocol, pin_auth, internal_uv = self._get_auth_params(
client_data, rp_id, user_verification, pin, event, on_keepalive client_data, rp_id, user_verification, pin, event, on_keepalive
) )
if internal_uv: options = {"uv": True} if internal_uv else None
options = {"uv": True}
else:
options = None
if allow_list: if allow_list:
# Filter out credential IDs which are too long # Filter out credential IDs which are too long
@ -966,5 +966,12 @@ class WindowsClient(_BaseClient):
user = {"id": user_id} if user_id else None user = {"id": user_id} if user_id else None
return AssertionSelection( return AssertionSelection(
client_data, client_data,
[AssertionResponse.create(credential, auth_data, signature, user)], [
AssertionResponse.create(
credential=credential,
auth_data=auth_data,
signature=signature,
user=user,
)
],
) )

View File

@ -29,6 +29,7 @@ from .utils import bytes2int, int2bytes
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
from typing import Sequence, Type
try: try:
from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives.asymmetric import ed25519
@ -101,9 +102,9 @@ class CoseKey(dict):
def supported_algorithms(): def supported_algorithms():
"""Get a list of all supported algorithm identifiers""" """Get a list of all supported algorithm identifiers"""
if ed25519: if ed25519:
algs = (ES256, EdDSA, PS256, RS256) algs: Sequence[Type[CoseKey]] = [ES256, EdDSA, PS256, RS256]
else: else:
algs = (ES256, PS256, RS256) algs = [ES256, PS256, RS256]
return [cls.ALGORITHM for cls in algs] return [cls.ALGORITHM for cls in algs]

View File

@ -36,6 +36,8 @@ class Ctap2Extension(abc.ABC):
the extension. the extension.
""" """
NAME: str = None # type: ignore
def __init__(self, ctap): def __init__(self, ctap):
self.ctap = ctap self.ctap = ctap

View File

@ -163,7 +163,7 @@ class CtapHidDevice(CtapDevice):
try: try:
# Read response # Read response
seq = 0 seq = 0
response = None response = b""
last_ka = None last_ka = None
while True: while True:
if event.is_set(): if event.is_set():
@ -182,11 +182,11 @@ class CtapHidDevice(CtapDevice):
if r_channel != self._channel_id: if r_channel != self._channel_id:
raise Exception("Wrong channel") raise Exception("Wrong channel")
if response is None: # Initialization packet if not response: # Initialization packet
r_cmd, r_len = struct.unpack_from(">BH", recv) r_cmd, r_len = struct.unpack_from(">BH", recv)
recv = recv[3:] recv = recv[3:]
if r_cmd == TYPE_INIT | cmd: if r_cmd == TYPE_INIT | cmd:
response = b"" pass # first data packet
elif r_cmd == TYPE_INIT | CTAPHID.KEEPALIVE: elif r_cmd == TYPE_INIT | CTAPHID.KEEPALIVE:
ka_status = struct.unpack_from(">B", recv)[0] ka_status = struct.unpack_from(">B", recv)[0]
logger.debug("Got keepalive status: %02x" % ka_status) logger.debug("Got keepalive status: %02x" % ka_status)

View File

@ -24,6 +24,7 @@ import os
from .base import HidDescriptor, parse_report_descriptor, FileCtapHidConnection from .base import HidDescriptor, parse_report_descriptor, FileCtapHidConnection
import logging import logging
from typing import Dict, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -101,7 +102,7 @@ def _enumerate():
if retval != 0: if retval != 0:
continue continue
dev = {} dev: Dict[str, Optional[str]] = {}
dev["name"] = uhid[len(devdir) :] dev["name"] = uhid[len(devdir) :]
dev["path"] = uhid dev["path"] = uhid

View File

@ -277,7 +277,7 @@ class MacCtapHidConnection(CtapHidConnection):
raise OSError("Failed to open device for communication: {}".format(result)) raise OSError("Failed to open device for communication: {}".format(result))
# Create read queue # Create read queue
self.read_queue = Queue() self.read_queue: Queue = Queue()
# Create and start read thread # Create and start read thread
self.run_loop_ref = None self.run_loop_ref = None

View File

@ -99,7 +99,7 @@ def get_descriptor(path):
dev_info = UsbDeviceInfo() dev_info = UsbDeviceInfo()
try: try:
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info) fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info) # type: ignore
finally: finally:
os.close(f) os.close(f)

View File

@ -19,7 +19,7 @@ from .base import HidDescriptor, CtapHidConnection, FIDO_USAGE_PAGE, FIDO_USAGE
import ctypes import ctypes
import platform import platform
from ctypes import WinDLL # type: ignore from ctypes import WinDLL, WinError # type: ignore
from ctypes import wintypes, LibraryLoader from ctypes import wintypes, LibraryLoader
import logging import logging
@ -193,7 +193,7 @@ class WinCtapHidConnection(CtapHidConnection):
None, None,
) )
if self.handle == INVALID_HANDLE_VALUE: if self.handle == INVALID_HANDLE_VALUE:
raise ctypes.WinError() raise WinError()
def close(self): def close(self):
kernel32.CloseHandle(self.handle) kernel32.CloseHandle(self.handle)
@ -205,7 +205,7 @@ class WinCtapHidConnection(CtapHidConnection):
self.handle, out, len(out), ctypes.byref(num_written), None self.handle, out, len(out), ctypes.byref(num_written), None
) )
if not ret: if not ret:
raise ctypes.WinError() raise WinError()
if num_written.value != len(out): if num_written.value != len(out):
raise OSError( raise OSError(
"Failed to write complete packet. " "Failed to write complete packet. "
@ -219,7 +219,7 @@ class WinCtapHidConnection(CtapHidConnection):
self.handle, buf, len(buf), ctypes.byref(num_read), None self.handle, buf, len(buf), ctypes.byref(num_read), None
) )
if not ret: if not ret:
raise ctypes.WinError() raise WinError()
if num_read.value != self.descriptor.report_size_in + 1: if num_read.value != self.descriptor.report_size_in + 1:
raise OSError("Failed to read full length report from device.") raise OSError("Failed to read full length report from device.")
@ -231,7 +231,7 @@ def get_vid_pid(device):
attributes = HidAttributes() attributes = HidAttributes()
result = hid.HidD_GetAttributes(device, ctypes.byref(attributes)) result = hid.HidD_GetAttributes(device, ctypes.byref(attributes))
if not result: if not result:
raise ctypes.WinError() raise WinError()
return attributes.VendorID, attributes.ProductID return attributes.VendorID, attributes.ProductID
@ -269,19 +269,19 @@ def get_descriptor(path):
None, None,
) )
if device == INVALID_HANDLE_VALUE: if device == INVALID_HANDLE_VALUE:
raise ctypes.WinError() raise WinError()
try: try:
preparsed_data = PHIDP_PREPARSED_DATA(0) preparsed_data = PHIDP_PREPARSED_DATA(0)
ret = hid.HidD_GetPreparsedData(device, ctypes.byref(preparsed_data)) ret = hid.HidD_GetPreparsedData(device, ctypes.byref(preparsed_data))
if not ret: if not ret:
raise ctypes.WinError() raise WinError()
try: try:
caps = HidCapabilities() caps = HidCapabilities()
ret = hid.HidP_GetCaps(preparsed_data, ctypes.byref(caps)) ret = hid.HidP_GetCaps(preparsed_data, ctypes.byref(caps))
if ret != HIDP_STATUS_SUCCESS: if ret != HIDP_STATUS_SUCCESS:
raise ctypes.WinError() raise WinError()
if caps.UsagePage == FIDO_USAGE_PAGE and caps.Usage == FIDO_USAGE: if caps.UsagePage == FIDO_USAGE_PAGE and caps.Usage == FIDO_USAGE:
vid, pid = get_vid_pid(device) vid, pid = get_vid_pid(device)
@ -331,19 +331,19 @@ def list_descriptors():
if not result: if not result:
break break
detail_len = wintypes.DWORD() dw_detail_len = wintypes.DWORD()
result = setupapi.SetupDiGetDeviceInterfaceDetailA( result = setupapi.SetupDiGetDeviceInterfaceDetailA(
collection, collection,
ctypes.byref(interface_info), ctypes.byref(interface_info),
None, None,
0, 0,
ctypes.byref(detail_len), ctypes.byref(dw_detail_len),
None, None,
) )
if result: if result:
raise ctypes.WinError() raise WinError()
detail_len = detail_len.value detail_len = dw_detail_len.value
if detail_len == 0: if detail_len == 0:
# skip this device, some kind of error # skip this device, some kind of error
continue continue
@ -361,7 +361,7 @@ def list_descriptors():
None, None,
) )
if not result: if not result:
raise ctypes.WinError() raise WinError()
path = ctypes.string_at(interface_detail.DevicePath) path = ctypes.string_at(interface_detail.DevicePath)

View File

@ -103,7 +103,7 @@ def _ignore_attestation(attestation_object, client_data_hash):
def _default_attestations(): def _default_attestations():
return [ return [
cls() cls() # type: ignore
for cls in Attestation.__subclasses__() for cls in Attestation.__subclasses__()
if getattr(cls, "FORMAT", "none") != "none" if getattr(cls, "FORMAT", "none") != "none"
] ]
@ -183,7 +183,7 @@ class Fido2Server:
self.timeout = None self.timeout = None
self.attestation = AttestationConveyancePreference._wrap(attestation) self.attestation = AttestationConveyancePreference._wrap(attestation)
self.allowed_algorithms = [ self.allowed_algorithms = [
PublicKeyCredentialParameters("public-key", alg) PublicKeyCredentialParameters(PublicKeyCredentialType.PUBLIC_KEY, alg)
for alg in CoseKey.supported_algorithms() for alg in CoseKey.supported_algorithms()
] ]
self._verify_attestation = verify_attestation or _ignore_attestation self._verify_attestation = verify_attestation or _ignore_attestation

View File

@ -30,7 +30,7 @@ from .cose import CoseKey, ES256
from .utils import sha256, ByteBuffer from .utils import sha256, ByteBuffer
from enum import Enum, unique, IntFlag from enum import Enum, unique, IntFlag
from dataclasses import dataclass, fields, field as _field from dataclasses import dataclass, fields, field as _field
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple, cast
import re import re
import struct import struct
@ -256,7 +256,7 @@ class AttestationObject(bytes): # , Mapping[str, Any]):
def __init__(self, _): def __init__(self, _):
super().__init__() super().__init__()
data = cbor.decode(bytes(self)) data = cast(Mapping[str, Any], cbor.decode(bytes(self)))
self.fmt = data["fmt"] self.fmt = data["fmt"]
self.auth_data = AuthenticatorData(data["authData"]) self.auth_data = AuthenticatorData(data["authData"])
self.att_stmt = data["attStmt"] self.att_stmt = data["attStmt"]
@ -368,7 +368,10 @@ class _DataObject(Mapping[str, Any]):
self._keys.append(_snake2camel(f.name)) self._keys.append(_snake2camel(f.name))
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, _camel2snake(key)) try:
return getattr(self, _camel2snake(key))
except AttributeError as e:
raise KeyError(e)
def __iter__(self): def __iter__(self):
return iter(self._keys) return iter(self._keys)

View File

@ -40,8 +40,14 @@ https://github.com/microsoft/webauthn
from enum import IntEnum, unique from enum import IntEnum, unique
from ctypes.wintypes import BOOL, DWORD, LONG, LPCWSTR, HWND from ctypes.wintypes import BOOL, DWORD, LONG, LPCWSTR, HWND
from threading import Thread from threading import Thread
from typing import Mapping
import ctypes import ctypes
from ctypes import WinDLL # type: ignore
from ctypes import LibraryLoader
windll = LibraryLoader(WinDLL)
PBYTE = ctypes.POINTER(ctypes.c_ubyte) # Different from wintypes.PBYTE, which is signed PBYTE = ctypes.POINTER(ctypes.c_ubyte) # Different from wintypes.PBYTE, which is signed
@ -558,7 +564,7 @@ class WebAuthNCTAPTransport(IntEnum):
HRESULT = ctypes.HRESULT # type: ignore HRESULT = ctypes.HRESULT # type: ignore
WEBAUTHN = ctypes.windll.webauthn # type: ignore WEBAUTHN = windll.webauthn # type: ignore
WEBAUTHN_API_VERSION = WEBAUTHN.WebAuthNGetApiVersionNumber() WEBAUTHN_API_VERSION = WEBAUTHN.WebAuthNGetApiVersionNumber()
# The following is derived from # The following is derived from
# https://github.com/microsoft/webauthn/blob/master/webauthn.h#L37 # https://github.com/microsoft/webauthn/blob/master/webauthn.h#L37
@ -603,7 +609,7 @@ WEBAUTHN.WebAuthNGetErrorName.argtypes = [HRESULT]
WEBAUTHN.WebAuthNGetErrorName.restype = PCWSTR WEBAUTHN.WebAuthNGetErrorName.restype = PCWSTR
WEBAUTHN_STRUCT_VERSIONS = { WEBAUTHN_STRUCT_VERSIONS: Mapping[int, Mapping[str, int]] = {
1: { 1: {
"WebAuthNRpEntityInformation": 1, "WebAuthNRpEntityInformation": 1,
"WebAuthNUserEntityInformation": 1, "WebAuthNUserEntityInformation": 1,
@ -621,7 +627,7 @@ WEBAUTHN_STRUCT_VERSIONS = {
} }
def get_version(class_name): def get_version(class_name: str) -> int:
"""Get version of struct. """Get version of struct.
:param str class_name: Struct class name. :param str class_name: Struct class name.
@ -663,7 +669,7 @@ class WinAPI:
version = WEBAUTHN_API_VERSION version = WEBAUTHN_API_VERSION
def __init__(self, handle=None): def __init__(self, handle=None):
self.handle = handle or ctypes.windll.user32.GetForegroundWindow() self.handle = handle or windll.user32.GetForegroundWindow()
def get_error_name(self, winerror): def get_error_name(self, winerror):
"""Returns an error name given an error HRESULT value. """Returns an error name given an error HRESULT value.

View File

@ -1,5 +1,6 @@
[mypy] [mypy]
files = fido2/ files = fido2/
check_untyped_defs = True
[mypy-smartcard.*] [mypy-smartcard.*]
ignore_missing_imports = True ignore_missing_imports = True