From b1da22e9f7f26ff1b78c81394a10532f6ebb625b Mon Sep 17 00:00:00 2001 From: Erik Larsson Date: Thu, 18 Jan 2024 18:11:16 +0100 Subject: [PATCH] utils: fix type hints Signed-off-by: Erik Larsson --- pyproject.toml | 1 - src/tpm2_pytss/internal/crypto.py | 19 +++++++++++++------ src/tpm2_pytss/types.py | 10 ++++++++-- src/tpm2_pytss/utils.py | 31 +++++++++++++++++-------------- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73299146..81b63c7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ exclude = ''' [tool.mypy] mypy_path = "mypy_stubs" exclude = [ - 'src/tpm2_pytss/utils.py', 'src/tpm2_pytss/internal/templates.py', 'src/tpm2_pytss/encoding.py', 'src/tpm2_pytss/policy.py', diff --git a/src/tpm2_pytss/internal/crypto.py b/src/tpm2_pytss/internal/crypto.py index 93e51813..120d0b30 100644 --- a/src/tpm2_pytss/internal/crypto.py +++ b/src/tpm2_pytss/internal/crypto.py @@ -23,7 +23,7 @@ from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature -from typing import Tuple, Type +from typing import Tuple, Type, Union import secrets import sys @@ -77,7 +77,7 @@ def _get_curve(curveid): return None -def _get_digest(digestid): +def _get_digest(digestid: TPM2_ALG): for (algid, d) in _digesttable: if algid == digestid: return d @@ -336,7 +336,14 @@ def _getname(obj): return name -def _kdfa(hashAlg, key, label, contextU, contextV, bits): +def _kdfa( + hashAlg: "TPM2_ALG", + key: bytes, + label: bytes, + contextU: bytes, + contextV: bytes, + bits: int, +) -> bytes: halg = _get_digest(hashAlg) if halg is None: raise ValueError(f"unsupported digest algorithm: {hashAlg}") @@ -373,7 +380,7 @@ def kdfe(hashAlg, z, use, partyuinfo, partyvinfo, bits): return kdf.derive(z) -def _symdef_to_crypt(symdef): +def _symdef_to_crypt(symdef: "TPMT_SYM_DEF"): alg = _get_alg(symdef.algorithm) if alg is None: raise ValueError(f"unsupported symmetric algorithm {symdef.algorithm}") @@ -394,7 +401,7 @@ def _calculate_sym_unique(nameAlg, secret, seed): return d.finalize() -def _get_digest_size(alg): +def _get_digest_size(alg: TPM2_ALG) -> int: dt = _get_digest(alg) if dt is None: raise ValueError(f"unsupported digest algorithm: {alg}") @@ -591,7 +598,7 @@ def _secret_to_seed( private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC", label: bytes, - outsymseed: bytes, + outsymseed: Union[bytes, "types.TPM2B_SIMPLE_OBJECT"], ): key = private_to_key(private, public) if isinstance(key, rsa.RSAPrivateKey): diff --git a/src/tpm2_pytss/types.py b/src/tpm2_pytss/types.py index 828e12fe..73ea0d79 100644 --- a/src/tpm2_pytss/types.py +++ b/src/tpm2_pytss/types.py @@ -41,6 +41,12 @@ TPM2_HR, ) from typing import Union, Tuple, Optional, Dict, Any + +try: + # assume mypy is running on python 3.11+ + from typing import Self +except ImportError: + pass import sys try: @@ -223,7 +229,7 @@ def marshal(self) -> bytes: return bytes(buf[0 : offset[0]]) @classmethod - def unmarshal(cls, buf: bytes): + def unmarshal(cls, buf: bytes) -> Tuple["Self", int]: """Unmarshal bytes into type instance. Args: @@ -247,7 +253,7 @@ class TPM2B_SIMPLE_OBJECT(TPM_OBJECT): """ Abstract Base class for all TPM2B Simple Objects. A Simple object contains only a size and byte buffer fields. This is not suitable for direct instantiation.""" - def __init__(self, _cdata: Optional[ffi.CData] = None, **kwargs: Dict[str, Any]): + def __init__(self, _cdata: Optional[Union[ffi.CData, bytes]] = None, **kwargs: Any): _cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs) _bytefield = type(self)._get_bytefield() diff --git a/src/tpm2_pytss/utils.py b/src/tpm2_pytss/utils.py index 51dc34ee..1e1c8a2d 100644 --- a/src/tpm2_pytss/utils.py +++ b/src/tpm2_pytss/utils.py @@ -19,19 +19,22 @@ TPM2_ECC, TPM2_PT, TPM2_RH, + TPM2_ALG, ) from .internal.templates import _ek from .TSS2_Exception import TSS2_Exception from cryptography.hazmat.primitives import constant_time as ct from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend -from typing import Optional, Tuple, Callable, List +from typing import Optional, Tuple, Callable, List, Union import secrets def make_credential( - public: TPM2B_PUBLIC, credential: bytes, name: TPM2B_NAME + public: TPM2B_PUBLIC, + credential: Union[bytes, TPM2B_DIGEST], + name: Union[TPM2B_NAME, bytes], ) -> Tuple[TPM2B_ID_OBJECT, TPM2B_ENCRYPTED_SECRET]: """Encrypts credential for use with activate_credential @@ -283,8 +286,8 @@ def unwrap( # unwrap the inner encryption which is the integrity + TPM2B_SENSITIVE innerint_and_decsens = _decrypt(cipher, mode, symkey, sensb) - innerint, offset = TPM2B_DIGEST.unmarshal(innerint_and_decsens) - innerint = bytes(innerint) + innerint2b, offset = TPM2B_DIGEST.unmarshal(innerint_and_decsens) + innerint = bytes(innerint2b) decsensb = innerint_and_decsens[offset:] h = hashes.Hash(halg(), backend=default_backend()) @@ -315,11 +318,11 @@ class NoSuchIndex(Exception): index (int): The NV index requested """ - def __init__(self, index): + def __init__(self, index: int): self.index = index - def __str__(self): - return f"NV index 0x{index:08x} does not exist" + def __str__(self) -> str: + return f"NV index 0x{self.index:08x} does not exist" class NVReadEK: @@ -336,10 +339,10 @@ class NVReadEK: def __init__( self, ectx: ESAPI, - auth_handle: ESYS_TR = None, - session1: ESYS_TR = ESYS_TR.PASSWORD, - session2: ESYS_TR = ESYS_TR.NONE, - session3: ESYS_TR = ESYS_TR.NONE, + auth_handle: Optional[ESYS_TR] = None, + session1: ESYS_TR = ESYS_TR(ESYS_TR.PASSWORD), + session2: ESYS_TR = ESYS_TR(ESYS_TR.NONE), + session3: ESYS_TR = ESYS_TR(ESYS_TR.NONE), ): self._ectx = ectx self._auth_handle = auth_handle @@ -351,7 +354,7 @@ def __init__( more = True while more: more, data = self._ectx.get_capability( - TPM2_CAP.TPM_PROPERTIES, + TPM2_CAP(TPM2_CAP.TPM_PROPERTIES), TPM2_PT.FIXED, 4096, session1=session2, @@ -367,7 +370,7 @@ def __init__( def __call__(self, index: Union[int, TPM2_RH]) -> bytes: try: nvh = self._ectx.tr_from_tpmpublic( - index, session1=self._session2, session2=self._session3 + TPM2_HANDLE(index), session1=self._session2, session2=self._session3 ) except TSS2_Exception as e: if e.rc == 0x18B: @@ -399,7 +402,7 @@ def __call__(self, index: Union[int, TPM2_RH]) -> bytes: def create_ek_template( ektype: str, nv_read_cb: Callable[[Union[int, TPM2_RH]], bytes] -) -> Tuple[bytes, TPM2B_PUBLIC]: +) -> Tuple[Optional[bytes], TPM2B_PUBLIC]: """Creates an Endorsenment Key template which when created matches the EK certificate The template is created according to TCG EK Credential Profile For TPM Family 2.0: