mirror of https://github.com/Yubico/python-fido2
More mypy checks/fixes.
This commit is contained in:
parent
f12247ed51
commit
fed257922c
|
@ -49,6 +49,7 @@ from typing import Type, Any, Union, Callable, Optional, Mapping, Sequence
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import platform
|
import platform
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
class ClientData(bytes):
|
class ClientData(bytes):
|
||||||
|
@ -387,8 +388,10 @@ class Fido2ClientAssertionSelection(AssertionSelection):
|
||||||
return extension_outputs
|
return extension_outputs
|
||||||
|
|
||||||
|
|
||||||
def _default_extensions():
|
def _default_extensions() -> Sequence[Type[Ctap2Extension]]:
|
||||||
return [cls for cls in Ctap2Extension.__subclasses__() if hasattr(cls, "NAME")]
|
return [
|
||||||
|
cls for cls in Ctap2Extension.__subclasses__() if not inspect.isabstract(cls)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
_CTAP1_INFO = Info.create(versions=["U2F_V2"], aaguid=b"\0" * 32)
|
_CTAP1_INFO = Info.create(versions=["U2F_V2"], aaguid=b"\0" * 32)
|
||||||
|
@ -410,7 +413,7 @@ class Fido2Client(_BaseClient):
|
||||||
device: CtapDevice,
|
device: CtapDevice,
|
||||||
origin: str,
|
origin: str,
|
||||||
verify: Callable[[str, str], bool] = verify_rp_id,
|
verify: Callable[[str, str], bool] = verify_rp_id,
|
||||||
extension_types: Optional[Type[Ctap2Extension]] = None,
|
extension_types: Sequence[Type[Ctap2Extension]] = [],
|
||||||
):
|
):
|
||||||
super(Fido2Client, self).__init__(origin, verify)
|
super(Fido2Client, self).__init__(origin, verify)
|
||||||
|
|
||||||
|
@ -420,9 +423,9 @@ class Fido2Client(_BaseClient):
|
||||||
self.ctap2 = Ctap2(device)
|
self.ctap2 = Ctap2(device)
|
||||||
self.info = self.ctap2.info
|
self.info = self.ctap2.info
|
||||||
try:
|
try:
|
||||||
self.client_pin: Optional[ClientPin] = ClientPin(self.ctap2)
|
self.client_pin: ClientPin = ClientPin(self.ctap2)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
self.client_pin = None
|
self.client_pin = None # type: ignore
|
||||||
self._do_make_credential = self._ctap2_make_credential
|
self._do_make_credential = self._ctap2_make_credential
|
||||||
self._do_get_assertion = self._ctap2_get_assertion
|
self._do_get_assertion = self._ctap2_get_assertion
|
||||||
except (ValueError, CtapError):
|
except (ValueError, CtapError):
|
||||||
|
@ -760,10 +763,7 @@ class Fido2Client(_BaseClient):
|
||||||
pin_protocol, pin_auth, internal_uv = self._get_auth_params(
|
pin_protocol, pin_auth, internal_uv = self._get_auth_params(
|
||||||
client_data, rp_id, user_verification, pin, event, on_keepalive
|
client_data, rp_id, user_verification, pin, event, on_keepalive
|
||||||
)
|
)
|
||||||
if internal_uv:
|
options = {"uv": True} if internal_uv else None
|
||||||
options = {"uv": True}
|
|
||||||
else:
|
|
||||||
options = None
|
|
||||||
|
|
||||||
if allow_list:
|
if allow_list:
|
||||||
# Filter out credential IDs which are too long
|
# Filter out credential IDs which are too long
|
||||||
|
@ -966,5 +966,12 @@ class WindowsClient(_BaseClient):
|
||||||
user = {"id": user_id} if user_id else None
|
user = {"id": user_id} if user_id else None
|
||||||
return AssertionSelection(
|
return AssertionSelection(
|
||||||
client_data,
|
client_data,
|
||||||
[AssertionResponse.create(credential, auth_data, signature, user)],
|
[
|
||||||
|
AssertionResponse.create(
|
||||||
|
credential=credential,
|
||||||
|
auth_data=auth_data,
|
||||||
|
signature=signature,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from .utils import bytes2int, int2bytes
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
|
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
|
||||||
|
from typing import Sequence, Type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||||
|
@ -101,9 +102,9 @@ class CoseKey(dict):
|
||||||
def supported_algorithms():
|
def supported_algorithms():
|
||||||
"""Get a list of all supported algorithm identifiers"""
|
"""Get a list of all supported algorithm identifiers"""
|
||||||
if ed25519:
|
if ed25519:
|
||||||
algs = (ES256, EdDSA, PS256, RS256)
|
algs: Sequence[Type[CoseKey]] = [ES256, EdDSA, PS256, RS256]
|
||||||
else:
|
else:
|
||||||
algs = (ES256, PS256, RS256)
|
algs = [ES256, PS256, RS256]
|
||||||
return [cls.ALGORITHM for cls in algs]
|
return [cls.ALGORITHM for cls in algs]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,8 @@ class Ctap2Extension(abc.ABC):
|
||||||
the extension.
|
the extension.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
NAME: str = None # type: ignore
|
||||||
|
|
||||||
def __init__(self, ctap):
|
def __init__(self, ctap):
|
||||||
self.ctap = ctap
|
self.ctap = ctap
|
||||||
|
|
||||||
|
|
|
@ -163,7 +163,7 @@ class CtapHidDevice(CtapDevice):
|
||||||
try:
|
try:
|
||||||
# Read response
|
# Read response
|
||||||
seq = 0
|
seq = 0
|
||||||
response = None
|
response = b""
|
||||||
last_ka = None
|
last_ka = None
|
||||||
while True:
|
while True:
|
||||||
if event.is_set():
|
if event.is_set():
|
||||||
|
@ -182,11 +182,11 @@ class CtapHidDevice(CtapDevice):
|
||||||
if r_channel != self._channel_id:
|
if r_channel != self._channel_id:
|
||||||
raise Exception("Wrong channel")
|
raise Exception("Wrong channel")
|
||||||
|
|
||||||
if response is None: # Initialization packet
|
if not response: # Initialization packet
|
||||||
r_cmd, r_len = struct.unpack_from(">BH", recv)
|
r_cmd, r_len = struct.unpack_from(">BH", recv)
|
||||||
recv = recv[3:]
|
recv = recv[3:]
|
||||||
if r_cmd == TYPE_INIT | cmd:
|
if r_cmd == TYPE_INIT | cmd:
|
||||||
response = b""
|
pass # first data packet
|
||||||
elif r_cmd == TYPE_INIT | CTAPHID.KEEPALIVE:
|
elif r_cmd == TYPE_INIT | CTAPHID.KEEPALIVE:
|
||||||
ka_status = struct.unpack_from(">B", recv)[0]
|
ka_status = struct.unpack_from(">B", recv)[0]
|
||||||
logger.debug("Got keepalive status: %02x" % ka_status)
|
logger.debug("Got keepalive status: %02x" % ka_status)
|
||||||
|
|
|
@ -24,6 +24,7 @@ import os
|
||||||
from .base import HidDescriptor, parse_report_descriptor, FileCtapHidConnection
|
from .base import HidDescriptor, parse_report_descriptor, FileCtapHidConnection
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -101,7 +102,7 @@ def _enumerate():
|
||||||
if retval != 0:
|
if retval != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dev = {}
|
dev: Dict[str, Optional[str]] = {}
|
||||||
dev["name"] = uhid[len(devdir) :]
|
dev["name"] = uhid[len(devdir) :]
|
||||||
dev["path"] = uhid
|
dev["path"] = uhid
|
||||||
|
|
||||||
|
|
|
@ -277,7 +277,7 @@ class MacCtapHidConnection(CtapHidConnection):
|
||||||
raise OSError("Failed to open device for communication: {}".format(result))
|
raise OSError("Failed to open device for communication: {}".format(result))
|
||||||
|
|
||||||
# Create read queue
|
# Create read queue
|
||||||
self.read_queue = Queue()
|
self.read_queue: Queue = Queue()
|
||||||
|
|
||||||
# Create and start read thread
|
# Create and start read thread
|
||||||
self.run_loop_ref = None
|
self.run_loop_ref = None
|
||||||
|
|
|
@ -99,7 +99,7 @@ def get_descriptor(path):
|
||||||
dev_info = UsbDeviceInfo()
|
dev_info = UsbDeviceInfo()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info)
|
fcntl.ioctl(f, USB_GET_DEVICEINFO, dev_info) # type: ignore
|
||||||
finally:
|
finally:
|
||||||
os.close(f)
|
os.close(f)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from .base import HidDescriptor, CtapHidConnection, FIDO_USAGE_PAGE, FIDO_USAGE
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
import platform
|
import platform
|
||||||
from ctypes import WinDLL # type: ignore
|
from ctypes import WinDLL, WinError # type: ignore
|
||||||
from ctypes import wintypes, LibraryLoader
|
from ctypes import wintypes, LibraryLoader
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -193,7 +193,7 @@ class WinCtapHidConnection(CtapHidConnection):
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if self.handle == INVALID_HANDLE_VALUE:
|
if self.handle == INVALID_HANDLE_VALUE:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
kernel32.CloseHandle(self.handle)
|
kernel32.CloseHandle(self.handle)
|
||||||
|
@ -205,7 +205,7 @@ class WinCtapHidConnection(CtapHidConnection):
|
||||||
self.handle, out, len(out), ctypes.byref(num_written), None
|
self.handle, out, len(out), ctypes.byref(num_written), None
|
||||||
)
|
)
|
||||||
if not ret:
|
if not ret:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
if num_written.value != len(out):
|
if num_written.value != len(out):
|
||||||
raise OSError(
|
raise OSError(
|
||||||
"Failed to write complete packet. "
|
"Failed to write complete packet. "
|
||||||
|
@ -219,7 +219,7 @@ class WinCtapHidConnection(CtapHidConnection):
|
||||||
self.handle, buf, len(buf), ctypes.byref(num_read), None
|
self.handle, buf, len(buf), ctypes.byref(num_read), None
|
||||||
)
|
)
|
||||||
if not ret:
|
if not ret:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
if num_read.value != self.descriptor.report_size_in + 1:
|
if num_read.value != self.descriptor.report_size_in + 1:
|
||||||
raise OSError("Failed to read full length report from device.")
|
raise OSError("Failed to read full length report from device.")
|
||||||
|
@ -231,7 +231,7 @@ def get_vid_pid(device):
|
||||||
attributes = HidAttributes()
|
attributes = HidAttributes()
|
||||||
result = hid.HidD_GetAttributes(device, ctypes.byref(attributes))
|
result = hid.HidD_GetAttributes(device, ctypes.byref(attributes))
|
||||||
if not result:
|
if not result:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
return attributes.VendorID, attributes.ProductID
|
return attributes.VendorID, attributes.ProductID
|
||||||
|
|
||||||
|
@ -269,19 +269,19 @@ def get_descriptor(path):
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if device == INVALID_HANDLE_VALUE:
|
if device == INVALID_HANDLE_VALUE:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
try:
|
try:
|
||||||
preparsed_data = PHIDP_PREPARSED_DATA(0)
|
preparsed_data = PHIDP_PREPARSED_DATA(0)
|
||||||
ret = hid.HidD_GetPreparsedData(device, ctypes.byref(preparsed_data))
|
ret = hid.HidD_GetPreparsedData(device, ctypes.byref(preparsed_data))
|
||||||
if not ret:
|
if not ret:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
caps = HidCapabilities()
|
caps = HidCapabilities()
|
||||||
ret = hid.HidP_GetCaps(preparsed_data, ctypes.byref(caps))
|
ret = hid.HidP_GetCaps(preparsed_data, ctypes.byref(caps))
|
||||||
|
|
||||||
if ret != HIDP_STATUS_SUCCESS:
|
if ret != HIDP_STATUS_SUCCESS:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
if caps.UsagePage == FIDO_USAGE_PAGE and caps.Usage == FIDO_USAGE:
|
if caps.UsagePage == FIDO_USAGE_PAGE and caps.Usage == FIDO_USAGE:
|
||||||
vid, pid = get_vid_pid(device)
|
vid, pid = get_vid_pid(device)
|
||||||
|
@ -331,19 +331,19 @@ def list_descriptors():
|
||||||
if not result:
|
if not result:
|
||||||
break
|
break
|
||||||
|
|
||||||
detail_len = wintypes.DWORD()
|
dw_detail_len = wintypes.DWORD()
|
||||||
result = setupapi.SetupDiGetDeviceInterfaceDetailA(
|
result = setupapi.SetupDiGetDeviceInterfaceDetailA(
|
||||||
collection,
|
collection,
|
||||||
ctypes.byref(interface_info),
|
ctypes.byref(interface_info),
|
||||||
None,
|
None,
|
||||||
0,
|
0,
|
||||||
ctypes.byref(detail_len),
|
ctypes.byref(dw_detail_len),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
detail_len = detail_len.value
|
detail_len = dw_detail_len.value
|
||||||
if detail_len == 0:
|
if detail_len == 0:
|
||||||
# skip this device, some kind of error
|
# skip this device, some kind of error
|
||||||
continue
|
continue
|
||||||
|
@ -361,7 +361,7 @@ def list_descriptors():
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise ctypes.WinError()
|
raise WinError()
|
||||||
|
|
||||||
path = ctypes.string_at(interface_detail.DevicePath)
|
path = ctypes.string_at(interface_detail.DevicePath)
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,7 @@ def _ignore_attestation(attestation_object, client_data_hash):
|
||||||
|
|
||||||
def _default_attestations():
|
def _default_attestations():
|
||||||
return [
|
return [
|
||||||
cls()
|
cls() # type: ignore
|
||||||
for cls in Attestation.__subclasses__()
|
for cls in Attestation.__subclasses__()
|
||||||
if getattr(cls, "FORMAT", "none") != "none"
|
if getattr(cls, "FORMAT", "none") != "none"
|
||||||
]
|
]
|
||||||
|
@ -183,7 +183,7 @@ class Fido2Server:
|
||||||
self.timeout = None
|
self.timeout = None
|
||||||
self.attestation = AttestationConveyancePreference._wrap(attestation)
|
self.attestation = AttestationConveyancePreference._wrap(attestation)
|
||||||
self.allowed_algorithms = [
|
self.allowed_algorithms = [
|
||||||
PublicKeyCredentialParameters("public-key", alg)
|
PublicKeyCredentialParameters(PublicKeyCredentialType.PUBLIC_KEY, alg)
|
||||||
for alg in CoseKey.supported_algorithms()
|
for alg in CoseKey.supported_algorithms()
|
||||||
]
|
]
|
||||||
self._verify_attestation = verify_attestation or _ignore_attestation
|
self._verify_attestation = verify_attestation or _ignore_attestation
|
||||||
|
|
|
@ -30,7 +30,7 @@ from .cose import CoseKey, ES256
|
||||||
from .utils import sha256, ByteBuffer
|
from .utils import sha256, ByteBuffer
|
||||||
from enum import Enum, unique, IntFlag
|
from enum import Enum, unique, IntFlag
|
||||||
from dataclasses import dataclass, fields, field as _field
|
from dataclasses import dataclass, fields, field as _field
|
||||||
from typing import Any, Mapping, Optional, Sequence, Tuple
|
from typing import Any, Mapping, Optional, Sequence, Tuple, cast
|
||||||
import re
|
import re
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ class AttestationObject(bytes): # , Mapping[str, Any]):
|
||||||
def __init__(self, _):
|
def __init__(self, _):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
data = cbor.decode(bytes(self))
|
data = cast(Mapping[str, Any], cbor.decode(bytes(self)))
|
||||||
self.fmt = data["fmt"]
|
self.fmt = data["fmt"]
|
||||||
self.auth_data = AuthenticatorData(data["authData"])
|
self.auth_data = AuthenticatorData(data["authData"])
|
||||||
self.att_stmt = data["attStmt"]
|
self.att_stmt = data["attStmt"]
|
||||||
|
@ -368,7 +368,10 @@ class _DataObject(Mapping[str, Any]):
|
||||||
self._keys.append(_snake2camel(f.name))
|
self._keys.append(_snake2camel(f.name))
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return getattr(self, _camel2snake(key))
|
try:
|
||||||
|
return getattr(self, _camel2snake(key))
|
||||||
|
except AttributeError as e:
|
||||||
|
raise KeyError(e)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._keys)
|
return iter(self._keys)
|
||||||
|
|
|
@ -40,8 +40,14 @@ https://github.com/microsoft/webauthn
|
||||||
from enum import IntEnum, unique
|
from enum import IntEnum, unique
|
||||||
from ctypes.wintypes import BOOL, DWORD, LONG, LPCWSTR, HWND
|
from ctypes.wintypes import BOOL, DWORD, LONG, LPCWSTR, HWND
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
from ctypes import WinDLL # type: ignore
|
||||||
|
from ctypes import LibraryLoader
|
||||||
|
|
||||||
|
|
||||||
|
windll = LibraryLoader(WinDLL)
|
||||||
|
|
||||||
|
|
||||||
PBYTE = ctypes.POINTER(ctypes.c_ubyte) # Different from wintypes.PBYTE, which is signed
|
PBYTE = ctypes.POINTER(ctypes.c_ubyte) # Different from wintypes.PBYTE, which is signed
|
||||||
|
@ -558,7 +564,7 @@ class WebAuthNCTAPTransport(IntEnum):
|
||||||
|
|
||||||
|
|
||||||
HRESULT = ctypes.HRESULT # type: ignore
|
HRESULT = ctypes.HRESULT # type: ignore
|
||||||
WEBAUTHN = ctypes.windll.webauthn # type: ignore
|
WEBAUTHN = windll.webauthn # type: ignore
|
||||||
WEBAUTHN_API_VERSION = WEBAUTHN.WebAuthNGetApiVersionNumber()
|
WEBAUTHN_API_VERSION = WEBAUTHN.WebAuthNGetApiVersionNumber()
|
||||||
# The following is derived from
|
# The following is derived from
|
||||||
# https://github.com/microsoft/webauthn/blob/master/webauthn.h#L37
|
# https://github.com/microsoft/webauthn/blob/master/webauthn.h#L37
|
||||||
|
@ -603,7 +609,7 @@ WEBAUTHN.WebAuthNGetErrorName.argtypes = [HRESULT]
|
||||||
WEBAUTHN.WebAuthNGetErrorName.restype = PCWSTR
|
WEBAUTHN.WebAuthNGetErrorName.restype = PCWSTR
|
||||||
|
|
||||||
|
|
||||||
WEBAUTHN_STRUCT_VERSIONS = {
|
WEBAUTHN_STRUCT_VERSIONS: Mapping[int, Mapping[str, int]] = {
|
||||||
1: {
|
1: {
|
||||||
"WebAuthNRpEntityInformation": 1,
|
"WebAuthNRpEntityInformation": 1,
|
||||||
"WebAuthNUserEntityInformation": 1,
|
"WebAuthNUserEntityInformation": 1,
|
||||||
|
@ -621,7 +627,7 @@ WEBAUTHN_STRUCT_VERSIONS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_version(class_name):
|
def get_version(class_name: str) -> int:
|
||||||
"""Get version of struct.
|
"""Get version of struct.
|
||||||
|
|
||||||
:param str class_name: Struct class name.
|
:param str class_name: Struct class name.
|
||||||
|
@ -663,7 +669,7 @@ class WinAPI:
|
||||||
version = WEBAUTHN_API_VERSION
|
version = WEBAUTHN_API_VERSION
|
||||||
|
|
||||||
def __init__(self, handle=None):
|
def __init__(self, handle=None):
|
||||||
self.handle = handle or ctypes.windll.user32.GetForegroundWindow()
|
self.handle = handle or windll.user32.GetForegroundWindow()
|
||||||
|
|
||||||
def get_error_name(self, winerror):
|
def get_error_name(self, winerror):
|
||||||
"""Returns an error name given an error HRESULT value.
|
"""Returns an error name given an error HRESULT value.
|
||||||
|
|
Loading…
Reference in New Issue