Skip to content

Commit

Permalink
feat(types): use strict type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Jul 13, 2024
1 parent 3aed632 commit 9854013
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 70 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ exclude_lines = [
]

[tool.mypy]
strict = true
python_version = "3.8"
files = ["src/joserfc"]
show_error_codes = true
Expand Down
4 changes: 2 additions & 2 deletions src/joserfc/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def serialize_json(
if registry is None:
registry = construct_registry(algorithms)

def find_key(obj: Any):
def find_key(obj: Any) -> Key:
return guess_key(private_key, obj, True)

_payload = to_bytes(payload)
Expand Down Expand Up @@ -271,7 +271,7 @@ def deserialize_json(
if registry is None:
registry = construct_registry(algorithms)

def find_key(obj: Any):
def find_key(obj: Any) -> Key:
return guess_key(public_key, obj)

if "signatures" in value:
Expand Down
3 changes: 2 additions & 1 deletion src/joserfc/rfc7516/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def extract_flattened_json(data: FlattenedJSONSerialization) -> FlattenedJSONEnc


def __extract_segments(
data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization]): # type: ignore[no-untyped-def]
data: t.Union[GeneralJSONSerialization, FlattenedJSONSerialization]
) -> t.Tuple[t.Dict[str, bytes], t.Dict[str, bytes], t.Optional[bytes]]:
base64_segments: t.Dict[str, bytes] = {
"iv": to_bytes(data["iv"]),
"ciphertext": to_bytes(data["ciphertext"]),
Expand Down
53 changes: 27 additions & 26 deletions src/joserfc/rfc7516/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
import os
import typing as t
from abc import ABCMeta, abstractmethod
from ..registry import Header, HeaderRegistryDict
from ..errors import InvalidKeyTypeError, InvalidKeyLengthError
from .._keys import Key, ECKey
from .._keys import Key, ECKey, OctKey

KeyType = t.TypeVar("KeyType")

Expand All @@ -12,8 +13,8 @@ class Recipient(t.Generic[KeyType]):
def __init__(
self,
parent: t.Union["CompactEncryption", "GeneralJSONEncryption", "FlattenedJSONEncryption"],
header: t.Optional[Header] = None,
recipient_key: t.Optional[KeyType] = None):
header: Header | None = None,
recipient_key: KeyType | None = None):
self.__parent = parent
self.header = header
self.recipient_key = recipient_key
Expand All @@ -30,35 +31,35 @@ def headers(self) -> Header:
rv.update(self.header)
return rv

def add_header(self, k: str, v: t.Any):
def add_header(self, k: str, v: t.Any) -> None:
if isinstance(self.__parent, CompactEncryption):
self.__parent.protected.update({k: v})
elif self.header:
self.header.update({k: v})
else:
self.header = {k: v}

def set_kid(self, kid: str):
def set_kid(self, kid: str) -> None:
self.add_header("kid", kid)


class CompactEncryption:
"""An object to represent the JWE Compact Serialization. It is usually returned by
``decrypt_compact`` method.
"""
def __init__(self, protected: Header, plaintext: t.Optional[bytes] = None):
def __init__(self, protected: Header, plaintext: bytes | None = None):
#: protected header in dict
self.protected = protected
#: the plaintext in bytes
self.plaintext = plaintext
self.recipient: t.Optional[Recipient] = None
self.recipient: Recipient[t.Any] | None = None
self.bytes_segments: t.Dict[str, bytes] = {} # store the decoded segments
self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments

def headers(self) -> Header:
return self.protected

def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
def attach_recipient(self, key: Key, header: Header | None = None) -> None:
"""Add a recipient to the JWE Compact Serialization. Please add a key that
comply with the given "alg" value.
Expand All @@ -71,7 +72,7 @@ def attach_recipient(self, key: Key, header: t.Optional[Header] = None):
self.recipient = recipient

@property
def recipients(self) -> t.List[Recipient]:
def recipients(self) -> list[Recipient[t.Any]]:
if self.recipient is not None:
return [self.recipient]
return []
Expand All @@ -89,14 +90,14 @@ class BaseJSONEncryption(metaclass=ABCMeta):
#: an optional additional authenticated data
aad: t.Optional[bytes]
#: a list of recipients
recipients: t.List[Recipient]
recipients: t.List[Recipient[t.Any]]

def __init__(
self,
protected: Header,
plaintext: t.Optional[bytes] = None,
unprotected: t.Optional[Header] = None,
aad: t.Optional[bytes] = None):
plaintext: bytes | None = None,
unprotected: Header | None = None,
aad: bytes | None = None):
self.protected = protected
self.plaintext = plaintext
self.unprotected = unprotected
Expand All @@ -106,7 +107,7 @@ def __init__(
self.base64_segments: t.Dict[str, bytes] = {} # store the encoded segments

@abstractmethod
def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None:
"""Add a recipient to the JWE JSON Serialization. Please add a key that
comply with the "alg" to this recipient.
Expand All @@ -131,7 +132,7 @@ class GeneralJSONEncryption(BaseJSONEncryption):
"""
flattened = False

def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None:
recipient = Recipient(self, header, key)
self.recipients.append(recipient)

Expand All @@ -152,7 +153,7 @@ class FlattenedJSONEncryption(BaseJSONEncryption):
"""
flattened = True

