Skip to content

Commit

Permalink
utils: fix type hints
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Larsson <who+github@cnackers.org>
  • Loading branch information
whooo committed Jan 18, 2024
1 parent 39c5651 commit b1da22e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 23 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ exclude = '''
[tool.mypy]
mypy_path = "mypy_stubs"
exclude = [
'src/tpm2_pytss/utils.py',
'src/tpm2_pytss/internal/templates.py',
'src/tpm2_pytss/encoding.py',
'src/tpm2_pytss/policy.py',
Expand Down
19 changes: 13 additions & 6 deletions src/tpm2_pytss/internal/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature
from typing import Tuple, Type
from typing import Tuple, Type, Union
import secrets
import sys

Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_curve(curveid):
return None


def _get_digest(digestid):
def _get_digest(digestid: TPM2_ALG):
for (algid, d) in _digesttable:
if algid == digestid:
return d
Expand Down Expand Up @@ -336,7 +336,14 @@ def _getname(obj):
return name


def _kdfa(hashAlg, key, label, contextU, contextV, bits):
def _kdfa(
hashAlg: "TPM2_ALG",
key: bytes,
label: bytes,
contextU: bytes,
contextV: bytes,
bits: int,
) -> bytes:
halg = _get_digest(hashAlg)
if halg is None:
raise ValueError(f"unsupported digest algorithm: {hashAlg}")
Expand Down Expand Up @@ -373,7 +380,7 @@ def kdfe(hashAlg, z, use, partyuinfo, partyvinfo, bits):
return kdf.derive(z)


def _symdef_to_crypt(symdef):
def _symdef_to_crypt(symdef: "TPMT_SYM_DEF"):
alg = _get_alg(symdef.algorithm)
if alg is None:
raise ValueError(f"unsupported symmetric algorithm {symdef.algorithm}")
Expand All @@ -394,7 +401,7 @@ def _calculate_sym_unique(nameAlg, secret, seed):
return d.finalize()


def _get_digest_size(alg):
def _get_digest_size(alg: TPM2_ALG) -> int:
dt = _get_digest(alg)
if dt is None:
raise ValueError(f"unsupported digest algorithm: {alg}")
Expand Down Expand Up @@ -591,7 +598,7 @@ def _secret_to_seed(
private: "types.TPMT_SENSITIVE",
public: "types.TPMT_PUBLIC",
label: bytes,
outsymseed: bytes,
outsymseed: Union[bytes, "types.TPM2B_SIMPLE_OBJECT"],
):
key = private_to_key(private, public)
if isinstance(key, rsa.RSAPrivateKey):
Expand Down
10 changes: 8 additions & 2 deletions src/tpm2_pytss/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@
TPM2_HR,
)
from typing import Union, Tuple, Optional, Dict, Any

try:
# assume mypy is running on python 3.11+
from typing import Self
except ImportError:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass
import sys

try:
Expand Down Expand Up @@ -223,7 +229,7 @@ def marshal(self) -> bytes:
return bytes(buf[0 : offset[0]])

@classmethod
def unmarshal(cls, buf: bytes):
def unmarshal(cls, buf: bytes) -> Tuple["Self", int]:
"""Unmarshal bytes into type instance.
Args:
Expand All @@ -247,7 +253,7 @@ class TPM2B_SIMPLE_OBJECT(TPM_OBJECT):
""" Abstract Base class for all TPM2B Simple Objects. A Simple object contains only
a size and byte buffer fields. This is not suitable for direct instantiation."""

def __init__(self, _cdata: Optional[ffi.CData] = None, **kwargs: Dict[str, Any]):
def __init__(self, _cdata: Optional[Union[ffi.CData, bytes]] = None, **kwargs: Any):

_cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs)
_bytefield = type(self)._get_bytefield()
Expand Down
31 changes: 17 additions & 14 deletions src/tpm2_pytss/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,22 @@
TPM2_ECC,
TPM2_PT,
TPM2_RH,
TPM2_ALG,
)
from .internal.templates import _ek
from .TSS2_Exception import TSS2_Exception
from cryptography.hazmat.primitives import constant_time as ct
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.backends import default_backend
from typing import Optional, Tuple, Callable, List
from typing import Optional, Tuple, Callable, List, Union

import secrets


def make_credential(
public: TPM2B_PUBLIC, credential: bytes, name: TPM2B_NAME
public: TPM2B_PUBLIC,
credential: Union[bytes, TPM2B_DIGEST],
name: Union[TPM2B_NAME, bytes],
) -> Tuple[TPM2B_ID_OBJECT, TPM2B_ENCRYPTED_SECRET]:
"""Encrypts credential for use with activate_credential
Expand Down Expand Up @@ -283,8 +286,8 @@ def unwrap(

# unwrap the inner encryption which is the integrity + TPM2B_SENSITIVE
innerint_and_decsens = _decrypt(cipher, mode, symkey, sensb)
innerint, offset = TPM2B_DIGEST.unmarshal(innerint_and_decsens)
innerint = bytes(innerint)
innerint2b, offset = TPM2B_DIGEST.unmarshal(innerint_and_decsens)
innerint = bytes(innerint2b)
decsensb = innerint_and_decsens[offset:]

h = hashes.Hash(halg(), backend=default_backend())
Expand Down Expand Up @@ -315,11 +318,11 @@ class NoSuchIndex(Exception):
index (int): The NV index requested
"""

def __init__(self, index):
def __init__(self, index: int):
self.index = index

def __str__(self):
return f"NV index 0x{index:08x} does not exist"
def __str__(self) -> str:
return f"NV index 0x{self.index:08x} does not exist"


class NVReadEK:
Expand All @@ -336,10 +339,10 @@ class NVReadEK:
def __init__(
self,
ectx: ESAPI,
auth_handle: ESYS_TR = None,
session1: ESYS_TR = ESYS_TR.PASSWORD,
session2: ESYS_TR = ESYS_TR.NONE,
session3: ESYS_TR = ESYS_TR.NONE,
auth_handle: Optional[ESYS_TR] = None,
session1: ESYS_TR = ESYS_TR(ESYS_TR.PASSWORD),
session2: ESYS_TR = ESYS_TR(ESYS_TR.NONE),
session3: ESYS_TR = ESYS_TR(ESYS_TR.NONE),
):
self._ectx = ectx
self._auth_handle = auth_handle
Expand All @@ -351,7 +354,7 @@ def __init__(
more = True
while more:
more, data = self._ectx.get_capability(
TPM2_CAP.TPM_PROPERTIES,
TPM2_CAP(TPM2_CAP.TPM_PROPERTIES),
TPM2_PT.FIXED,
4096,
session1=session2,
Expand All @@ -367,7 +370,7 @@ def __init__(
def __call__(self, index: Union[int, TPM2_RH]) -> bytes:
try:
nvh = self._ectx.tr_from_tpmpublic(
index, session1=self._session2, session2=self._session3
TPM2_HANDLE(index), session1=self._session2, session2=self._session3
)
except TSS2_Exception as e:
if e.rc == 0x18B:
Expand Down Expand Up @@ -399,7 +402,7 @@ def __call__(self, index: Union[int, TPM2_RH]) -> bytes:

def create_ek_template(
ektype: str, nv_read_cb: Callable[[Union[int, TPM2_RH]], bytes]
) -> Tuple[bytes, TPM2B_PUBLIC]:
) -> Tuple[Optional[bytes], TPM2B_PUBLIC]:
"""Creates an Endorsenment Key template which when created matches the EK certificate
The template is created according to TCG EK Credential Profile For TPM Family 2.0:
Expand Down

0 comments on commit b1da22e

Please sign in to comment.