mirror of https://github.com/Yubico/python-fido2
More mypy checks/fixes.
This commit is contained in:
parent
f12247ed51
commit
fed257922c
|
@ -49,6 +49,7 @@ from typing import Type, Any, Union, Callable, Optional, Mapping, Sequence
|
|||
|
||||
import json
|
||||
import platform
|
||||
import inspect
|
||||
|
||||
|
||||
class ClientData(bytes):
|
||||
|
@ -387,8 +388,10 @@ class Fido2ClientAssertionSelection(AssertionSelection):
|
|||
return extension_outputs
|
||||
|
||||
|
||||
def _default_extensions():
|
||||
return [cls for cls in Ctap2Extension.__subclasses__() if hasattr(cls, "NAME")]
|
||||
def _default_extensions() -> Sequence[Type[Ctap2Extension]]:
|
||||
return [
|
||||
cls for cls in Ctap2Extension.__subclasses__() if not inspect.isabstract(cls)
|
||||
]
|
||||
|
||||
|
||||
_CTAP1_INFO = Info.create(versions=["U2F_V2"], aaguid=b"\0" * 32)
|
||||
|
@ -410,7 +413,7 @@ class Fido2Client(_BaseClient):
|
|||
device: CtapDevice,
|
||||
origin: str,
|
||||
verify: Callable[[str, str], bool] = verify_rp_id,
|
||||
extension_types: Optional[Type[Ctap2Extension]] = None,
|
||||
extension_types: Sequence[Type[Ctap2Extension]] = [],
|
||||
):
|
||||
super(Fido2Client, self).__init__(origin, verify)
|
||||
|
||||
|
@ -420,9 +423,9 @@ class Fido2Client(_BaseClient):
|
|||
self.ctap2 = Ctap2(device)
|
||||
self.info = self.ctap2.info
|
||||
try:
|
||||
self.client_pin: Optional[ClientPin] = ClientPin(self.ctap2)
|
||||
self.client_pin: ClientPin = ClientPin(self.ctap2)
|
||||
except ValueError:
|
||||
self.client_pin = None
|
||||
self.client_pin = None # type: ignore
|
||||
self._do_make_credential = self._ctap2_make_credential
|
||||
self._do_get_assertion = self._ctap2_get_assertion
|
||||
except (ValueError, CtapError):
|
||||
|
@ -760,10 +763,7 @@ class Fido2Client(_BaseClient):
|
|||
pin_protocol, pin_auth, internal_uv = self._get_auth_params(
|
||||
client_data, rp_id, user_verification, pin, event, on_keepalive
|
||||
)
|
||||
if internal_uv:
|
||||
options = {"uv": True}
|
||||
else:
|
||||
options = None
|
||||
options = {"uv": True} if internal_uv else None
|
||||
|
||||
if allow_list:
|
||||
# Filter out credential IDs which are too long
|
||||
|
@ -966,5 +966,12 @@ class WindowsClient(_BaseClient):
|
|||
user = {"id": user_id} if user_id else None
|
||||
return AssertionSelection(
|
||||
client_data,
|
||||
[AssertionResponse.create(credential, auth_data, signature, user)],
|
||||
[
|
||||
AssertionResponse.create(
|
||||
credential=credential,
|
||||
auth_data=auth_data,
|
||||
signature=signature,
|
||||
user=user,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
|
@ -29,6 +29,7 @@ from .utils import bytes2int, int2bytes
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
|
||||
from typing import Sequence, Type
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
@ -101,9 +102,9 @@ class CoseKey(dict):
|
|||
def supported_algorithms():
|
||||
"""Get a list of all supported algorithm identifiers"""
|
||||
if ed25519:
|
||||
algs = (ES256, EdDSA, PS256, RS256)
|
||||
algs: Sequence[Type[CoseKey]] = [ES256, EdDSA, PS256, RS256]
|
||||
else:
|
||||
algs = (ES256, PS256, RS256)
|
||||
algs = [ES256, PS256, RS256]
|
||||
return [cls.ALGORITHM for cls in algs]
|
||||
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ class Ctap2Extension(abc.ABC):
|
|||
the extension.
|
||||
"""
|
||||
|
||||
NAME: str = None # type: ignore
|
||||
|
||||
def __init__(self, ctap):
|
||||
self.ctap = ctap
|
||||
|
||||
|
|
|
@ -163,7 +163,7 @@ class CtapHidDevice(CtapDevice):
|
|||
try:
|
||||
# Read response
|
||||
seq = 0
|
||||
response = None
|
||||
response = b""
|
||||
last_ka = None
|
||||
while True:
|
||||
if event.is_set():
|
||||
|
@ -182,11 +182,11 @@ class CtapHidDevice(CtapDevice):
|
|||
if r_channel != self._channel_id:
|
||||
raise Exception("Wrong channel")
|
||||
|
||||
if response is None: # Initialization packet
|
||||
if not response: # Initialization packet
|
||||
r_cmd, r_len = struct.unpack_from(">BH", recv)
|
||||
recv = recv[3:]
|
||||
if r_cmd == TYPE_INIT | cmd:
|
||||
response = b""
|
||||
pass # first data packet
|
||||
elif r_cmd == TYPE_INIT | CTAPHID.KEEPALIVE:
|
||||
ka_status = struct.unpack_from(">B", recv)[0]
|
||||
logger.debug("Got keepalive status: %02x" % ka_status)
|
||||
|
|
|
@ -24,6 +24,7 @@ import os
|
|||
from .base import HidDescriptor, parse_report_descriptor, FileCtapHidConnection
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -101,7 +102,7 @@ def _enumerate():
|
|||
if retval != 0:
|
||||
continue
|
||||
|
||||
dev = {}
|
||||
dev: Dict[str, Optional[str]] = {}
|
||||
dev["name"] = uhid[len(devdir) :]
|
||||
dev["path"] = uhid
|
||||
|
||||
|
|
|
@ -277,7 +277,7 @@ class MacCtapHidConnection(CtapHidConnection):
|
|||
raise OSError("Failed to open device for communication: {}".format(result))
|
||||
|
||||
# Create read queue
|
||||
self.read_queue = Queue()
|
||||
self.read_queue: Queue = Queue()
|
||||
|
||||
# Create and start read thread
|
||||
self.run_loop_ref = None
|
||||
|
|
|
@ -99,7 +99,7 @@ def get_descriptor(path):
|
|||
dev_info = UsbDeviceInfo()
|
||||
|
||||
try:
|
||||
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info)
|
||||
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info) # type: ignore
|
||||
finally:
|
||||
os.close(f)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from .base import HidDescriptor, CtapHidConnection, FIDO_USAGE_PAGE, FIDO_USAGE
|
|||
|
||||
import ctypes
|
||||
import platform
|
||||
from ctypes import WinDLL # type: ignore
|
||||
from ctypes import WinDLL, WinError # type: ignore
|
||||
from ctypes import wintypes, LibraryLoader
|
||||
|
||||
import logging
|
||||
|
@ -193,7 +193,7 @@ class WinCtapHidConnection(CtapHidConnection):
|
|||
None,
|
||||
)
|
||||
if self.handle == INVALID_HANDLE_VALUE:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
def close(self):
|
||||
kernel32.CloseHandle(self.handle)
|
||||
|
@ -205,7 +205,7 @@ class WinCtapHidConnection(CtapHidConnection):
|
|||
self.handle, out, len(out), ctypes.byref(num_written), None
|
||||
)
|
||||
if not ret:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
if num_written.value != len(out):
|
||||
raise OSError(
|
||||
"Failed to write complete packet. "
|
||||
|
@ -219,7 +219,7 @@ class WinCtapHidConnection(CtapHidConnection):
|
|||
self.handle, buf, len(buf), ctypes.byref(num_read), None
|
||||
)
|
||||
if not ret:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
if num_read.value != self.descriptor.report_size_in + 1:
|
||||
raise OSError("Failed to read full length report from device.")
|
||||
|
@ -231,7 +231,7 @@ def get_vid_pid(device):
|
|||
attributes = HidAttributes()
|
||||
result = hid.HidD_GetAttributes(device, ctypes.byref(attributes))
|
||||
if not result:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
return attributes.VendorID, attributes.ProductID
|
||||
|
||||
|
@ -269,19 +269,19 @@ def get_descriptor(path):
|
|||
None,
|
||||
)
|
||||
if device == INVALID_HANDLE_VALUE:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
try:
|
||||
preparsed_data = PHIDP_PREPARSED_DATA(0)
|
||||
ret = hid.HidD_GetPreparsedData(device, ctypes.byref(preparsed_data))
|
||||
if not ret:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
try:
|
||||
caps = HidCapabilities()
|
||||
ret = hid.HidP_GetCaps(preparsed_data, ctypes.byref(caps))
|
||||
|
||||
if ret != HIDP_STATUS_SUCCESS:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
if caps.UsagePage == FIDO_USAGE_PAGE and caps.Usage == FIDO_USAGE:
|
||||
vid, pid = get_vid_pid(device)
|
||||
|
@ -331,19 +331,19 @@ def list_descriptors():
|
|||
if not result:
|
||||
break
|
||||
|
||||
detail_len = wintypes.DWORD()
|
||||
dw_detail_len = wintypes.DWORD()
|
||||
result = setupapi.SetupDiGetDeviceInterfaceDetailA(
|
||||
collection,
|
||||
ctypes.byref(interface_info),
|
||||
None,
|
||||
0,
|
||||
ctypes.byref(detail_len),
|
||||
ctypes.byref(dw_detail_len),
|
||||
None,
|
||||
)
|
||||
if result:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
detail_len = detail_len.value
|
||||
detail_len = dw_detail_len.value
|
||||
if detail_len == 0:
|
||||
# skip this device, some kind of error
|
||||
continue
|
||||
|
@ -361,7 +361,7 @@ def list_descriptors():
|
|||
None,
|
||||
)
|
||||
if not result:
|
||||
raise ctypes.WinError()
|
||||
raise WinError()
|
||||
|
||||
path = ctypes.string_at(interface_detail.DevicePath)
|
||||
|
||||
|
|
|
@ -103,7 +103,7 @@ def _ignore_attestation(attestation_object, client_data_hash):
|
|||
|
||||
def _default_attestations():
|
||||
return [
|
||||
cls()
|
||||
cls() # type: ignore
|
||||
for cls in Attestation.__subclasses__()
|
||||
if getattr(cls, "FORMAT", "none") != "none"
|
||||
]
|
||||
|
@ -183,7 +183,7 @@ class Fido2Server:
|
|||
self.timeout = None
|
||||
self.attestation = AttestationConveyancePreference._wrap(attestation)
|
||||
self.allowed_algorithms = [
|
||||
PublicKeyCredentialParameters("public-key", alg)
|
||||
PublicKeyCredentialParameters(PublicKeyCredentialType.PUBLIC_KEY, alg)
|
||||
for alg in CoseKey.supported_algorithms()
|
||||
]
|
||||
self._verify_attestation = verify_attestation or _ignore_attestation
|
||||
|
|
|
@ -30,7 +30,7 @@ from .cose import CoseKey, ES256
|
|||
from .utils import sha256, ByteBuffer
|
||||
from enum import Enum, unique, IntFlag
|
||||
from dataclasses import dataclass, fields, field as _field
|
||||
from typing import Any, Mapping, Optional, Sequence, Tuple
|
||||
from typing import Any, Mapping, Optional, Sequence, Tuple, cast
|
||||
import re
|
||||
import struct
|
||||
|
||||
|
@ -256,7 +256,7 @@ class AttestationObject(bytes): # , Mapping[str, Any]):
|
|||
def __init__(self, _):
|
||||
super().__init__()
|
||||
|
||||
data = cbor.decode(bytes(self))
|
||||
data = cast(Mapping[str, Any], cbor.decode(bytes(self)))
|
||||
self.fmt = data["fmt"]
|
||||
self.auth_data = AuthenticatorData(data["authData"])
|
||||
self.att_stmt = data["attStmt"]
|
||||
|
@ -368,7 +368,10 @@ class _DataObject(Mapping[str, Any]):
|
|||
self._keys.append(_snake2camel(f.name))
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, _camel2snake(key))
|
||||
try:
|
||||
return getattr(self, _camel2snake(key))
|
||||
except AttributeError as e:
|
||||
raise KeyError(e)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._keys)
|
||||
|
|
|
@ -40,8 +40,14 @@ https://github.com/microsoft/webauthn
|
|||
from enum import IntEnum, unique
|
||||
from ctypes.wintypes import BOOL, DWORD, LONG, LPCWSTR, HWND
|
||||
from threading import Thread
|
||||
from typing import Mapping
|
||||
|
||||
import ctypes
|
||||
from ctypes import WinDLL # type: ignore
|
||||
from ctypes import LibraryLoader
|
||||
|
||||
|
||||
windll = LibraryLoader(WinDLL)
|
||||
|
||||
|
||||
PBYTE = ctypes.POINTER(ctypes.c_ubyte) # Different from wintypes.PBYTE, which is signed
|
||||
|
@ -558,7 +564,7 @@ class WebAuthNCTAPTransport(IntEnum):
|
|||
|
||||
|
||||
HRESULT = ctypes.HRESULT # type: ignore
|
||||
WEBAUTHN = ctypes.windll.webauthn # type: ignore
|
||||
WEBAUTHN = windll.webauthn # type: ignore
|
||||
WEBAUTHN_API_VERSION = WEBAUTHN.WebAuthNGetApiVersionNumber()
|
||||
# The following is derived from
|
||||
# https://github.com/microsoft/webauthn/blob/master/webauthn.h#L37
|
||||
|
@ -603,7 +609,7 @@ WEBAUTHN.WebAuthNGetErrorName.argtypes = [HRESULT]
|
|||
WEBAUTHN.WebAuthNGetErrorName.restype = PCWSTR
|
||||
|
||||
|
||||
WEBAUTHN_STRUCT_VERSIONS = {
|
||||
WEBAUTHN_STRUCT_VERSIONS: Mapping[int, Mapping[str, int]] = {
|
||||
1: {
|
||||
"WebAuthNRpEntityInformation": 1,
|
||||
"WebAuthNUserEntityInformation": 1,
|
||||
|
@ -621,7 +627,7 @@ WEBAUTHN_STRUCT_VERSIONS = {
|
|||
}
|
||||
|
||||
|
||||
def get_version(class_name):
|
||||
def get_version(class_name: str) -> int:
|
||||
"""Get version of struct.
|
||||
|
||||
:param str class_name: Struct class name.
|
||||
|
@ -663,7 +669,7 @@ class WinAPI:
|
|||
version = WEBAUTHN_API_VERSION
|
||||
|
||||
def __init__(self, handle=None):
|
||||
self.handle = handle or ctypes.windll.user32.GetForegroundWindow()
|
||||
self.handle = handle or windll.user32.GetForegroundWindow()
|
||||
|
||||
def get_error_name(self, winerror):
|
||||
"""Returns an error name given an error HRESULT value.
|
||||
|
|
Loading…
Reference in New Issue