def add_recipient(self, header: t.Optional[Header] = None, key: t.Optional[Key] = None):
def add_recipient(self, header: Header | None = None, key: Key | None = None) -> None:
self.recipients = [Recipient(self, header, key)]


Expand All @@ -178,7 +179,7 @@ def check_iv(self, iv: bytes) -> bytes:
return iv

@abstractmethod
def encrypt(self, plaintext: bytes, cek: bytes, iv: bytes, aad: bytes) -> t.Tuple[bytes, bytes]:
def encrypt(self, plaintext: bytes, cek: bytes, iv: bytes, aad: bytes) -> tuple[bytes, bytes]:
pass

@abstractmethod
Expand Down Expand Up @@ -216,19 +217,19 @@ class KeyManagement:
def direct_mode(self) -> bool:
return self.key_size is None

def check_key_type(self, key: Key):
def check_key_type(self, key: Key) -> None:
if key.key_type not in self.key_types:
raise InvalidKeyTypeError()

def prepare_recipient_header(self, recipient: Recipient):
def prepare_recipient_header(self, recipient: Recipient[t.Any]) -> None:
raise NotImplementedError()


class JWEDirectEncryption(KeyManagement, metaclass=ABCMeta):
key_types = ["oct"]

@abstractmethod
def compute_cek(self, size: int, recipient: Recipient) -> bytes:
def compute_cek(self, size: int, recipient: Recipient[OctKey]) -> bytes:
pass


Expand All @@ -238,11 +239,11 @@ def direct_mode(self) -> bool:
return False

@abstractmethod
def encrypt_cek(self, cek: bytes, recipient: Recipient) -> bytes:
def encrypt_cek(self, cek: bytes, recipient: Recipient[t.Any]) -> bytes:
pass

@abstractmethod
def decrypt_cek(self, recipient: Recipient) -> bytes:
def decrypt_cek(self, recipient: Recipient[t.Any]) -> bytes:
pass


Expand All @@ -254,7 +255,7 @@ class JWEKeyWrapping(KeyManagement, metaclass=ABCMeta):
def direct_mode(self) -> bool:
return False

def check_op_key(self, op_key: bytes):
def check_op_key(self, op_key: bytes) -> None:
if len(op_key) * 8 != self.key_size:
raise InvalidKeyLengthError(f"A key of size {self.key_size} bits MUST be used")

Expand All @@ -267,11 +268,11 @@ def unwrap_cek(self, ek: bytes, key: bytes) -> bytes:
pass

@abstractmethod
def encrypt_cek(self, cek: bytes, recipient: Recipient) -> bytes:
def encrypt_cek(self, cek: bytes, recipient: Recipient[OctKey]) -> bytes:
pass

@abstractmethod
def decrypt_cek(self, recipient: Recipient) -> bytes:
def decrypt_cek(self, recipient: Recipient[OctKey]) -> bytes:
pass


Expand All @@ -280,7 +281,7 @@ class JWEKeyAgreement(KeyManagement, metaclass=ABCMeta):
tag_aware: bool = False
key_wrapping: t.Optional[JWEKeyWrapping]

def prepare_ephemeral_key(self, recipient: Recipient[ECKey]):
def prepare_ephemeral_key(self, recipient: Recipient[ECKey]) -> None:
recipient_key = recipient.recipient_key
assert recipient_key is not None
self.check_key_type(recipient_key)
Expand Down
20 changes: 12 additions & 8 deletions src/joserfc/rfc7516/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def __init__(
self.strict_check_header = strict_check_header

@classmethod
def register(cls, model: JWEAlgorithm):
def register(cls, model: JWEAlgorithm) -> None:
cls.algorithms[model.algorithm_location][model.name] = model # type: ignore
if model.recommended:
cls.recommended.append(model.name)

def check_header(self, header: Header, check_more=False):
def check_header(self, header: Header, check_more: bool = False) -> None:
"""Check and validate the fields in header part of a JWS object."""
check_crit_header(header)
validate_registry_header(self.header_registry, header)
Expand All @@ -77,24 +77,29 @@ def get_alg(self, name: str) -> JWEAlgModel:
:param name: value of the ``alg``, e.g. ``ECDH-ES``, ``A128KW``
"""
return self._get_algorithm("alg", name)
registry = self.algorithms["alg"]
self._check_algorithm(name, registry)
return registry[name]

def get_enc(self, name: str) -> JWEEncModel:
"""Get the allowed ("enc") algorithm instance of the given name.
:param name: value of the ``enc``, e.g. ``A128CBC-HS256``, ``A128GCM``
"""
return self._get_algorithm("enc", name)
registry = self.algorithms["enc"]
self._check_algorithm(name, registry)
return registry[name]

def get_zip(self, name: str) -> JWEZipModel:
"""Get the allowed ("zip") algorithm instance of the given name.
:param name: value of the ``zip``, e.g. ``DEF``
"""
return self._get_algorithm("zip", name)
registry = self.algorithms["zip"]
self._check_algorithm(name, registry)
return registry[name]

def _get_algorithm(self, location: t.Literal["alg", "enc", "zip"], name: str):
registry: t.Dict[str, JWEAlgorithm] = self.algorithms[location] # type: ignore
def _check_algorithm(self, name: str, registry: dict[str, t.Any]) -> None:
if name not in registry:
raise ValueError(f'Algorithm of "{name}" is not supported')

Expand All @@ -105,7 +110,6 @@ def _get_algorithm(self, location: t.Literal["alg", "enc", "zip"], name: str):

if name not in allowed:
raise ValueError(f'Algorithm of "{name}" is not allowed')
return registry[name]


default_registry = JWERegistry()
33 changes: 17 additions & 16 deletions src/joserfc/rfc7517/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import typing as t
from typing import overload
from collections.abc import KeysView
from abc import ABCMeta, abstractmethod
from .types import DictKey, AnyKey, KeyParameters
from ..registry import (
Expand Down Expand Up @@ -82,7 +83,7 @@ class BaseKey(t.Generic[NativePrivateKey, NativePublicKey], metaclass=ABCMeta):

def __init__(
self,
raw_value: t.Union[NativePrivateKey, NativePublicKey],
raw_value: NativePrivateKey | NativePublicKey,
original_value: t.Any,
parameters: t.Optional[KeyParameters] = None):
self._raw_value = raw_value
Expand All @@ -97,13 +98,13 @@ def __init__(
self.validate_dict_key(data)
self._dict_value = data

def keys(self):
def keys(self) -> KeysView[str]:
return self.dict_value.keys()

def __getitem__(self, k: str):
def __getitem__(self, k: str) -> str | list[str]:
return self.dict_value[k]

def get(self, k: str, default=None):
def get(self, k: str, default: str | None = None) -> str | list[str] | None:
return self.dict_value.get(k, default)

def ensure_kid(self) -> None:
Expand All @@ -114,17 +115,17 @@ def ensure_kid(self) -> None:
self._dict_value["kid"] = self.thumbprint()

@property
def kid(self) -> t.Optional[str]:
def kid(self) -> str | None:
"""The "kid" value of the JSON Web Key."""
return self.get("kid")
return t.cast(t.Optional[str], self.get("kid"))

@property
def alg(self) -> t.Optional[str]:
def alg(self) -> str | None:
"""The "alg" value of the JSON Web Key."""
return self.get("alg")
return t.cast(t.Optional[str], self.get("alg"))

@property
def raw_value(self):
def raw_value(self) -> t.Any:
raise NotImplementedError()

@property
Expand Down Expand Up @@ -220,13 +221,13 @@ def check_key_op(self, operation: str) -> None:
if reg.private and not self.is_private:
raise UnsupportedKeyOperationError(f'Invalid key_op "{operation}" for public key')

@overload
@t.overload
def get_op_key(self, operation: t.Literal["verify", "encrypt", "wrapKey", "deriveKey"]) -> NativePublicKey: ...

@overload
@t.overload
def get_op_key(self, operation: t.Literal["sign", "decrypt", "unwrapKey"]) -> NativePrivateKey: ...

def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKey]:
def get_op_key(self, operation: str) -> NativePublicKey | NativePrivateKey:
self.check_key_op(operation)
reg = self.operation_registry[operation]
if reg.private:
Expand All @@ -235,7 +236,7 @@ def get_op_key(self, operation: str) -> t.Union[NativePublicKey, NativePrivateKe
return self.public_key

@classmethod
def validate_dict_key(cls, data: DictKey):
def validate_dict_key(cls, data: DictKey) -> None:
cls.binding.validate_dict_key_registry(data, cls.param_registry)
cls.binding.validate_dict_key_registry(data, cls.value_registry)
cls.binding.validate_dict_key_use_operations(data)
Expand All @@ -257,7 +258,7 @@ def import_key(
@classmethod
def generate_key(
cls: t.Type[GenericKey],
size_or_crv,
size_or_crv: t.Any,
parameters: t.Optional[KeyParameters] = None,
private: bool = True,
auto_kid: bool = False) -> GenericKey:
Expand Down Expand Up @@ -312,5 +313,5 @@ def curve_name(self) -> str:
pass

@abstractmethod
def exchange_derive_key(self, key) -> bytes:
def exchange_derive_key(self, key: t.Any) -> bytes:
pass
Loading

0 comments on commit 9854013

Please sign in to comment.