diff --git a/.ci/run.sh b/.ci/run.sh index 785df5d2..4f85abd1 100755 --- a/.ci/run.sh +++ b/.ci/run.sh @@ -103,6 +103,11 @@ function run_style() { "${PYTHON}" -m black --diff --check "${SRC_ROOT}" } +function run_mypy() { + python3 -m pip install --user -e .[dev] + mypy --strict src/tpm2_pytss +} + if [ "x${TEST}" != "x" ]; then run_test elif [ "x${WHITESPACE}" != "x" ]; then @@ -111,4 +116,6 @@ elif [ "x${STYLE}" != "x" ]; then run_style elif [ "x${PUBLISH_PKG}" != "x" ]; then run_publish_pkg +elif [ "x${MYPY}" != "x" ]; then + run_mypy fi diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index de6fe93d..a475ea60 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,6 +67,28 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} files: /tmp/coverage.xml + mypy: + runs-on: ubuntu-20.04 + + steps: + - name: Checkout Repository + uses: actions/checkout@v2 + + - name: Set up Python 3.x + uses: actions/setup-python@v2 + with: + python-version: 3.x + + - name: Install dependencies + env: + TPM2_TSS_VERSION: 4.0.1 + run: ./.ci/install-deps.sh + + - name: Check + env: + MYPY: 1 + run: ./.ci/run.sh + whitespace-check: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index 396cfdbe..c27f1cf2 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ htmlcov /.pytest_cache/ src/tpm2_pytss/internal/type_mapping.py src/tpm2_pytss/internal/versions.py +src/tpm2_pytss/_libtpm2_pytss/lib.pyi diff --git a/mypy_stubs/_cffi_backend.pyi b/mypy_stubs/_cffi_backend.pyi new file mode 100644 index 00000000..e69de29b diff --git a/mypy_stubs/asn1crypto/core.pyi b/mypy_stubs/asn1crypto/core.pyi new file mode 100644 index 00000000..3ca5a5e6 --- /dev/null +++ b/mypy_stubs/asn1crypto/core.pyi @@ -0,0 +1,16 @@ +from typing import Optional, Any, Dict + +class Boolean: + _trailer: bytes +class ObjectIdentifier: + def __init__(self, oid: str): ... + native: Optional[str] +class Sequence: + def __setitem__(self, key: str, value: Any) -> None: ... + def __getitem__(self, key: str) -> Any: ... + def dump(self, force: bool = False) -> bytes: ... + @classmethod + def load(cls, encoded_data: bytes, strict: bool = False, **kwargs: Dict[str, Any]) -> "Sequence": ... + +class Integer: ... +class OctetString: ... diff --git a/mypy_stubs/asn1crypto/pem.pyi b/mypy_stubs/asn1crypto/pem.pyi new file mode 100644 index 00000000..efba0bbb --- /dev/null +++ b/mypy_stubs/asn1crypto/pem.pyi @@ -0,0 +1,4 @@ +from typing import Optional, Dict, Tuple + +def armor(type_name: str, der_bytes: bytes, headers: Optional[Dict[str, str]] = None) -> bytes: ... +def unarmor(pem_bytes: bytes , multiple: bool = False) -> Tuple[str, Dict[str, str], bytes]: ... diff --git a/pyproject.toml b/pyproject.toml index 80a15f9b..a865c414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=44", "wheel", "setuptools_scm[toml]>=3.4.3", "pycparser", "pkgconfig"] +requires = ["setuptools>=44", "wheel", "setuptools_scm[toml]>=3.4.3", "pycparser", "pkgconfig", "cffi"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] @@ -19,6 +19,13 @@ exclude = ''' | build | dist | esys_binding.py + | .*\.pyi$ ) ) ''' + +[tool.mypy] +mypy_path = "mypy_stubs" +exclude = [ + 'src/tpm2_pytss/encoding.py', +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 1b7d1313..6376b7f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,3 +55,4 @@ dev = myst-parser build installer + mypy diff --git a/setup.py b/setup.py index 1b5f5135..b5f80c51 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ Enum, ) from textwrap import dedent +from cffi import cparser # workaround bug https://github.com/pypa/pip/issues/7953 site.ENABLE_USER_SITE = "--user" in sys.argv[1:] @@ -282,8 +283,220 @@ def run(self): self.copy_file(vp, svp) +class type_hints_generator(type_generator): + is_int = set(("int",)) + callbacks = dict() + functions = dict() + + macro_types = ( + ("TPM2_ALG_", "TPM2_ALG"), + ("ESYS_TR_", "ESYS_TR"), + ("TPM2_ECC_", "TPM2_ECC"), + ("TPM2_RH_", "TPM2_RH"), + ("TPM2_SU_", "TPM2_SU"), + ("TPMA_OBJECT_", "TPMA_OBJECT"), + ("TPM2_CC_", "TPM2_CC"), + ("TPM2_SPEC_", "TPM2_SPEC"), + ("TPM2_GENERATED_", "TPM2_GENERATED"), + ("TPM2_RC_", "TPM2_RC"), + ("TSS2_RC_", "TSS2_RC"), + ("TPM2_EO_", "TPM2_EO"), + ("TPM2_ST_", "TPM2_ST"), + ("TPM2_SE_", "TPM2_SE"), + ("TPM2_CAP_", "TPM2_CAP"), + ("TPM_AT_", "TPM_AT"), + ("TPM2_PT_", "TPM2_PT"), + ("TPM2_PT_VENDOR_", "TPM2_PT_VENDOR"), + ("TPM2_PT_FIRMWARE_", "TPM2_PT_FIRMWARE"), + ("TPM2_PT_HR_", "TPM2_PT_HR"), + ("TPM2_PT_NV_", "TPM2_PT_NV"), + ("TPM2_PT_CONTEXT_", "TPM2_PT_CONTEXT"), + ("TPM2_PT_PS_", "TPM2_PT_PS"), + ("TPM2_PT_AUDIT_", "TPM2_PT_AUDIT"), + ("TPM2_PT_PCR_", "TPM2_PT_PCR"), + ("TPM2_PS_", "TPM2_PS"), + ("TPM2_HT_", "TPM2_HT"), + ("TPMA_SESSION_", "TPMA_SESSION"), + ("TPMA_LOCALITY_", "TPMA_LOCALITY"), + ("TPM2_NT_", "TPM2_NT"), + ("TPM2_HR_", "TPM2_HR"), + ("TPM2_HC_", "TPM2_HC"), + ("TPM2_CLOCK_", "TPM2_CLOCK"), + ("TPMA_NV_", "TPMA_NV"), + ("TPMA_CC_", "TPMA_CC"), + ("TPMA_ALGORITHM_", "TPMA_ALGORITHM"), + ("TPMA_PERMANENT_", "TPMA_PERMANENT"), + ("TPMA_STARTUP_", "TPMA_STARTUP"), + ("TPMA_MEMORY_", "TPMA_MEMORY"), + ("TPM2_MAX_", "TPM2_MAX"), + ("TPMA_MODES_", "TPMA_MODES"), + ) + + def macro_to_type(self, macro): + mt = "int" + ml = 0 + for prefix, tn in self.macro_types: + pl = len(prefix) + if macro.startswith(prefix) and pl > ml: + mt = tn + ml = pl + return mt + + def _make_callback_output(self, cname): + callback = self.callbacks[cname] + rt, args = callback + paramtypes = list() + for _, at in args: + paramtypes.append(at) + cbdef = f"Callable[[{', '.join(paramtypes)}], {rt}]" + return cbdef + + def build_function(self, d): + args = list() + for param in d.args.params: + pn = param.name + if pn is None: + # if the param doesn't have a name, ignore + continue + elif isinstance( + param.type, + (cparser.pycparser.c_ast.PtrDecl, cparser.pycparser.c_ast.ArrayDecl), + ): + ft = "CData" + if isinstance( + param.type.type, cparser.pycparser.c_ast.TypeDecl + ) and isinstance( + param.type.type.type, cparser.pycparser.c_ast.IdentifierType + ): + tn = ( + param.type.type.type.names[0] + if param.type.type.type.names + else None + ) + if tn in ("char", "uint8_t"): + ft = "CData | bytes" + elif isinstance( + param.type, cparser.pycparser.c_ast.TypeDecl + ) and isinstance(param.type.type, cparser.pycparser.c_ast.IdentifierType): + tn = param.type.type.names[0] + if tn in self.is_int: + ft = "int" + elif tn in self.callbacks: + ft = self._make_callback_output(tn) + else: + raise ValueError(f"unable to handle C type {param.type}") + args.append((pn, ft)) + rt = "CData" + if isinstance(d.type, cparser.pycparser.c_ast.TypeDecl) and isinstance( + d.type.type, cparser.pycparser.c_ast.IdentifierType + ): + tn = d.type.type.names[0] + if tn in self.is_int: + rt = "int" + elif tn == "void": + rt = "None" + else: + rt = "CData" + + return (rt, tuple(args)) + + def write_type_hints(self, macros): + output = dedent( + """ + # SPDX-License-Identifier: BSD-2 + # This file is generated during the build process. + from typing import Callable + from .ffi import CData + + # Defines + """ + ) + + mtl = [x for _, x in self.macro_types] + output += f"from ..constants import {', '.join(mtl)}\n" + + for m in macros: + mt = self.macro_to_type(m) + output += f'{m}: "{mt}"\n' + + output += "\n# Callback definitions\n" + for cname in self.callbacks: + cbdef = self._make_callback_output(cname) + output += f"{cname}: {cbdef}\n" + + output += "\n# Function definitions\n" + for fname, function in self.functions.items(): + rt, args = function + params = list() + for an, at in args: + if an == "in": + an = "in_" + params.append(f"{an}: {at}") + output += f"def {fname}({', '.join(params)}) -> {rt}:...\n" + + p = os.path.join(self.build_lib, "tpm2_pytss/_libtpm2_pytss/lib.pyi") + sp = os.path.join( + os.path.dirname(__file__), "src/tpm2_pytss/_libtpm2_pytss/lib.pyi" + ) + + if not self.dry_run: + self.mkpath(os.path.dirname(p)) + with open(p, "wt") as tf: + tf.seek(0) + tf.truncate(0) + tf.write(output) + + if self.inplace: + self.copy_file(p, sp) + + def run(self): + super().run() + + with open("libesys.h", "r") as sf: + cdata = sf.read() + + parser = cparser.Parser() + ast, macros, _ = parser._parse(cdata) + + for d in ast: + name = d.name + if isinstance(name, str) and name.startswith("__cffi_"): + # internal cffi stuff, ignore + continue + if isinstance(d, cparser.pycparser.c_ast.Typedef): + d = d.type + if isinstance( + d.type, + ( + cparser.pycparser.c_ast.Struct, + cparser.pycparser.c_ast.Union, + cparser.pycparser.c_ast.Enum, + ), + ): + # ignore unions, structs and enums + pass + elif isinstance(d.type, cparser.pycparser.c_ast.IdentifierType): + # check if type is int, otherwise ignore + if len(d.type.names) == 0: + break + tn = d.type.names[0] + if tn in self.is_int: + self.is_int.add(name) + elif isinstance(d, cparser.pycparser.c_ast.PtrDecl) and isinstance( + d.type, cparser.pycparser.c_ast.FuncDecl + ): + rt, args = self.build_function(d.type) + self.callbacks[name] = (rt, args) + elif isinstance(d, cparser.pycparser.c_ast.Decl) and isinstance( + d.type, cparser.pycparser.c_ast.FuncDecl + ): + rt, args = self.build_function(d.type) + self.functions[name] = (rt, args) + self.write_type_hints(macros) + + setup( use_scm_version=True, cffi_modules=["scripts/libtss2_build.py:ffibuilder"], - cmdclass={"build_ext": type_generator}, + cmdclass={"build_ext": type_hints_generator}, ) diff --git a/src/tpm2_pytss/ESAPI.py b/src/tpm2_pytss/ESAPI.py index 7cb88eab..fd029a40 100644 --- a/src/tpm2_pytss/ESAPI.py +++ b/src/tpm2_pytss/ESAPI.py @@ -10,15 +10,30 @@ ) from .TCTI import TCTI from .TCTILdr import TCTILdr +from ._libtpm2_pytss import ffi, lib -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union, Any, Type, Callable, Sequence + +try: + from typing import Self +except ImportError: + # assume mypy is running on python 3.11+ + pass +from types import TracebackType # Work around this FAPI dependency if FAPI is not present with the constant value _fapi_installed_ = _lib_version_atleast("tss2-fapi", "3.0.0") _DEFAULT_LOAD_BLOB_SELECTOR = FAPI_ESYSBLOB.CONTEXTLOAD if _fapi_installed_ else 1 -def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs): +def _get_cdata( + value: Any, + expected: Type["TPM_OBJECT"], + varname: str, + allow_none: bool = False, + *args: Any, + **kwargs: Any, +) -> ffi.CData: tname = expected.__name__ if value is None and allow_none: @@ -36,12 +51,12 @@ def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs): return value vname = type(value).__name__ - parse_method = getattr(expected, "parse", None) + parse_method: Optional[Callable[..., ffi.CData]] = getattr(expected, "parse", None) if isinstance(value, (bytes, str)) and issubclass(expected, TPM2B_SIMPLE_OBJECT): bo = expected(value) return bo._cdata elif isinstance(value, str) and parse_method and callable(parse_method): - return expected.parse(value, *args, **kwargs)._cdata + return parse_method(value, *args, **kwargs)._cdata elif issubclass(expected, TPML_OBJECT) and isinstance(value, list): return expected(value)._cdata elif not isinstance(value, expected): @@ -50,7 +65,9 @@ def _get_cdata(value, expected, varname, allow_none=False, *args, **kwargs): return value._cdata -def _check_handle_type(handle, varname, expected=None): +def _check_handle_type( + handle: ESYS_TR, varname: str, expected: Optional[Sequence[ESYS_TR]] = None +) -> None: if not isinstance(handle, ESYS_TR): raise TypeError(f"expected {varname} to be type ESYS_TR, got {type(handle)}") @@ -129,10 +146,12 @@ def __init__(self, tcti: Union[TCTI, str, None] = None): _chkrc(lib.Esys_Initialize(self._ctx_pp, tctx, ffi.NULL)) self._ctx = self._ctx_pp[0] - def __enter__(self): + def __enter__(self) -> "Self": return self - def __exit__(self, _type, value, traceback) -> None: + def __exit__( + self, _type: Type[Exception], value: Exception, traceback: TracebackType + ) -> None: self.close() # @@ -154,7 +173,7 @@ def close(self) -> None: self._ctx = ffi.NULL self._ctx_pp = ffi.NULL if self._did_load_tcti and self._tcti is not None: - self._tcti.close() + self._tcti.finalize() self._tcti = None def get_tcti(self) -> Optional[TCTI]: @@ -856,7 +875,7 @@ def load( def load_external( self, in_public: TPM2B_PUBLIC, - in_private: TPM2B_SENSITIVE = None, + in_private: Optional[TPM2B_SENSITIVE] = None, hierarchy: ESYS_TR = ESYS_TR.NULL, session1: ESYS_TR = ESYS_TR.NONE, session2: ESYS_TR = ESYS_TR.NONE, @@ -901,7 +920,7 @@ def load_external( in_public_cdata = _get_cdata(in_public, TPM2B_PUBLIC, "in_public") - hierarchy = ESAPI._fixup_hierarchy(hierarchy) + fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy) object_handle = ffi.new("ESYS_TR *") _chkrc( @@ -912,7 +931,7 @@ def load_external( session3, in_private_cdata, in_public_cdata, - hierarchy, + fixed_hierarchy, object_handle, ) ) @@ -2103,7 +2122,7 @@ def hash( _check_friendly_int(hash_alg, "hash_alg", TPM2_ALG) _check_friendly_int(hierarchy, "hierarchy", ESYS_TR) - hierarchy = ESAPI._fixup_hierarchy(hierarchy) + fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy) data_cdata = _get_cdata(data, TPM2B_MAX_BUFFER, "data") @@ -2117,7 +2136,7 @@ def hash( session3, data_cdata, hash_alg, - hierarchy, + fixed_hierarchy, out_hash, validation, ) @@ -2616,7 +2635,7 @@ def sequence_complete( _check_handle_type(session3, "session3") _check_friendly_int(hierarchy, "hierarchy", ESYS_TR) - hierarchy = ESAPI._fixup_hierarchy(hierarchy) + fixed_hierarchy = ESAPI._fixup_hierarchy(hierarchy) buffer_cdata = _get_cdata(buffer, TPM2B_MAX_BUFFER, "buffer", allow_none=True) @@ -2630,7 +2649,7 @@ def sequence_complete( session2, session3, buffer_cdata, - hierarchy, + fixed_hierarchy, result, validation, ) @@ -5107,7 +5126,7 @@ def hierarchy_control( "enable", expected=(ESYS_TR.ENDORSEMENT, ESYS_TR.OWNER, ESYS_TR.PLATFORM), ) - enable = ESAPI._fixup_hierarchy(enable) + fixed_enable = ESAPI._fixup_hierarchy(enable) if not isinstance(state, bool): raise TypeError(f"Expected state to be a bool, got {type(state)}") @@ -5118,7 +5137,13 @@ def hierarchy_control( _chkrc( lib.Esys_HierarchyControl( - self._ctx, auth_handle, session1, session2, session3, enable, state + self._ctx, + auth_handle, + session1, + session2, + session3, + fixed_enable, + state, ) ) @@ -5554,7 +5579,7 @@ def pp_commands( def set_algorithm_set( self, - algorithm_set: Union[List[int], int], + algorithm_set: int, auth_handle: ESYS_TR = ESYS_TR.PLATFORM, session1: ESYS_TR = ESYS_TR.PASSWORD, session2: ESYS_TR = ESYS_TR.NONE, @@ -5567,7 +5592,7 @@ def set_algorithm_set( available. Args: - algorithm_set (Union[List[int], int]): A TPM vendor-dependent value indicating the + algorithm_set (int): A TPM vendor-dependent value indicating the algorithm set selection. auth_handle (ESYS_TR): ESYS_TR.PLATFORM. Defaults to ESYS_TR.PLATFORM. session1 (ESYS_TR): A session for securing the TPM command (optional). Defaults to ESYS_TR.PASSWORD. @@ -6138,7 +6163,7 @@ def ac_get_capability( _check_friendly_int(capability, "capability", TPM_AT) if not isinstance(count, int): - raise TypeError(f"Expected count to be an int, got {type(prop)}") + raise TypeError(f"Expected count to be an int, got {type(count)}") _check_handle_type(ac, "ac") _check_handle_type(session1, "session1") @@ -6228,7 +6253,7 @@ def ac_send( ac_data_out, ) ) - return TPMS_AC_OUTPUT(_get_dptr(acDataOut, lib.Esys_Free)) + return TPMS_AC_OUTPUT(_get_dptr(ac_data_out, lib.Esys_Free)) def policy_ac_send_select( self, @@ -7169,7 +7194,7 @@ def tr_serialize(self, esys_handle: ESYS_TR) -> bytes: _chkrc(lib.Esys_TR_Serialize(self._ctx, esys_handle, buffer, buffer_size)) buffer_size = buffer_size[0] buffer = _get_dptr(buffer, lib.Esys_Free) - return bytes(ffi.buffer(buffer, buffer_size)) + return bytes(ffi.buffer(buffer, int(buffer_size))) def tr_deserialize(self, buffer: bytes) -> ESYS_TR: """Deserialization of an ESYS_TR from a byte buffer. @@ -7235,6 +7260,6 @@ def _fixup_hierarchy(hierarchy: ESYS_TR) -> Union[TPM2_RH, ESYS_TR]: "Expected hierarchy to be one of ESYS_TR.NULL, ESYS_TR.PLATFORM, ESYS_TR.OWNER, ESYS_TR.ENDORSMENT" ) - hierarchy = fixup_map[hierarchy] - - return hierarchy + return fixup_map[hierarchy] + else: + return hierarchy diff --git a/src/tpm2_pytss/FAPI.py b/src/tpm2_pytss/FAPI.py index 916eda62..8e4f477e 100644 --- a/src/tpm2_pytss/FAPI.py +++ b/src/tpm2_pytss/FAPI.py @@ -9,7 +9,18 @@ import logging import os import tempfile -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + List, + Optional, + Tuple, + Union, + Type, + Literal, + Dict, + cast, +) from ._libtpm2_pytss import ffi, lib from .fapi_info import FapiInfo @@ -18,6 +29,13 @@ from .types import TPM2B_PUBLIC, TPM2B_PRIVATE from .constants import TSS2_RC, TPM2_ALG from .TSS2_Exception import TSS2_Exception +from types import TracebackType + +try: + from typing import Self +except ImportError: + # assume mypy is running on python 3.11+ + pass logger = logging.getLogger(__name__) @@ -33,7 +51,12 @@ class FAPIConfig(contextlib.ExitStack): """Context to create a temporary Fapi environment.""" - def __init__(self, config: Optional[dict] = None, temp_dirs: bool = True, **kwargs): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + temp_dirs: bool = True, + **kwargs: Any, + ): f"""Create a temporary Fapi environment. Get the fapi_conf in this order: * `config` if given * File specified with environment variable `{FAPI_CONFIG_ENV}` if defined @@ -107,35 +130,46 @@ def __init__(self, config: Optional[dict] = None, temp_dirs: bool = True, **kwar self.config_env_backup = os.environ.get(FAPI_CONFIG_ENV, None) os.environ[FAPI_CONFIG_ENV] = self.config_tmp_path - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: super().__exit__(exc_type, exc_val, exc_tb) del os.environ[FAPI_CONFIG_ENV] if self.config_tmp_path is not None: os.unlink(self.config_tmp_path) + return False class _FAPI_CB_UDATA: - def __init__(self, cb, udata): - self.cur_exc = None + def __init__(self, cb: Optional[Callable[..., Any]], udata: Any): + self.cur_exc: Optional[Exception] = None self.udata = udata self.cb = cb @ffi.def_extern() -def _fapi_auth_callback(object_path, description, auth, user_data): +def _fapi_auth_callback( + object_path: ffi.CData, + description: ffi.CData, + auth: ffi.CData, + user_data: ffi.CData, +) -> int: cb_udata: _FAPI_CB_UDATA = ffi.from_handle(user_data) if not cb_udata.cb: return TSS2_RC.FAPI_RC_NOT_IMPLEMENTED try: - got_auth: Optional[bytes, str] = cb_udata.cb( + got_auth: Optional[Union[bytes, str]] = cb_udata.cb( ffi.string(object_path), ffi.string(description), cb_udata.udata ) - auth_bytes = got_auth.decode() if isinstance(got_auth, str) else got_auth + auth_bytes = got_auth.encode() if isinstance(got_auth, str) else got_auth auth[0] = _cffi_malloc("char[]", auth_bytes) except Exception as e: rc = e.rc if isinstance(e, TSS2_Exception) else TSS2_RC.FAPI_RC_NOT_IMPLEMENTED @@ -146,7 +180,9 @@ def _fapi_auth_callback(object_path, description, auth, user_data): @ffi.def_extern() -def _fapi_policy_action_callback(object_path, action, user_data): +def _fapi_policy_action_callback( + object_path: ffi.CData, action: ffi.CData, user_data: ffi.CData +) -> int: cb_udata: _FAPI_CB_UDATA = ffi.from_handle(user_data) if not cb_udata.cb: @@ -167,17 +203,17 @@ def _fapi_policy_action_callback(object_path, action, user_data): @ffi.def_extern() def _fapi_sign_callback( - object_path, - description, - publickey, - publickey_hint, - hashalg, - data_to_sign, - data_to_sign_size, - signature, - signature_size, - user_data, -): + object_path: ffi.CData, + description: ffi.CData, + publickey: ffi.CData, + publickey_hint: ffi.CData, + hashalg: int, + data_to_sign: ffi.CData, + data_to_sign_size: int, + signature: ffi.CData, + signature_size: ffi.CData, + user_data: ffi.CData, +) -> int: cb_udata: _FAPI_CB_UDATA = ffi.from_handle(user_data) if not cb_udata.cb: @@ -204,8 +240,13 @@ def _fapi_sign_callback( @ffi.def_extern() def _fapi_branch_callback( - object_path, description, branch_names, num_branches, selected_branch, user_data -): + object_path: ffi.CData, + description: ffi.CData, + branch_names: ffi.CData, + num_branches: int, + selected_branch: ffi.CData, + user_data: ffi.CData, +) -> int: cb_udata: _FAPI_CB_UDATA = ffi.from_handle(user_data) if not cb_udata.cb: @@ -213,7 +254,8 @@ def _fapi_branch_callback( try: branch_list = [ - ffi.string(x).decode() for x in ffi.unpack(branch_names, num_branches) + ffi.string(cast(ffi.CData, x)).decode() + for x in ffi.unpack(branch_names, num_branches) ] position: int = cb_udata.cb( @@ -240,25 +282,27 @@ def __init__(self, uri: Optional[Union[bytes, str]] = None): self.encoding = "utf-8" self._ctx_pp = ffi.new("FAPI_CONTEXT **") - uri = _to_bytes_or_null(uri) - ret = lib.Fapi_Initialize(self._ctx_pp, uri) + curi = _to_bytes_or_null(uri) + ret = lib.Fapi_Initialize(self._ctx_pp, curi) _chkrc(ret) - self._callback_metadata = {} + self._callback_metadata: Dict[str, Any] = {} @property - def _ctx(self): + def _ctx(self) -> ffi.CData: """Get the Feature API C context used by the library to hold state. Returns: The Feature API C context. """ - return self._ctx_pp[0] + return cast(ffi.CData, self._ctx_pp[0]) - def __enter__(self): + def __enter__(self) -> "Self": return self - def __exit__(self, _type, value, traceback): + def __exit__( + self, _type: Type[Exception], value: Exception, traceback: TracebackType + ) -> None: self.close() def close(self) -> None: @@ -267,7 +311,7 @@ def close(self) -> None: # TODO flesh out info class @property - def version(self): + def version(self) -> str: """ Get the tpm2-tss library version. @@ -275,15 +319,15 @@ def version(self): str: The Feature API C context. """ info = json.loads(self.get_info()) - return FapiInfo(info).version + return str(FapiInfo(info).version) @property - def config(self): # TODO doc, test + def config(self) -> FAPIConfig: # TODO doc, test info = json.loads(self.get_info()) - return FapiInfo(info).fapi_config + return cast(FAPIConfig, FapiInfo(info).fapi_config) @property - def tcti(self): # TODO doc, test + def tcti(self) -> TCTI: # TODO doc, test tcti = ffi.new("TSS2_TCTI_CONTEXT **") # returns the actual tcti context, not a copy (so no extra memory is allocated by the fapi) ret = lib.Fapi_GetTcti(self._ctx, tcti) @@ -312,11 +356,11 @@ def provision( Returns: bool: True if Fapi was provisioned, False otherwise. """ - auth_value_eh = _to_bytes_or_null(auth_value_eh) - auth_value_sh = _to_bytes_or_null(auth_value_sh) - auth_value_lockout = _to_bytes_or_null(auth_value_lockout, allow_null=False) + cauth_value_eh = _to_bytes_or_null(auth_value_eh) + cauth_value_sh = _to_bytes_or_null(auth_value_sh) + cauth_value_lockout = _to_bytes_or_null(auth_value_lockout, allow_null=False) ret = lib.Fapi_Provision( - self._ctx, auth_value_eh, auth_value_sh, auth_value_lockout + self._ctx, cauth_value_eh, cauth_value_sh, cauth_value_lockout ) _chkrc( ret, @@ -373,9 +417,9 @@ def list(self, search_path: Optional[Union[bytes, str]] = None) -> List[str]: Returns: List[str]: List of all current Fapi object paths. """ - search_path = _to_bytes_or_null(search_path, allow_null=False) + csearch_path = _to_bytes_or_null(search_path, allow_null=False) path_list = ffi.new("char **") - ret = lib.Fapi_List(self._ctx, search_path, path_list) + ret = lib.Fapi_List(self._ctx, csearch_path, path_list) _chkrc(ret) return ( ffi.string(_get_dptr(path_list, lib.Fapi_Free)) @@ -405,11 +449,11 @@ def create_key( Returns: bool: True if the key was created. False otherwise. """ - path = _to_bytes_or_null(path) - type_ = _to_bytes_or_null(type_) - policy_path = _to_bytes_or_null(policy_path) - auth_value = _to_bytes_or_null(auth_value) - ret = lib.Fapi_CreateKey(self._ctx, path, type_, policy_path, auth_value) + cpath = _to_bytes_or_null(path) + ctype_ = _to_bytes_or_null(type_) + cpolicy_path = _to_bytes_or_null(policy_path) + cauth_value = _to_bytes_or_null(auth_value) + ret = lib.Fapi_CreateKey(self._ctx, cpath, ctype_, cpolicy_path, cauth_value) _chkrc( ret, acceptable=lib.TSS2_FAPI_RC_PATH_ALREADY_EXISTS if exists_ok else None ) @@ -420,7 +464,7 @@ def sign( path: Union[bytes, str], digest: bytes, padding: Optional[Union[bytes, str]] = None, # TODO enum - ) -> Tuple[bytes, str, str]: + ) -> Tuple[bytes, bytes, bytes]: """Create a signature over a given digest. Args: @@ -432,11 +476,11 @@ def sign( TSS2_Exception: If Fapi returned an error code. Returns: - Tuple[bytes, str, str]: (signature (DER), public key (PEM), certificate (PEM)) + Tuple[bytes, bytes, bytes]: (signature (DER), public key (PEM), certificate (PEM)) """ - path = _to_bytes_or_null(path) - padding = _to_bytes_or_null(padding) # enum - digest = _to_bytes_or_null(digest) + cpath = _to_bytes_or_null(path) + cpadding = _to_bytes_or_null(padding) # enum + cdigest = _to_bytes_or_null(digest) signature = ffi.new("uint8_t **") signature_size = ffi.new("size_t *") public_key = ffi.new("char **") @@ -444,9 +488,9 @@ def sign( ret = lib.Fapi_Sign( self._ctx, - path, - padding, - digest, + cpath, + cpadding, + cdigest, len(digest), signature, signature_size, @@ -462,7 +506,7 @@ def sign( def verify_signature( self, path: Union[bytes, str], digest: bytes, signature: bytes - ): + ) -> None: """Verify a signature on a given digest. Args: @@ -473,9 +517,9 @@ def verify_signature( Raises: TSS2_Exception: If Fapi returned an error code, e.g. if the signature cannot be verified successfully. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) ret = lib.Fapi_VerifySignature( - self._ctx, path, digest, len(digest), signature, len(signature) + self._ctx, cpath, digest, len(digest), signature, len(signature) ) _chkrc(ret) @@ -499,12 +543,12 @@ def encrypt( backports=["2.4.7", "3.0.5", "3.1.1"], details="Faulty free of FAPI Encrypt might lead to Segmentation Fault. See https://github.com/tpm2-software/tpm2-tss/issues/2092", ) - path = _to_bytes_or_null(path) - plaintext = _to_bytes_or_null(plaintext) + cpath = _to_bytes_or_null(path) + cplaintext = _to_bytes_or_null(plaintext) ciphertext = ffi.new("uint8_t **") ciphertext_size = ffi.new("size_t *") ret = lib.Fapi_Encrypt( - self._ctx, path, plaintext, len(plaintext), ciphertext, ciphertext_size + self._ctx, cpath, cplaintext, len(plaintext), ciphertext, ciphertext_size ) _chkrc(ret) return bytes( @@ -524,11 +568,11 @@ def decrypt(self, path: Union[bytes, str], ciphertext: bytes) -> bytes: Returns: bytes: The plaintext. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) plaintext = ffi.new("uint8_t **") plaintext_size = ffi.new("size_t *") ret = lib.Fapi_Decrypt( - self._ctx, path, ciphertext, len(ciphertext), plaintext, plaintext_size + self._ctx, cpath, ciphertext, len(ciphertext), plaintext, plaintext_size ) _chkrc(ret) return bytes(ffi.unpack(plaintext[0], plaintext_size[0])) @@ -560,21 +604,22 @@ def create_seal( Returns: bool: True if the sealed object was created. False otherwise. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) if data is not None and size is not None: raise ValueError("Parameters data and size cannot be given at same time.") if data is None and size is None: raise ValueError("Either parameter data or parameter size must be given.") - if data is None: + data_len = 0 + if data is None and isinstance(size, int): data_len = size - else: + elif isinstance(data, (bytes, str)): data_len = len(data) - data = _to_bytes_or_null(data) - type_ = _to_bytes_or_null(type_) - policy_path = _to_bytes_or_null(policy_path) - auth_value = _to_bytes_or_null(auth_value) + cdata = _to_bytes_or_null(data) + ctype_ = _to_bytes_or_null(type_) + cpolicy_path = _to_bytes_or_null(policy_path) + cauth_value = _to_bytes_or_null(auth_value) ret = lib.Fapi_CreateSeal( - self._ctx, path, type_, data_len, policy_path, auth_value, data + self._ctx, cpath, ctype_, data_len, cpolicy_path, cauth_value, cdata ) _chkrc( ret, acceptable=lib.TSS2_FAPI_RC_PATH_ALREADY_EXISTS if exists_ok else None @@ -593,10 +638,10 @@ def unseal(self, path: Union[bytes, str]) -> bytes: Returns: bytes: The unsealed data in plaintext. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) data = ffi.new("uint8_t **") data_size = ffi.new("size_t *") - ret = lib.Fapi_Unseal(self._ctx, path, data, data_size) + ret = lib.Fapi_Unseal(self._ctx, cpath, data, data_size) _chkrc(ret) return bytes(ffi.unpack(_get_dptr(data, lib.Fapi_Free), data_size[0])) @@ -623,9 +668,9 @@ def import_object( fixed_in="3.2", details="FAPI Import will overwrite existing objects with same path silently. See https://github.com/tpm2-software/tpm2-tss/issues/2028", ) - path = _to_bytes_or_null(path) - import_data = _to_bytes_or_null(import_data) - ret = lib.Fapi_Import(self._ctx, path, import_data) + cpath = _to_bytes_or_null(path) + cimport_data = _to_bytes_or_null(import_data) + ret = lib.Fapi_Import(self._ctx, cpath, cimport_data) _chkrc( ret, acceptable=lib.TSS2_FAPI_RC_PATH_ALREADY_EXISTS if exists_ok else None ) @@ -640,8 +685,8 @@ def delete(self, path: Union[bytes, str]) -> None: Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - ret = lib.Fapi_Delete(self._ctx, path) + cpath = _to_bytes_or_null(path) + ret = lib.Fapi_Delete(self._ctx, cpath) _chkrc(ret) def change_auth( @@ -656,13 +701,13 @@ def change_auth( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - auth_value = _to_bytes_or_null(auth_value) - ret = lib.Fapi_ChangeAuth(self._ctx, path, auth_value) + cpath = _to_bytes_or_null(path) + cauth_value = _to_bytes_or_null(auth_value) + ret = lib.Fapi_ChangeAuth(self._ctx, cpath, cauth_value) _chkrc(ret) def export_key( - self, path: Union[bytes, str], new_path: Union[bytes, str] = None + self, path: Union[bytes, str], new_path: Optional[Union[bytes, str]] = None ) -> str: """Export a Fapi object as a JSON-encoded string. @@ -676,10 +721,10 @@ def export_key( Returns: str: The exported data. """ - path = _to_bytes_or_null(path) - new_path = _to_bytes_or_null(new_path) + cpath = _to_bytes_or_null(path) + cnew_path = _to_bytes_or_null(new_path) exported_data = ffi.new("char **") - ret = lib.Fapi_ExportKey(self._ctx, path, new_path, exported_data) + ret = lib.Fapi_ExportKey(self._ctx, cpath, cnew_path, exported_data) _chkrc(ret) return ffi.string(_get_dptr(exported_data, lib.Fapi_Free)).decode(self.encoding) @@ -695,12 +740,12 @@ def set_description( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - description = _to_bytes_or_null(description) - ret = lib.Fapi_SetDescription(self._ctx, path, description) + cpath = _to_bytes_or_null(path) + cdescription = _to_bytes_or_null(description) + ret = lib.Fapi_SetDescription(self._ctx, cpath, cdescription) _chkrc(ret) - def get_description(self, path: Union[bytes, str] = None) -> str: + def get_description(self, path: Optional[Union[bytes, str]] = None) -> str: """Get the description of a Fapi object. Args: @@ -712,9 +757,9 @@ def get_description(self, path: Union[bytes, str] = None) -> str: Returns: str: The description of the Fapi object. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) description = ffi.new("char **") - ret = lib.Fapi_GetDescription(self._ctx, path, description) + ret = lib.Fapi_GetDescription(self._ctx, cpath, description) _chkrc(ret) # description is guaranteed to be a null-terminated string return ffi.string(_get_dptr(description, lib.Fapi_Free)).decode() @@ -731,13 +776,13 @@ def set_app_data( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) if app_data is None: app_data_len = 0 else: app_data_len = len(app_data) - app_data = _to_bytes_or_null(app_data) - ret = lib.Fapi_SetAppData(self._ctx, path, app_data, app_data_len) + capp_data = _to_bytes_or_null(app_data) + ret = lib.Fapi_SetAppData(self._ctx, cpath, capp_data, app_data_len) _chkrc(ret) def get_app_data(self, path: Union[bytes, str]) -> Optional[bytes]: @@ -752,10 +797,10 @@ def get_app_data(self, path: Union[bytes, str]) -> Optional[bytes]: Returns: Optional[bytes]: The application data or None. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) app_data = ffi.new("uint8_t **") app_data_size = ffi.new("size_t *") - ret = lib.Fapi_GetAppData(self._ctx, path, app_data, app_data_size) + ret = lib.Fapi_GetAppData(self._ctx, cpath, app_data, app_data_size) _chkrc(ret) if app_data[0] == ffi.NULL: return None @@ -773,9 +818,9 @@ def set_certificate( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - certificate = _to_bytes_or_null(certificate) - ret = lib.Fapi_SetCertificate(self._ctx, path, certificate) + cpath = _to_bytes_or_null(path) + ccertificate = _to_bytes_or_null(certificate) + ret = lib.Fapi_SetCertificate(self._ctx, cpath, ccertificate) _chkrc(ret) def get_certificate(self, path: Union[bytes, str]) -> str: @@ -790,9 +835,9 @@ def get_certificate(self, path: Union[bytes, str]) -> str: Returns: bytes: The application data. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) certificate = ffi.new("char **") - ret = lib.Fapi_GetCertificate(self._ctx, path, certificate) + ret = lib.Fapi_GetCertificate(self._ctx, cpath, certificate) _chkrc(ret) # certificate is guaranteed to be a null-terminated string return ffi.string(_get_dptr(certificate, lib.Fapi_Free)).decode() @@ -826,7 +871,7 @@ def get_platform_certificates(self, no_cert_ok: bool = False) -> bytes: if no_cert_ok and ret == lib.TSS2_FAPI_RC_NO_CERT: return b"" return bytes( - ffi.unpack(_get_dptr(certificate, lib.Fapi_Free), certificates_size) + ffi.unpack(_get_dptr(certificate, lib.Fapi_Free), int(certificates_size)) ) def get_tpm_blobs( @@ -843,7 +888,7 @@ def get_tpm_blobs( Returns: Tuple[TPM2B_PUBLIC, TPM2B_PRIVATE, str]: (tpm_2b_public, tpm_2b_private, policy) """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) tpm_2b_public = ffi.new("uint8_t **") tpm_2b_public_size = ffi.new("size_t *") tpm_2b_private = ffi.new("uint8_t **") @@ -851,7 +896,7 @@ def get_tpm_blobs( policy = ffi.new("char **") ret = lib.Fapi_GetTpmBlobs( self._ctx, - path, + cpath, tpm_2b_public, tpm_2b_public_size, tpm_2b_private, @@ -892,11 +937,11 @@ def get_esys_blob(self, path: Union[bytes, str]) -> Tuple[bytes, Any]: Returns: Tuple[bytes, Any]: A tuple of the binary blob and its type (:const:`constants.FAPI_ESYSBLOB.CONTEXTLOAD` or :const:`constants.FAPI_ESYSBLOB.DESERIALIZE`) """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) type_ = ffi.new("uint8_t *") data = ffi.new("uint8_t **") length = ffi.new("size_t *") - ret = lib.Fapi_GetEsysBlob(self._ctx, path, type_, data, length) + ret = lib.Fapi_GetEsysBlob(self._ctx, cpath, type_, data, length) _chkrc(ret) return bytes(ffi.unpack(_get_dptr(data, lib.Fapi_Free), length[0])), type_[0] @@ -912,9 +957,9 @@ def export_policy(self, path: Union[bytes, str]) -> str: Returns: str: JSON-encoded policy. """ - path = _to_bytes_or_null(path) + cpath = _to_bytes_or_null(path) policy = ffi.new("char **") - ret = lib.Fapi_ExportPolicy(self._ctx, path, policy) + ret = lib.Fapi_ExportPolicy(self._ctx, cpath, policy) _chkrc(ret) return ffi.string(_get_dptr(policy, lib.Fapi_Free)).decode() @@ -923,7 +968,7 @@ def authorize_policy( policy_path: Union[bytes, str], key_path: Union[bytes, str], policy_ref: Optional[Union[bytes, str]] = None, - ): + ) -> None: """Specify the underlying policy/policies for a policy Authorize. Args: @@ -934,15 +979,15 @@ def authorize_policy( Raises: TSS2_Exception: If Fapi returned an error code. """ - policy_path = _to_bytes_or_null(policy_path) - key_path = _to_bytes_or_null(key_path) + cpolicy_path = _to_bytes_or_null(policy_path) + ckey_path = _to_bytes_or_null(key_path) if policy_ref is None: policy_ref_len = 0 else: policy_ref_len = len(policy_ref) - policy_ref = _to_bytes_or_null(policy_ref) + cpolicy_ref = _to_bytes_or_null(policy_ref) ret = lib.Fapi_AuthorizePolicy( - self._ctx, policy_path, key_path, policy_ref, policy_ref_len + self._ctx, cpolicy_path, ckey_path, cpolicy_ref, policy_ref_len ) _chkrc(ret) @@ -993,9 +1038,9 @@ def pcr_extend( Tuple[bytes, str]: PCR value and its associated event log. """ # TODO "extend", formula in doc - log = _to_bytes_or_null(log) - data = _to_bytes_or_null(data) - ret = lib.Fapi_PcrExtend(self._ctx, index, data, len(data), log) + clog = _to_bytes_or_null(log) + cdata = _to_bytes_or_null(data) + ret = lib.Fapi_PcrExtend(self._ctx, index, cdata, len(data), clog) _chkrc(ret) def quote( @@ -1025,13 +1070,14 @@ def quote( details="Multiple calls of FAPI Quote might lead to TPM out of memory errors. See https://github.com/tpm2-software/tpm2-tss/issues/2084", ) - path = _to_bytes_or_null(path) - quote_type = _to_bytes_or_null(quote_type) + cpcrs = ffi.new("uint32_t []", pcrs) + cpath = _to_bytes_or_null(path) + cquote_type = _to_bytes_or_null(quote_type) if qualifying_data is None: qualifying_data_len = 0 else: qualifying_data_len = len(qualifying_data) - qualifying_data = _to_bytes_or_null(qualifying_data) + cqualifying_data = _to_bytes_or_null(qualifying_data) quote_info = ffi.new("char **") signature = ffi.new("uint8_t **") @@ -1040,11 +1086,11 @@ def quote( certificate = ffi.new("char **") ret = lib.Fapi_Quote( self._ctx, - pcrs, + cpcrs, len(pcrs), - path, - quote_type, - qualifying_data, + cpath, + cquote_type, + cqualifying_data, qualifying_data_len, quote_info, signature, @@ -1069,7 +1115,7 @@ def verify_quote( quote_info: Union[bytes, str], qualifying_data: Optional[Union[bytes, str]] = None, pcr_log: Optional[Union[bytes, str]] = None, - ): + ) -> None: """Verify the signature to a TPM quote. Args: @@ -1082,24 +1128,24 @@ def verify_quote( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - signature = _to_bytes_or_null(signature) + cpath = _to_bytes_or_null(path) + csignature = _to_bytes_or_null(signature) if qualifying_data is None: qualifying_data_len = 0 else: qualifying_data_len = len(qualifying_data) - qualifying_data = _to_bytes_or_null(qualifying_data) - quote_info = _to_bytes_or_null(quote_info) - pcr_log = _to_bytes_or_null(pcr_log) + cqualifying_data = _to_bytes_or_null(qualifying_data) + cquote_info = _to_bytes_or_null(quote_info) + cpcr_log = _to_bytes_or_null(pcr_log) ret = lib.Fapi_VerifyQuote( self._ctx, - path, - qualifying_data, + cpath, + cqualifying_data, qualifying_data_len, - quote_info, - signature, + cquote_info, + csignature, len(signature), - pcr_log, + cpcr_log, ) _chkrc(ret) @@ -1123,11 +1169,13 @@ def create_nv( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - type_ = _to_bytes_or_null(type_) - policy_path = _to_bytes_or_null(policy_path) - auth_value = _to_bytes_or_null(auth_value) - ret = lib.Fapi_CreateNv(self._ctx, path, type_, size, policy_path, auth_value) + cpath = _to_bytes_or_null(path) + ctype_ = _to_bytes_or_null(type_) + cpolicy_path = _to_bytes_or_null(policy_path) + cauth_value = _to_bytes_or_null(auth_value) + ret = lib.Fapi_CreateNv( + self._ctx, cpath, ctype_, size, cpolicy_path, cauth_value + ) _chkrc(ret) def nv_read(self, path: Union[bytes, str]) -> Tuple[bytes, str]: @@ -1142,15 +1190,15 @@ def nv_read(self, path: Union[bytes, str]) -> Tuple[bytes, str]: Returns: Tuple[bytes, str]: Data stored in the NV storage area and its associated event log. """ - path = _to_bytes_or_null(path) - data = ffi.new("uint8_t **") - data_size = ffi.new("size_t *") - log = ffi.new("char **") - ret = lib.Fapi_NvRead(self._ctx, path, data, data_size, log) + cpath = _to_bytes_or_null(path) + cdata = ffi.new("uint8_t **") + cdata_size = ffi.new("size_t *") + clog = ffi.new("char **") + ret = lib.Fapi_NvRead(self._ctx, cpath, cdata, cdata_size, clog) _chkrc(ret) return ( - bytes(ffi.unpack(_get_dptr(data, lib.Fapi_Free), data_size[0])), - ffi.string(_get_dptr(log, lib.Fapi_Free)).decode(), + bytes(ffi.unpack(_get_dptr(cdata, lib.Fapi_Free), cdata_size[0])), + ffi.string(_get_dptr(clog, lib.Fapi_Free)).decode(), ) def nv_write(self, path: Union[bytes, str], data: Union[bytes, str]) -> None: @@ -1163,9 +1211,9 @@ def nv_write(self, path: Union[bytes, str], data: Union[bytes, str]) -> None: Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - data = _to_bytes_or_null(data) - ret = lib.Fapi_NvWrite(self._ctx, path, data, len(data)) + cpath = _to_bytes_or_null(path) + cdata = _to_bytes_or_null(data) + ret = lib.Fapi_NvWrite(self._ctx, cpath, cdata, len(data)) _chkrc(ret) def nv_extend( @@ -1186,10 +1234,10 @@ def nv_extend( Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - data = _to_bytes_or_null(data) - log = _to_bytes_or_null(log) - ret = lib.Fapi_NvExtend(self._ctx, path, data, len(data), log) + cpath = _to_bytes_or_null(path) + cdata = _to_bytes_or_null(data) + clog = _to_bytes_or_null(log) + ret = lib.Fapi_NvExtend(self._ctx, cpath, cdata, len(data), clog) _chkrc(ret) def nv_increment(self, path: Union[bytes, str]) -> None: @@ -1201,8 +1249,8 @@ def nv_increment(self, path: Union[bytes, str]) -> None: Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - ret = lib.Fapi_NvIncrement(self._ctx, path) + cpath = _to_bytes_or_null(path) + ret = lib.Fapi_NvIncrement(self._ctx, cpath) _chkrc(ret) def nv_set_bits(self, path: Union[bytes, str], bitmap: int) -> None: @@ -1215,8 +1263,8 @@ def nv_set_bits(self, path: Union[bytes, str], bitmap: int) -> None: Raises: TSS2_Exception: If Fapi returned an error code. """ - path = _to_bytes_or_null(path) - ret = lib.Fapi_NvSetBits(self._ctx, path, bitmap) + cpath = _to_bytes_or_null(path) + ret = lib.Fapi_NvSetBits(self._ctx, cpath, bitmap) _chkrc(ret) def write_authorize_nv( @@ -1231,9 +1279,9 @@ def write_authorize_nv( Raises: TSS2_Exception: If Fapi returned an error code. """ - nv_path = _to_bytes_or_null(nv_path) - policy_path = _to_bytes_or_null(policy_path) - ret = lib.Fapi_WriteAuthorizeNv(self._ctx, nv_path, policy_path) + cnv_path = _to_bytes_or_null(nv_path) + cpolicy_path = _to_bytes_or_null(policy_path) + ret = lib.Fapi_WriteAuthorizeNv(self._ctx, cnv_path, cpolicy_path) _chkrc(ret) def set_auth_callback( @@ -1276,7 +1324,7 @@ def set_branch_callback( self, callback: Optional[Callable[[str, str, List[str], Optional[Any]], int]] = None, user_data: Optional[Any] = None, - ): + ) -> None: """Set the Fapi policy branch callback, called to decide which policy path to take in a policy Or. If `callback` is None, the callback function is reset. Args: @@ -1317,7 +1365,7 @@ def set_sign_callback( Callable[[str, str, str, str, int, bytes, Optional[Any]], bytes] ] = None, user_data: Optional[Any] = None, - ): + ) -> None: """Set the Fapi signing callback which is called to satisfy the policy Signed. If `callback` is None, the callback function is reset. Args: @@ -1354,7 +1402,7 @@ def set_policy_action_callback( self, callback: Optional[Callable[[str, str, Optional[bytes]], None]] = None, user_data: Optional[Any] = None, - ): + ) -> None: """Set the policy Action callback which is called to satisfy the policy Action. If `callback` is None, the callback function is reset. Args: diff --git a/src/tpm2_pytss/TCTI.py b/src/tpm2_pytss/TCTI.py index 360e5280..c4dc796c 100644 --- a/src/tpm2_pytss/TCTI.py +++ b/src/tpm2_pytss/TCTI.py @@ -7,7 +7,8 @@ from .TSS2_Exception import TSS2_Exception import os -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Callable, Any, Iterable, Dict, Type +from types import TracebackType class PollData(object): @@ -67,10 +68,10 @@ def events(self) -> int: return self._events -def common_checks(version=1, null_ok=False): - def decorator(func): - def wrapper(self, *args, **kwargs): - def camel_case(s): +def common_checks(version: int = 1, null_ok: bool = False) -> Callable[..., Any]: + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(self: Any, *args: Iterable[Any], **kwargs: Dict[str, Any]) -> Any: + def camel_case(s: str) -> str: from re import sub s = sub(r"(_|-)+", " ", s).title().replace(" ", "") @@ -117,7 +118,8 @@ class TCTI: def __init__(self, ctx: ffi.CData): self._v1 = ffi.cast("TSS2_TCTI_CONTEXT_COMMON_V1 *", ctx) - if self._v1.version == 2: + self._v2: Optional[ffi.CData] + if int(self._v1.version) == 2: self._v2 = ffi.cast("TSS2_TCTI_CONTEXT_COMMON_V2 *", ctx) else: self._v2 = None @@ -127,13 +129,13 @@ def __init__(self, ctx: ffi.CData): # Normal TCTIs cannot make use of this, by Python TCTIs can. Add it to the base class # for subordinate TCTIs to use. This way the TCTI fn calls return the most helpful # error. - self._last_exception = None + self._last_exception: Optional[Exception] = None - def _set_last_exception(self, exc): + def _set_last_exception(self, exc: Exception) -> None: self._last_exception = exc @property - def _tcti_context(self): + def _tcti_context(self) -> ffi.CData: return self._ctx @property @@ -146,7 +148,7 @@ def magic(self) -> bytes: # uint64_t in C land by default or let subclass control it magic_len = getattr(self, "_magic_len", 8) - return self._v1.magic.to_bytes(magic_len, "big") + return int(self._v1.magic).to_bytes(magic_len, "big") @property def version(self) -> int: @@ -160,12 +162,12 @@ def version(self) -> int: The TCTI version number. """ - return self._v1.version + return int(self._v1.version) - def _clear_exceptions(self): + def _clear_exceptions(self) -> None: self._last_exception = None - def _get_current_exception(self, e: Exception): + def _get_current_exception(self, e: Exception) -> Exception: x = self._last_exception return x if x is not None else e @@ -216,7 +218,7 @@ def receive(self, size: int = 4096, timeout: int = -1) -> bytes: return bytes(ffi.buffer(resp, rsize[0])) @common_checks(null_ok=True) - def finalize(self): + def finalize(self) -> None: """Cleans up a TCTI's state and resources.""" if self._v1.finalize != ffi.NULL: @@ -241,7 +243,7 @@ def cancel(self) -> None: _chkrc(self._v1.cancel(self._ctx)) @common_checks() - def get_poll_handles(self) -> Tuple[PollData]: + def get_poll_handles(self) -> Tuple[PollData, ...]: """Gets the poll handles from the TPM. Returns: @@ -300,19 +302,22 @@ def make_sticky(self, handle: int, sticky: Union[bool, int]) -> None: """ hptr = ffi.new("TPM2_HANDLE *", handle) + if self._v2 is None: + raise ValueError(f"TCTI module does not have make_sticky") _chkrc(self._v2.makeSticky(self._ctx, hptr, sticky)) - return hptr[0] - def __enter__(self): + def __enter__(self) -> "TCTI": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, exc_type: Type[Exception], exc_value: Exception, traceback: TracebackType + ) -> None: self.finalize() # Global callbacks @ffi.def_extern() -def _tcti_transmit_wrapper(ctx, size, command): +def _tcti_transmit_wrapper(ctx: ffi.CData, size: int, command: ffi.CData) -> int: pi = PyTCTI._cffi_cast(ctx) if not hasattr(pi, "do_transmit"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -327,7 +332,9 @@ def _tcti_transmit_wrapper(ctx, size, command): @ffi.def_extern() -def _tcti_receive_wrapper(ctx, size, response, timeout): +def _tcti_receive_wrapper( + ctx: ffi.CData, size: ffi.CData, response: ffi.CData, timeout: int +) -> int: # Let the allocator know how much we need. pi = PyTCTI._cffi_cast(ctx) @@ -354,7 +361,7 @@ def _tcti_receive_wrapper(ctx, size, response, timeout): @ffi.def_extern() -def _tcti_cancel_wrapper(ctx): +def _tcti_cancel_wrapper(ctx: ffi.CData) -> int: pi = PyTCTI._cffi_cast(ctx) if not hasattr(pi, "do_cancel"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -369,7 +376,9 @@ def _tcti_cancel_wrapper(ctx): @ffi.def_extern() -def _tcti_get_pollfds_wrapper(ctx, handles, cnt): +def _tcti_get_pollfds_wrapper( + ctx: ffi.CData, handles: ffi.CData, cnt: ffi.CData +) -> int: pi = PyTCTI._cffi_cast(ctx) if not hasattr(pi, "do_get_poll_handles"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -386,7 +395,7 @@ def _tcti_get_pollfds_wrapper(ctx, handles, cnt): if handles == ffi.NULL: cnt[0] = len(pi._poll_handle_cache) elif cnt[0] < len(pi._poll_handle_cache): - raise TSS2_RC.TCTI_RC_INSUFFICIENT_BUFFER + return TSS2_RC.TCTI_RC_INSUFFICIENT_BUFFER else: cnt[0] = len(pi._poll_handle_cache) # Enumerate didn't work here @@ -407,7 +416,7 @@ def _tcti_get_pollfds_wrapper(ctx, handles, cnt): @ffi.def_extern() -def _tcti_set_locality_wrapper(ctx, locality): +def _tcti_set_locality_wrapper(ctx: ffi.CData, locality: int) -> int: pi = PyTCTI._cffi_cast(ctx) if not hasattr(pi, "do_set_locality"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -422,7 +431,7 @@ def _tcti_set_locality_wrapper(ctx, locality): @ffi.def_extern() -def _tcti_make_sticky_wrapper(ctx, handle, sticky): +def _tcti_make_sticky_wrapper(ctx: ffi.CData, handle: int, sticky: int) -> int: pi = PyTCTI._cffi_cast(ctx) if not hasattr(pi, "do_make_sticky"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -437,7 +446,7 @@ def _tcti_make_sticky_wrapper(ctx, handle, sticky): @ffi.def_extern() -def _tcti_finalize_wrapper(ctx): +def _tcti_finalize_wrapper(ctx: ffi.CData) -> None: pi = PyTCTI._cffi_cast(ctx) if not hasattr(pi, "do_finalize"): return @@ -492,7 +501,7 @@ def __init__(self, max_size: int = 4096, magic: bytes = b"PYTCTI\x00\x00"): cdata = self._cdata = ffi.new("PYTCTI_CONTEXT *") self._max_size = max_size - self._poll_handle_cache = None + self._poll_handle_cache: Optional[Tuple[PollData, ...]] = None self._magic_len = len(magic) cdata.common.v1.version = 2 cdata.common.v1.magic = int.from_bytes(magic, "big") @@ -514,9 +523,12 @@ def __init__(self, max_size: int = 4096, magic: bytes = b"PYTCTI\x00\x00"): super().__init__(opaque) @staticmethod - def _cffi_cast(ctx): + def _cffi_cast(ctx: ffi.CData) -> "PyTCTI": ctx = ffi.cast("PYTCTI_CONTEXT *", ctx) - return ffi.from_handle(ctx.thiz) + thiz = ffi.from_handle(ctx.thiz) + if not isinstance(thiz, PyTCTI): + raise ValueError(f"expected an instance of PyTCTI, got {type(thiz)}") + return thiz def do_transmit(self, command: bytes) -> None: """This method transmits a command buffer to the TPM. This method IS REQUIRED. @@ -556,7 +568,7 @@ def do_cancel(self) -> None: """ pass - def do_get_poll_handles(self) -> Optional[Tuple[PollData]]: + def do_get_poll_handles(self) -> Optional[Tuple[PollData, ...]]: """Retrieves PollData objects from the TCTI used for async I/O. This method is OPTIONAL. Returns: diff --git a/src/tpm2_pytss/TCTILdr.py b/src/tpm2_pytss/TCTILdr.py index 6e5423cf..5a320038 100644 --- a/src/tpm2_pytss/TCTILdr.py +++ b/src/tpm2_pytss/TCTILdr.py @@ -3,10 +3,16 @@ from ._libtpm2_pytss import lib, ffi from .TCTI import TCTI from .internal.utils import _chkrc +from typing import Optional, Union, Type +from types import TracebackType class TCTILdr(TCTI): - def __init__(self, name=None, conf=None): + def __init__( + self, + name: Optional[Union[bytes, ffi.CData, str]] = None, + conf: Optional[Union[bytes, ffi.CData, str]] = None, + ): self._ctx_pp = ffi.new("TSS2_TCTI_CONTEXT **") @@ -29,21 +35,23 @@ def __init__(self, name=None, conf=None): _chkrc(lib.Tss2_TctiLdr_Initialize_Ex(name, conf, self._ctx_pp)) super().__init__(self._ctx_pp[0]) - self._name = name.decode() if name else "" - self._conf = conf.decode() if conf else "" + self._name = name.decode() if isinstance(name, bytes) else "" + self._conf = conf.decode() if isinstance(conf, bytes) else "" - def __enter__(self): + def __enter__(self) -> "TCTILdr": return self - def __exit__(self, _type, value, traceback): + def __exit__( + self, _type: Type[Exception], value: Exception, traceback: TracebackType + ) -> None: self.close() - def close(self): + def close(self) -> None: lib.Tss2_TctiLdr_Finalize(self._ctx_pp) self._ctx = ffi.NULL @classmethod - def parse(cls, tcti_name_conf: str): + def parse(cls, tcti_name_conf: str) -> "TCTILdr": chunks = tcti_name_conf.split(":", 1) if len(chunks) > 2: @@ -54,22 +62,22 @@ def parse(cls, tcti_name_conf: str): return cls(name, conf) @property - def name(self): + def name(self) -> str: return self._name @property - def conf(self): + def conf(self) -> str: return self._conf @property - def name_conf(self): + def name_conf(self) -> str: return f"{self.name}:{self.conf}" if self.conf else self.name - def __str__(self): + def __str__(self) -> str: return self.name_conf @staticmethod - def is_available(name=None) -> bool: + def is_available(name: Optional[Union[ffi.CData, str, bytes]] = None) -> bool: """Lookup the TCTI and return its availability Returns: diff --git a/src/tpm2_pytss/TCTISPIHelper.py b/src/tpm2_pytss/TCTISPIHelper.py index a3a1f262..4ef857bf 100644 --- a/src/tpm2_pytss/TCTISPIHelper.py +++ b/src/tpm2_pytss/TCTISPIHelper.py @@ -5,13 +5,14 @@ from .constants import TSS2_RC, TPM2_RC from .TSS2_Exception import TSS2_Exception from .TCTI import TCTI +from typing import Optional if not _lib_version_atleast("tss2-tcti-spi-helper", "0.0.0"): raise NotImplementedError("Package tss2-tcti-spi-helper not present") @ffi.def_extern() -def _tcti_spi_helper_sleep_ms(userdata, milliseconds): +def _tcti_spi_helper_sleep_ms(userdata: ffi.CData, milliseconds: int) -> int: thiz = TCTISPIHelper._cffi_cast(userdata) if not hasattr(thiz, "on_sleep_ms"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -26,7 +27,7 @@ def _tcti_spi_helper_sleep_ms(userdata, milliseconds): @ffi.def_extern() -def _tcti_spi_helper_start_timeout(userdata, milliseconds): +def _tcti_spi_helper_start_timeout(userdata: ffi.CData, milliseconds: int) -> int: thiz = TCTISPIHelper._cffi_cast(userdata) if not hasattr(thiz, "on_start_timeout"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -41,7 +42,9 @@ def _tcti_spi_helper_start_timeout(userdata, milliseconds): @ffi.def_extern() -def _tcti_spi_helper_timeout_expired(userdata, is_time_expired) -> bool: +def _tcti_spi_helper_timeout_expired( + userdata: ffi.CData, is_time_expired: ffi.CData +) -> int: thiz = TCTISPIHelper._cffi_cast(userdata) if not hasattr(thiz, "on_timeout_expired"): @@ -58,7 +61,7 @@ def _tcti_spi_helper_timeout_expired(userdata, is_time_expired) -> bool: @ffi.def_extern() -def _tcti_spi_helper_spi_acquire(userdata): +def _tcti_spi_helper_spi_acquire(userdata: ffi.CData) -> int: thiz = TCTISPIHelper._cffi_cast(userdata) if not hasattr(thiz, "on_start_timeout"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -73,7 +76,7 @@ def _tcti_spi_helper_spi_acquire(userdata): @ffi.def_extern() -def _tcti_spi_helper_spi_release(userdata): +def _tcti_spi_helper_spi_release(userdata: ffi.CData) -> int: thiz = TCTISPIHelper._cffi_cast(userdata) if not hasattr(thiz, "on_spi_release"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -88,7 +91,9 @@ def _tcti_spi_helper_spi_release(userdata): @ffi.def_extern() -def _tcti_spi_helper_spi_transfer(userdata, data_out, data_in, cnt): +def _tcti_spi_helper_spi_transfer( + userdata: ffi.CData, data_out: ffi.CData, data_in: ffi.CData, cnt: int +) -> int: thiz = TCTISPIHelper._cffi_cast(userdata) if not hasattr(thiz, "on_spi_transfer"): return TSS2_RC.TCTI_RC_NOT_IMPLEMENTED @@ -125,10 +130,10 @@ def _tcti_spi_helper_spi_transfer(userdata, data_out, data_in, cnt): @ffi.def_extern() -def _tcti_spi_helper_finalize(userdata): +def _tcti_spi_helper_finalize(userdata: ffi.CData) -> None: thiz = TCTISPIHelper._cffi_cast(userdata) if hasattr(thiz, "on_finalize"): - thiz.on_finalize(thiz) + thiz.on_finalize() class TCTISPIHelper(TCTI): @@ -153,7 +158,7 @@ class TCTISPIHelper(TCTI): with_wait_state (bool): True if you intend to use wait states. Defaults to False. """ - def __init__(self, with_wait_state=False): + def __init__(self, with_wait_state: bool = False): self._with_wait_state = with_wait_state size = ffi.new("size_t *") @@ -212,7 +217,7 @@ def __init__(self, with_wait_state=False): super().__init__(self._opaque_tcti_ctx) @property - def waitstate(self): + def waitstate(self) -> bool: """Gets the wait state property. Returns(bool): @@ -221,8 +226,13 @@ def waitstate(self): return self._with_wait_state @staticmethod - def _cffi_cast(userdata): - return ffi.from_handle(userdata) + def _cffi_cast(userdata: ffi.CData) -> "TCTISPIHelper": + helper = ffi.from_handle(userdata) + if not isinstance(helper, TCTISPIHelper): + raise ValueError( + f"expected an instance of TCTISPIHelper, got {type(helper)}" + ) + return helper def on_sleep_ms(self, milliseconds: int) -> None: """Sleeps for a specified amount of time in millisecons. @@ -258,9 +268,9 @@ def on_timeout_expired(self) -> bool: This callback is REQUIRED. No errors may occur across this boundary. """ - pass + raise NotImplementedError - def on_spi_transfer(self, data_in: bytes) -> bytes: + def on_spi_transfer(self, data_in: Optional[bytes]) -> bytes: """Called to transfer data across the SPI bus. This callback is REQUIRED. @@ -275,7 +285,7 @@ def on_spi_transfer(self, data_in: bytes) -> bytes: Exception: Implementations are free to raise any Exception. Exceptions are retained across the native boundary. """ - pass + raise NotImplementedError def on_finalize(self) -> None: """Called when the TCTI is finalized. diff --git a/src/tpm2_pytss/TSS2_Exception.py b/src/tpm2_pytss/TSS2_Exception.py index 309e25db..58327174 100644 --- a/src/tpm2_pytss/TSS2_Exception.py +++ b/src/tpm2_pytss/TSS2_Exception.py @@ -1,5 +1,9 @@ from ._libtpm2_pytss import lib, ffi -from typing import Union +from typing import Union, TYPE_CHECKING + + +if TYPE_CHECKING: + from .constants import TSS2_RC, TPM2_RC class TSS2_Exception(RuntimeError): @@ -25,7 +29,7 @@ def __init__(self, rc: Union["TSS2_RC", "TPM2_RC", int]): else: self._error = self._rc - def _parse_fmt1(self): + def _parse_fmt1(self) -> None: self._error = lib.TPM2_RC_FMT1 + (self.rc & 0x3F) if self.rc & lib.TPM2_RC_P: @@ -36,31 +40,31 @@ def _parse_fmt1(self): self._handle = (self.rc & lib.TPM2_RC_N_MASK) >> 8 @property - def rc(self): + def rc(self) -> int: """int: The return code from the API call.""" return self._rc @property - def handle(self): + def handle(self) -> int: """int: The handle related to the error, 0 if not related to any handle.""" return self._handle @property - def parameter(self): + def parameter(self) -> int: """int: The parameter related to the error, 0 if not related to any parameter.""" return self._parameter @property - def session(self): + def session(self) -> int: """int: The session related to the error, 0 if not related to any session.""" return self._session @property - def error(self): + def error(self) -> int: """int: The error with handle, parameter and session stripped.""" return self._error @property - def fmt1(self): + def fmt1(self) -> bool: """bool: True if the error is related to a handle, parameter or session """ return bool(self._rc & lib.TPM2_RC_FMT1) diff --git a/src/tpm2_pytss/__init__.py b/src/tpm2_pytss/__init__.py index 83f401c8..c931ee43 100644 --- a/src/tpm2_pytss/__init__.py +++ b/src/tpm2_pytss/__init__.py @@ -4,13 +4,13 @@ # if we can't, provide a better message. try: from ._libtpm2_pytss import lib -except ImportError as e: - parts = e.msg.split(": ", 2) +except ImportError as ie: + parts = ie.msg.split(": ", 2) if len(parts) != 3: - raise e + raise ie path, error, symbol = parts if error != "undefined symbol": - raise e + raise ie raise ImportError( f"failed to load tpm2-tss bindigs in {path} due to missing symbol {symbol}, " + "ensure that you are using the same libraries the python module was built against." diff --git a/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi b/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi new file mode 100644 index 00000000..f128a164 --- /dev/null +++ b/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi @@ -0,0 +1,33 @@ +from typing import Optional, Callable, Iterable, Any, Dict, Tuple, Union + +error: type[Exception] +class CData: + def __getitem__(self, index: int | slice) -> Any: ... + def __setitem__(self, index: int | slice, value: "CData" | Callable[..., int] | int) -> None: ... + def __bytes__(self) -> bytes: ... + def __setattr__(self, key: str, value: "CData" | Any) -> None: ... + def __getattr__(self, key: str) -> "CData": ... + __call__: Callable[..., Any] + +class CType: + kind: str + cname: str + item: "CType" + fields: Iterable[Tuple[str, Any]] + +NULL: CData +def gc(cdata: CData, destructor: Callable[[CData], None], size: int = 0)-> CData: ... +def typeof(cdata: CData | str) -> CType: ... +def new(cdecl: str, init: Any = None) -> CData: ... +def string(cdata: CData, maxlen: Optional[int] = None) -> bytes: ... +def sizeof(cdata: CData | str) -> int: ... +def def_extern() -> Callable[..., int | None]: ... +def buffer(cdata: CData, size: Optional[int]) -> Any: ... +def from_buffer(python_buffer: Any, require_writable: bool = False) -> CData: ... +def from_handle(handle: CData) -> Any: ... +def new_handle(python_object: Any) -> CData: ... +def cast(ctype: str, value: CData) -> CData: ... +def memmove(dest: CData | bytes, src: CData | bytes, n: int) -> None: ... +def addressof(cdata: CData, *fields_or_indexes: str | int) -> CData: ... +def unpack(cdata: CData, maxlen: Optional[int]) -> bytes: ... +def new_allocator(**kwargs: Any) -> Callable[[str, Optional[Callable[[CData], CData] | bytes | int]], CData]: ... diff --git a/src/tpm2_pytss/constants.py b/src/tpm2_pytss/constants.py index 291a9ca5..7e0e002d 100644 --- a/src/tpm2_pytss/constants.py +++ b/src/tpm2_pytss/constants.py @@ -12,20 +12,41 @@ _lib_version_atleast, _chkrc, ) +from typing import ( + Dict, + Iterable, + Tuple, + Optional, + Type, + TYPE_CHECKING, + Any, + SupportsIndex, +) + +try: + from typing import Self +except ImportError: + # assume mypy is running on python 3.11+ + pass + + +if TYPE_CHECKING: + from .ESAPI import ESAPI + from .types import TPM2B_NAME, TPM2_HANDLE, TPM2B_PUBLIC class TPM_FRIENDLY_INT(int): - _FIXUP_MAP = {} + _FIXUP_MAP: Dict[str, str] = {} @classmethod - def parse(cls, value: str) -> int: + def parse(cls, value: str) -> "Self": # If it's a string initializer value, see if it matches anything in the list if isinstance(value, str): try: x = _CLASS_INT_ATTRS_from_string(cls, value, cls._FIXUP_MAP) if not isinstance(x, int): raise KeyError(f'Expected int got: "{type(x)}"') - return x + return cls(x) except KeyError: raise ValueError( f'Could not convert friendly name to value, got: "{value}"' @@ -34,7 +55,7 @@ def parse(cls, value: str) -> int: raise TypeError(f'Expected value to be a str object, got: "{type(value)}"') @classmethod - def iterator(cls) -> filter: + def iterator(cls) -> Iterable[int]: """ Returns the constants in the class. Returns: @@ -102,109 +123,111 @@ def __str__(self) -> str: return k.lower() return str(int(self)) - def __abs__(self): + def __abs__(self) -> "Self": return self.__class__(int(self).__abs__()) - def __add__(self, value): + def __add__(self, value: int) -> "Self": return self.__class__(int(self).__add__(value)) - def __and__(self, value): + def __and__(self, value: int) -> "Self": return self.__class__(int(self).__and__(value)) - def __ceil__(self): + def __ceil__(self) -> "Self": return self.__class__(int(self).__ceil__()) - def __divmod__(self, value): + def __divmod__(self, value: int) -> Tuple["Self", "Self"]: a, b = int(self).__divmod__(value) return self.__class__(a), self.__class__(b) - def __floor__(self): + def __floor__(self) -> "Self": return self.__class__(int(self).__floor__()) - def __floordiv__(self, value): + def __floordiv__(self, value: int) -> "Self": return self.__class__(int(self).__floordiv__(value)) - def __invert__(self): + def __invert__(self) -> "Self": return self.__class__(int(self).__invert__()) - def __lshift__(self, value): + def __lshift__(self, value: int) -> "Self": return self.__class__(int(self).__lshift__(value)) - def __mod__(self, value): + def __mod__(self, value: int) -> "Self": return self.__class__(int(self).__mod__(value)) - def __mul__(self, value): + def __mul__(self, value: int) -> "Self": return self.__class__(int(self).__mul__(value)) - def __neg__(self): + def __neg__(self) -> "Self": return self.__class__(int(self).__neg__()) - def __or__(self, value): + def __or__(self, value: int) -> "Self": return self.__class__(int(self).__or__(value)) - def __pos__(self): + def __pos__(self) -> "Self": return self.__class__(int(self).__pos__()) - def __pow__(self, value, mod=None): + def __pow__(self, value: int, mod: Optional[int] = None) -> Any: return self.__class__(int(self).__pow__(value, mod)) - def __radd__(self, value): + def __radd__(self, value: int) -> "Self": return self.__class__(int(self).__radd__(value)) - def __rand__(self, value): + def __rand__(self, value: int) -> "Self": return self.__class__(int(self).__rand__(value)) - def __rdivmod__(self, value): + def __rdivmod__(self, value: int) -> Tuple["Self", "Self"]: a, b = int(self).__rdivmod__(value) return self.__class__(a), self.__class__(b) - def __rfloordiv__(self, value): + def __rfloordiv__(self, value: int) -> "Self": return self.__class__(int(self).__rfloordiv__(value)) - def __rlshift__(self, value): + def __rlshift__(self, value: int) -> "Self": return self.__class__(int(self).__rlshift__(value)) - def __rmod__(self, value): + def __rmod__(self, value: int) -> "Self": return self.__class__(int(self).__rmod__(value)) - def __rmul__(self, value): + def __rmul__(self, value: int) -> "Self": return self.__class__(int(self).__rmul__(value)) - def __ror__(self, value): + def __ror__(self, value: int) -> "Self": return self.__class__(int(self).__ror__(value)) - def __round__(self): - return self.__class__(int(self).__round__()) + def __round__(self, ndigits: SupportsIndex = False) -> "Self": + return self.__class__(int(self).__round__(ndigits)) - def __rpow__(self, value, mod=None): + def __rpow__(self, value: int, mod: Optional[int] = None) -> Any: return self.__class__(int(self).__rpow__(value, mod)) - def __rrshift__(self, value): + def __rrshift__(self, value: int) -> "Self": return self.__class__(int(self).__rrshift__(value)) - def __rshift__(self, value): + def __rshift__(self, value: int) -> "Self": return self.__class__(int(self).__rshift__(value)) - def __rsub__(self, value): + def __rsub__(self, value: int) -> "Self": return self.__class__(int(self).__rsub__(value)) - def __rtruediv__(self, value): + def __rtruediv__(self, value: int) -> Any: return self.__class__(int(self).__rtruediv__(value)) - def __rxor__(self, value): + def __rxor__(self, value: int) -> "Self": return self.__class__(int(self).__rxor__(value)) - def __sub__(self, value): + def __sub__(self, value: int) -> "Self": return self.__class__(int(self).__sub__(value)) - def __truediv__(self, value): + def __truediv__(self, value: int) -> "Self": return self.__class__(int(self).__truediv__(value)) - def __xor__(self, value): + def __xor__(self, value: int) -> "Self": return self.__class__(int(self).__xor__(value)) @staticmethod - def _copy_and_set(dstcls, srccls): + def _copy_and_set( + dstcls: Type["TPM_FRIENDLY_INT"], srccls: Type["TPM_FRIENDLY_INT"] + ) -> None: """Copy class variables from srccls to dstcls srccls must be a subclass on TPM_FRIENDLY_INT. @@ -222,7 +245,7 @@ def _copy_and_set(dstcls, srccls): setattr(dstcls, k, fv) @staticmethod - def _fix_const_type(cls): + def _fix_const_type(cls: Type["TPM_FRIENDLY_INT"]) -> Type["TPM_FRIENDLY_INT"]: """Ensure constants in a TPM2 constant class have the correct type We also copy constants from a superclass in case it's of the correct type. @@ -233,7 +256,7 @@ def _fix_const_type(cls): TPM_FRIENDLY_INT._copy_and_set(cls, sc) return cls - def marshal(self): + def marshal(self) -> bytes: """Marshal instance into bytes. Returns: @@ -258,7 +281,7 @@ def marshal(self): return bytes(buf[0 : offset[0]]) @classmethod - def unmarshal(cls, buf): + def unmarshal(cls, buf: bytes) -> Tuple["Self", int]: """Unmarshal bytes into type instance. Args: @@ -285,10 +308,10 @@ def unmarshal(cls, buf): class TPMA_FRIENDLY_INTLIST(TPM_FRIENDLY_INT): - _MASKS = tuple() + _MASKS: Tuple[Tuple[int, int, str], ...] = tuple() @classmethod - def parse(cls, value: str) -> int: + def parse(cls, value: str) -> "Self": """ Converts a string of | separated constant values into it's integer value. Given a pipe "|" separated list of string constant values that represent the @@ -344,9 +367,9 @@ def parse(cls, value: str) -> int: f'Could not convert friendly name to value, got: "{k}"' ) - return intvalue + return cls(intvalue) - def __str__(self): + def __str__(self) -> str: """Given a constant, return the string bitwise representation. Each constant is seperated by the "|" (pipe) character. @@ -446,11 +469,11 @@ class ESYS_TR(TPM_FRIENDLY_INT): RH_PLATFORM = lib.ESYS_TR_RH_PLATFORM RH_PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV - def marshal(self): + def marshal(self) -> bytes: raise NotImplementedError("Use serialize() instead") @classmethod - def unmarshal(cls, buf): + def unmarshal(cls, buf: bytes) -> Tuple["Self", int]: raise NotImplementedError("Use deserialize() instead") def serialize(self, ectx: "ESAPI") -> bytes: @@ -489,7 +512,7 @@ def get_name(self, ectx: "ESAPI") -> "TPM2B_NAME": """ return ectx.tr_get_name(self) - def close(self, ectx: "ESAPI"): + def close(self, ectx: "ESAPI") -> None: """Same as see tpm2_pytss.ESAPI.tr_close Args: @@ -757,7 +780,7 @@ class TPM2_GENERATED(TPM_FRIENDLY_INT): class TPM_BASE_RC(TPM_FRIENDLY_INT): - def decode(self): + def decode(self) -> str: return ffi.string(lib.Tss2_RC_Decode(self)).decode() @@ -1326,14 +1349,14 @@ class TPMA_LOCALITY(TPMA_FRIENDLY_INTLIST): EXTENDED_SHIFT = lib.TPMA_LOCALITY_EXTENDED_SHIFT @classmethod - def create_extended(cls, value): + def create_extended(cls, value: int) -> "Self": x = (1 << cls.EXTENDED_SHIFT) + value if x > 255: raise ValueError("Extended Localities must be less than 256") - return x + return cls(x) @classmethod - def parse(cls, value: str) -> "TPMA_LOCALITY": + def parse(cls, value: str) -> "Self": """Converts a string of | separated localities or an extended locality into a TPMA_LOCALITY instance Args: @@ -1354,7 +1377,7 @@ def parse(cls, value: str) -> "TPMA_LOCALITY": return cls(value, base=0) except ValueError: pass - return super().parse(value) + return cls(super().parse(value)) def __str__(self) -> str: """Given a set of localities or an extended locality, return the string representation diff --git a/src/tpm2_pytss/fapi_info.py b/src/tpm2_pytss/fapi_info.py index 3e6c4728..85273f4d 100644 --- a/src/tpm2_pytss/fapi_info.py +++ b/src/tpm2_pytss/fapi_info.py @@ -5,20 +5,21 @@ """Interface to make TPM info dict structure more accessible via dot notation.""" from collections import defaultdict +from typing import Any, Iterable, Dict, List class Traversable: """Attributes are traversable recursively.""" - def __init__(self, data): + def __init__(self, data: Dict[Any, Any]): self.data = data - def __str__(self): + def __str__(self) -> str: return str(self.data) - def attrs_recursive(self, parent=""): + def attrs_recursive(self, parent: str = "") -> Iterable[Any]: """Return a generator to all attributes.""" - attrs_rec = [] + attrs_rec: List[str] = [] sep = "." if parent else "" for attr in dir(self): @@ -35,10 +36,10 @@ def attrs_recursive(self, parent=""): class BasicDict(Traversable): """Takes a dict and makes values accessible via dot notation.""" - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return self.data[attr] - def __dir__(self): + def __dir__(self) -> Iterable[str]: return self.data.keys() @@ -60,13 +61,19 @@ class NamedKVPList(Traversable): instance of that class is returned (passing the value to __init__()). """ - def __init__(self, data, key_name, value_name, value_class=None): + def __init__( + self, + data: Dict[Any, Any], + key_name: str, + value_name: str, + value_class: Any = None, + ): super().__init__(data) self.key_name = key_name self.value_name = value_name self.value_class = value_class - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: value = next( item[self.value_name] for item in self.data @@ -78,19 +85,19 @@ def __getattr__(self, attr): return value - def __dir__(self): + def __dir__(self) -> Iterable[str]: return [item[self.key_name].lower() for item in self.data] class Capabilities(Traversable): """Takes a list of capability dicts and makes them accessible via dot notation.""" - def _get_cap_data(self, description): + def _get_cap_data(self, description: str) -> Any: return next(cap for cap in self.data if cap["description"] == description)[ "info" ]["data"] - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: # some caps are accessed via '_' but their names contain '-' attr = attr.replace("_", "-") cap_data = self._get_cap_data(attr) @@ -111,11 +118,11 @@ def __getattr__(self, attr): return cap - def __dir__(self): + def __dir__(self) -> Iterable[str]: return [item["description"] for item in self.data] -def str_from_int_list(int_list): +def str_from_int_list(int_list: Iterable[int]) -> str: """Cast integers to bytes and decode as string.""" string = b"".join( integer.to_bytes(4, byteorder="big") for integer in int_list @@ -133,7 +140,7 @@ def str_from_int_list(int_list): class FapiInfo(Traversable): """Takes a FAPI info dict and and makes its values accessible via dot notation.""" - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: item_data = self.data[attr] return defaultdict( @@ -145,7 +152,7 @@ def __getattr__(self, attr): )[attr] @property - def vendor_string(self): + def vendor_string(self) -> str: """Get the TPM Vendor String.""" return str_from_int_list( [ @@ -157,12 +164,12 @@ def vendor_string(self): ) @property - def manufacturer(self): + def manufacturer(self) -> str: """Get the TPM Manufacturer.""" return str_from_int_list([self.capabilities.properties_fixed.manufacturer]) @property - def firmware_version(self): + def firmware_version(self) -> str: """Get the TPM Firmware Version (formatted according to vendor conventions).""" key = f"{self.manufacturer}.{self.vendor_string}" ver1 = self.capabilities.properties_fixed.firmware_version_1 @@ -173,13 +180,13 @@ def firmware_version(self): )[key] @property - def spec_revision(self): + def spec_revision(self) -> str: """Get the TPM Specification Revision.""" rev = self.capabilities.properties_fixed.ps_revision # Add '.' after first digit - rev = f"{rev // 100}.{rev % 100}" + revstr = f"{rev // 100}.{rev % 100}" - return rev + return revstr - def __dir__(self): + def __dir__(self) -> Iterable[str]: return self.data.keys() diff --git a/src/tpm2_pytss/internal/crypto.py b/src/tpm2_pytss/internal/crypto.py index 42030c56..278fea10 100644 --- a/src/tpm2_pytss/internal/crypto.py +++ b/src/tpm2_pytss/internal/crypto.py @@ -20,13 +20,26 @@ from cryptography.hazmat.primitives.kdf.kbkdf import CounterLocation, KBKDFHMAC, Mode from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash from cryptography.hazmat.primitives.ciphers.algorithms import AES, Camellia -from cryptography.hazmat.primitives.ciphers import modes, Cipher, CipherAlgorithm +from cryptography.hazmat.primitives.ciphers import modes, Cipher from cryptography.hazmat.backends import default_backend from cryptography.exceptions import UnsupportedAlgorithm, InvalidSignature -from typing import Tuple, Type, Any +from typing import Tuple, Type, Any, Union, Optional, TYPE_CHECKING import secrets import sys + +if TYPE_CHECKING: + from ..types import ( + TPMT_PUBLIC, + TPMT_SIGNATURE, + TPMT_SYM_DEF, + TPMT_SENSITIVE, + TPMS_NV_PUBLIC, + TPM2B_PUBLIC, + TPM2B_SIMPLE_OBJECT, + ) + + _curvetable = ( (TPM2_ECC.NIST_P192, ec.SECP192R1), (TPM2_ECC.NIST_P224, ec.SECP224R1), @@ -35,7 +48,7 @@ (TPM2_ECC.NIST_P521, ec.SECP521R1), ) -_digesttable = ( +_digesttable: Tuple[Tuple[TPM2_ALG, Type[hashes.HashAlgorithm]], ...] = ( (TPM2_ALG.SHA1, hashes.SHA1), (TPM2_ALG.SHA256, hashes.SHA256), (TPM2_ALG.SHA384, hashes.SHA384), @@ -48,55 +61,68 @@ if hasattr(hashes, "SM3"): _digesttable += ((TPM2_ALG.SM3_256, hashes.SM3),) -_algtable = ( +_symtable: Tuple[ + Tuple[TPM2_ALG, Union[Type[AES], Type[Camellia], Type["SM4"]]], ... +] = ( (TPM2_ALG.AES, AES), (TPM2_ALG.CAMELLIA, Camellia), - (TPM2_ALG.CFB, modes.CFB), ) +_modetable = ((TPM2_ALG.CFB, modes.CFB),) + try: from cryptography.hazmat.primitives.ciphers.algorithms import SM4 - _algtable += ((TPM2_ALG.SM4, SM4),) + _symtable += ((TPM2_ALG.SM4, SM4),) except ImportError: # SM4 not implemented by cryptography package, ignore, no SM4 support. pass -def _get_curveid(curve): +def _get_curveid(curve: ec.EllipticCurve) -> Optional[TPM2_ECC]: for (algid, c) in _curvetable: if isinstance(curve, c): return algid return None -def _get_curve(curveid): +def _get_curve(curveid: TPM2_ECC) -> Optional[Type[ec.EllipticCurve]]: for (algid, c) in _curvetable: if algid == curveid: return c return None -def _get_digest(digestid): +def _get_digest(digestid: TPM2_ALG) -> Optional[Type[hashes.HashAlgorithm]]: for (algid, d) in _digesttable: if algid == digestid: return d return None -def _get_alg(alg): - for (algid, a) in _algtable: +def _get_symmetric( + alg: TPM2_ALG, +) -> Optional[Union[Type[AES], Type[Camellia], Type["SM4"]]]: + for (algid, a) in _symtable: + if algid == alg: + return a + return None + + +def _get_symmetric_mode(alg: TPM2_ALG) -> Optional[Type[modes.CFB]]: + for (algid, a) in _modetable: if algid == alg: return a return None -def _int_to_buffer(i, b): +def _int_to_buffer(i: int, b: "TPM2B_SIMPLE_OBJECT") -> None: s = ceil(i.bit_length() / 8) b.buffer = i.to_bytes(length=s, byteorder="big") -def key_from_encoding(data, password=None): +def key_from_encoding(data: bytes, password: Optional[bytes] = None) -> Any: + key: Any try: cert = load_pem_x509_certificate(data, backend=default_backend()) key = cert.public_key() @@ -140,30 +166,33 @@ def key_from_encoding(data, password=None): raise ValueError("Unsupported key format") -def _public_from_encoding(data, obj, password=None): +def _public_from_encoding( + data: bytes, obj: "TPMT_PUBLIC", password: Optional[bytes] = None +) -> None: key = key_from_encoding(data, password) - nums = key.public_numbers() if isinstance(key, rsa.RSAPublicKey): + rsanums = key.public_numbers() obj.type = TPM2_ALG.RSA obj.parameters.rsaDetail.keyBits = key.key_size - _int_to_buffer(nums.n, obj.unique.rsa) - if nums.e != 65537: - obj.parameters.rsaDetail.exponent = nums.e + _int_to_buffer(rsanums.n, obj.unique.rsa) + if rsanums.e != 65537: + obj.parameters.rsaDetail.exponent = rsanums.e else: obj.parameters.rsaDetail.exponent = 0 elif isinstance(key, ec.EllipticCurvePublicKey): + ecnums = key.public_numbers() obj.type = TPM2_ALG.ECC curveid = _get_curveid(key.curve) if curveid is None: raise ValueError(f"unsupported curve: {key.curve.name}") obj.parameters.eccDetail.curveID = curveid - _int_to_buffer(nums.x, obj.unique.ecc.x) - _int_to_buffer(nums.y, obj.unique.ecc.y) + _int_to_buffer(ecnums.x, obj.unique.ecc.x) + _int_to_buffer(ecnums.y, obj.unique.ecc.y) else: raise ValueError(f"unsupported key type: {key.__class__.__name__}") -def private_key_from_encoding(data, password=None): +def private_key_from_encoding(data: bytes, password: Optional[bytes] = None) -> Any: try: key = load_pem_private_key(data, password=password, backend=default_backend()) return key @@ -183,7 +212,9 @@ def private_key_from_encoding(data, password=None): raise ValueError("Unsupported key format") -def _private_from_encoding(data, obj, password=None): +def _private_from_encoding( + data: bytes, obj: "TPMT_SENSITIVE", password: Optional[bytes] = None +) -> None: key = private_key_from_encoding(data, password) nums = key.private_numbers() if isinstance(key, rsa.RSAPrivateKey): @@ -196,28 +227,29 @@ def _private_from_encoding(data, obj, password=None): raise ValueError(f"unsupported key type: {key.__class__.__name__}") -def public_to_key(obj): - key = None +def public_to_key( + obj: "TPMT_PUBLIC", +) -> Union[rsa.RSAPublicKey, ec.EllipticCurvePublicKey]: if obj.type == TPM2_ALG.RSA: b = obj.unique.rsa.buffer n = int.from_bytes(b, byteorder="big") e = obj.parameters.rsaDetail.exponent if e == 0: e = 65537 - nums = rsa.RSAPublicNumbers(e, n) - key = nums.public_key(backend=default_backend()) + rsanums = rsa.RSAPublicNumbers(e, n) + rsakey = rsanums.public_key(backend=default_backend()) + return rsakey elif obj.type == TPM2_ALG.ECC: curve = _get_curve(obj.parameters.eccDetail.curveID) if curve is None: raise ValueError(f"unsupported curve: {obj.parameters.eccDetail.curveID}") x = int.from_bytes(obj.unique.ecc.x, byteorder="big") y = int.from_bytes(obj.unique.ecc.y, byteorder="big") - nums = ec.EllipticCurvePublicNumbers(x, y, curve()) - key = nums.public_key(backend=default_backend()) - else: - raise ValueError(f"unsupported key type: {obj.type}") + ecnums = ec.EllipticCurvePublicNumbers(x, y, curve()) + eckey = ecnums.public_key(backend=default_backend()) + return eckey - return key + raise ValueError(f"unsupported key type: {obj.type}") class _MyRSAPrivateNumbers: @@ -255,11 +287,11 @@ def _xgcd(a: int, b: int) -> Tuple[int, int, int]: # - https://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm#Iterative_algorithm_3 # @staticmethod - def _modinv(a, m): + def _modinv(a: int, m: int) -> int: return pow(a, -1, m) @staticmethod - def _generate_d(p, q, e, n): + def _generate_d(p: int, q: int, e: int, n: int) -> int: # P most always be larger so we don't go negative if p < q: @@ -271,8 +303,9 @@ def _generate_d(p, q, e, n): return d -def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC"): - key = None +def private_to_key( + private: "TPMT_SENSITIVE", public: "TPMT_PUBLIC" +) -> Union[rsa.RSAPrivateKey, ec.EllipticCurvePrivateKey]: if private.sensitiveType == TPM2_ALG.RSA: p = int.from_bytes(bytes(private.sensitive.rsa), byteorder="big") @@ -283,9 +316,10 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC") else 65537 ) - key = _MyRSAPrivateNumbers(p, n, e, rsa.RSAPublicNumbers(e, n)).private_key( + rsakey = _MyRSAPrivateNumbers(p, n, e, rsa.RSAPublicNumbers(e, n)).private_key( backend=default_backend() ) + return rsakey elif private.sensitiveType == TPM2_ALG.ECC: curve = _get_curve(public.parameters.eccDetail.curveID) @@ -298,16 +332,15 @@ def private_to_key(private: "types.TPMT_SENSITIVE", public: "types.TPMT_PUBLIC") x = int.from_bytes(bytes(public.unique.ecc.x), byteorder="big") y = int.from_bytes(bytes(public.unique.ecc.y), byteorder="big") - key = ec.EllipticCurvePrivateNumbers( + eckey = ec.EllipticCurvePrivateNumbers( p, ec.EllipticCurvePublicNumbers(x, y, curve()) ).private_key(backend=default_backend()) - else: - raise ValueError(f"unsupported key type: {private.sensitiveType}") + return eckey - return key + raise ValueError(f"unsupported key type: {private.sensitiveType}") -def _public_to_pem(obj, encoding="pem"): +def _public_to_pem(obj: "TPMT_PUBLIC", encoding: str = "pem") -> bytes: encoding = encoding.lower() key = public_to_key(obj) if encoding == "pem": @@ -320,7 +353,7 @@ def _public_to_pem(obj, encoding="pem"): raise ValueError(f"unsupported encoding: {encoding}") -def _getname(obj): +def _getname(obj: Union["TPMT_PUBLIC", "TPMS_NV_PUBLIC"]) -> bytes: dt = _get_digest(obj.nameAlg) if dt is None: raise ValueError(f"unsupported digest algorithm: {obj.nameAlg}") @@ -333,7 +366,14 @@ def _getname(obj): return name -def _kdfa(hashAlg, key, label, contextU, contextV, bits): +def _kdfa( + hashAlg: "TPM2_ALG", + key: bytes, + label: bytes, + contextU: bytes, + contextV: bytes, + bits: int, +) -> bytes: halg = _get_digest(hashAlg) if halg is None: raise ValueError(f"unsupported digest algorithm: {hashAlg}") @@ -356,7 +396,14 @@ def _kdfa(hashAlg, key, label, contextU, contextV, bits): return kdf.derive(key) -def kdfe(hashAlg, z, use, partyuinfo, partyvinfo, bits): +def kdfe( + hashAlg: TPM2_ALG, + z: bytes, + use: bytes, + partyuinfo: bytes, + partyvinfo: bytes, + bits: int, +) -> bytes: halg = _get_digest(hashAlg) if halg is None: raise ValueError(f"unsupported digest algorithm: {hashAlg}") @@ -370,18 +417,20 @@ def kdfe(hashAlg, z, use, partyuinfo, partyvinfo, bits): return kdf.derive(z) -def _symdef_to_crypt(symdef): - alg = _get_alg(symdef.algorithm) +def _symdef_to_crypt( + symdef: "TPMT_SYM_DEF", +) -> Tuple[Union[Type[AES], Type[Camellia], Type["SM4"]], Type[modes.CFB], int]: + alg = _get_symmetric(symdef.algorithm) if alg is None: raise ValueError(f"unsupported symmetric algorithm {symdef.algorithm}") - mode = _get_alg(symdef.mode.sym) + mode = _get_symmetric_mode(symdef.mode.sym) if mode is None: raise ValueError(f"unsupported symmetric mode {symdef.mode.sym}") bits = symdef.keyBits.sym return (alg, mode, bits) -def _calculate_sym_unique(nameAlg, secret, seed): +def _calculate_sym_unique(nameAlg: TPM2_ALG, secret: bytes, seed: bytes) -> bytes: dt = _get_digest(nameAlg) if dt is None: raise ValueError(f"unsupported digest algorithm: {nameAlg}") @@ -391,15 +440,15 @@ def _calculate_sym_unique(nameAlg, secret, seed): return d.finalize() -def _get_digest_size(alg): +def _get_digest_size(alg: TPM2_ALG) -> int: dt = _get_digest(alg) if dt is None: raise ValueError(f"unsupported digest algorithm: {alg}") - return dt.digest_size + return dt().digest_size -def _get_signature_bytes(sig): +def _get_signature_bytes(sig: "TPMT_SIGNATURE") -> bytes: if sig.sigAlg in (TPM2_ALG.RSAPSS, TPM2_ALG.RSASSA): rb = bytes(sig.signature.rsapss.sig) elif sig.sigAlg == TPM2_ALG.ECDSA: @@ -414,17 +463,20 @@ def _get_signature_bytes(sig): return rb -def verify_signature_rsa(signature, key, data): +def verify_signature_rsa( + signature: "TPMT_SIGNATURE", key: rsa.RSAPublicKey, data: bytes +) -> None: dt = _get_digest(signature.signature.any.hashAlg) if dt is None: raise ValueError( f"unsupported digest algorithm: {signature.signature.rsapss.hash}" ) mpad = None + pad: Union[padding.PKCS1v15, padding.PSS] if signature.sigAlg == TPM2_ALG.RSASSA: pad = padding.PKCS1v15() elif signature.sigAlg == TPM2_ALG.RSAPSS: - pad = padding.PSS(mgf=padding.MGF1(dt()), salt_length=dt.digest_size) + pad = padding.PSS(mgf=padding.MGF1(dt()), salt_length=dt().digest_size) mpad = padding.PSS(mgf=padding.MGF1(dt()), salt_length=padding.PSS.MAX_LENGTH) else: raise ValueError(f"unsupported RSA signature algorithm: {signature.sigAlg}") @@ -439,7 +491,9 @@ def verify_signature_rsa(signature, key, data): raise -def verify_signature_ecc(signature, key, data): +def verify_signature_ecc( + signature: "TPMT_SIGNATURE", key: ec.EllipticCurvePublicKey, data: bytes +) -> None: dt = _get_digest(signature.signature.any.hashAlg) if dt is None: raise ValueError( @@ -451,7 +505,7 @@ def verify_signature_ecc(signature, key, data): key.verify(sig, data, ec.ECDSA(dt())) -def verify_signature_hmac(signature, key, data): +def verify_signature_hmac(signature: "TPMT_SIGNATURE", key: bytes, data: bytes) -> None: dt = _get_digest(signature.signature.hmac.hashAlg) if dt is None: raise ValueError( @@ -466,24 +520,31 @@ def verify_signature_hmac(signature, key, data): h.verify(sig) -def _verify_signature(signature, key, data): +def _verify_signature( + signature: "TPMT_SIGNATURE", + key: Union["TPMT_PUBLIC", "TPM2B_PUBLIC", bytes], + data: bytes, +) -> None: if hasattr(key, "publicArea"): - key = key.publicArea - kt = getattr(key, "type", None) + pa = key.publicArea + else: + pa = key + kt = getattr(pa, "type", None) + pkey: Any = key if kt in (TPM2_ALG.RSA, TPM2_ALG.ECC): - key = public_to_key(key) + pkey = public_to_key(pa) if signature.sigAlg in (TPM2_ALG.RSASSA, TPM2_ALG.RSAPSS): - if not isinstance(key, rsa.RSAPublicKey): + if not isinstance(pkey, rsa.RSAPublicKey): raise ValueError( - f"bad key type for {signature.sigAlg}, expected RSA public key, got {key.__class__.__name__}" + f"bad key type for {signature.sigAlg}, expected RSA public key, got {pkey.__class__.__name__}" ) - verify_signature_rsa(signature, key, data) + verify_signature_rsa(signature, pkey, data) elif signature.sigAlg == TPM2_ALG.ECDSA: - if not isinstance(key, ec.EllipticCurvePublicKey): + if not isinstance(pkey, ec.EllipticCurvePublicKey): raise ValueError( - f"bad key type for {signature.sigAlg}, expected ECC public key, got {key.__class__.__name__}" + f"bad key type for {signature.sigAlg}, expected ECC public key, got {pkey.__class__.__name__}" ) - verify_signature_ecc(signature, key, data) + verify_signature_ecc(signature, pkey, data) elif signature.sigAlg == TPM2_ALG.HMAC: if not isinstance(key, bytes): raise ValueError( @@ -495,12 +556,12 @@ def _verify_signature(signature, key, data): def _generate_rsa_seed( - key: rsa.RSAPublicKey, hashAlg: int, label: bytes + key: rsa.RSAPublicKey, hashAlg: TPM2_ALG, label: bytes ) -> Tuple[bytes, bytes]: halg = _get_digest(hashAlg) if halg is None: raise ValueError(f"unsupported digest algorithm {hashAlg}") - seed = secrets.token_bytes(halg.digest_size) + seed = secrets.token_bytes(halg().digest_size) mgf = padding.MGF1(halg()) padd = padding.OAEP(mgf, halg(), label) enc_seed = key.encrypt(seed, padd) @@ -508,7 +569,7 @@ def _generate_rsa_seed( def _generate_ecc_seed( - key: ec.EllipticCurvePublicKey, hashAlg: int, label: bytes + key: ec.EllipticCurvePublicKey, hashAlg: TPM2_ALG, label: bytes ) -> Tuple[bytes, bytes]: halg = _get_digest(hashAlg) if halg is None: @@ -528,21 +589,22 @@ def _generate_ecc_seed( shared_key = ekey.exchange(ec.ECDH(), key) pubnum = key.public_numbers() xbytes = pubnum.x.to_bytes(plength, "big") - seed = kdfe(hashAlg, shared_key, label, exbytes, xbytes, halg.digest_size * 8) + seed = kdfe(hashAlg, shared_key, label, exbytes, xbytes, halg().digest_size * 8) return (seed, secret) -def _generate_seed(public: "types.TPMT_PUBLIC", label: bytes) -> Tuple[bytes, bytes]: +def _generate_seed(public: "TPMT_PUBLIC", label: bytes) -> Tuple[bytes, bytes]: key = public_to_key(public) - if public.type == TPM2_ALG.RSA: + if isinstance(key, rsa.RSAPublicKey): return _generate_rsa_seed(key, public.nameAlg, label) - elif public.type == TPM2_ALG.ECC: + elif isinstance(key, ec.EllipticCurvePublicKey): return _generate_ecc_seed(key, public.nameAlg, label) - else: - raise ValueError(f"unsupported seed algorithm {public.type}") + raise ValueError(f"unsupported seed algorithm {public.type}") -def __rsa_secret_to_seed(key, hashAlg: int, label: bytes, outsymseed: bytes): +def __rsa_secret_to_seed( + key: rsa.RSAPrivateKey, hashAlg: TPM2_ALG, label: bytes, outsymseed: bytes +) -> bytes: halg = _get_digest(hashAlg) if halg is None: raise ValueError(f"unsupported digest algorithm {hashAlg}") @@ -553,8 +615,8 @@ def __rsa_secret_to_seed(key, hashAlg: int, label: bytes, outsymseed: bytes): def __ecc_secret_to_seed( - key: ec.EllipticCurvePrivateKey, hashAlg: int, label: bytes, outsymseed: bytes -) -> Tuple[bytes, bytes]: + key: ec.EllipticCurvePrivateKey, hashAlg: TPM2_ALG, label: bytes, outsymseed: bytes +) -> bytes: halg = _get_digest(hashAlg) if halg is None: raise ValueError(f"unsupported digest algorithm {hashAlg}") @@ -580,16 +642,17 @@ def __ecc_secret_to_seed( pubnum = key.public_key().public_numbers() xbytes = pubnum.x.to_bytes(key.key_size // 8, "big") - seed = kdfe(hashAlg, shared_key, label, exbytes, xbytes, halg.digest_size * 8) + seed = kdfe(hashAlg, shared_key, label, exbytes, xbytes, halg().digest_size * 8) return seed def _secret_to_seed( - private: "types.TPMT_SENSITIVE", - public: "types.TPMT_PUBLIC", + private: "TPMT_SENSITIVE", + public: "TPMT_PUBLIC", label: bytes, - outsymseed: bytes, -): + outsymseed: Union[bytes, "TPM2B_SIMPLE_OBJECT"], +) -> bytes: + outsymseed = bytes(outsymseed) key = private_to_key(private, public) if isinstance(key, rsa.RSAPrivateKey): return __rsa_secret_to_seed(key, public.nameAlg, label, outsymseed) @@ -600,7 +663,7 @@ def _secret_to_seed( def _hmac( - halg: hashes.HashAlgorithm, hmackey: bytes, enc_cred: bytes, name: bytes + halg: Type[hashes.HashAlgorithm], hmackey: bytes, enc_cred: bytes, name: bytes ) -> bytes: h = HMAC(hmackey, halg(), backend=default_backend()) h.update(enc_cred) @@ -609,12 +672,12 @@ def _hmac( def _check_hmac( - halg: hashes.HashAlgorithm, + halg: Type[hashes.HashAlgorithm], hmackey: bytes, enc_cred: bytes, name: bytes, expected: bytes, -): +) -> None: h = HMAC(hmackey, halg(), backend=default_backend()) h.update(enc_cred) h.update(name) @@ -622,7 +685,10 @@ def _check_hmac( def _encrypt( - cipher: Type[CipherAlgorithm], mode: Type[modes.Mode], key: bytes, data: bytes + cipher: Union[Type[AES], Type[Camellia], Type["SM4"]], + mode: Type[modes.CFB], + key: bytes, + data: bytes, ) -> bytes: iv = len(key) * b"\x00" ci = cipher(key) @@ -633,7 +699,10 @@ def _encrypt( def _decrypt( - cipher: Type[CipherAlgorithm], mode: Type[modes.Mode], key: bytes, data: bytes + cipher: Union[Type[AES], Type[Camellia], Type["SM4"]], + mode: Type[modes.CFB], + key: bytes, + data: bytes, ) -> bytes: iv = len(key) * b"\x00" ci = cipher(key) diff --git a/src/tpm2_pytss/internal/utils.py b/src/tpm2_pytss/internal/utils.py index 80a8ee8c..80199bd2 100644 --- a/src/tpm2_pytss/internal/utils.py +++ b/src/tpm2_pytss/internal/utils.py @@ -1,11 +1,24 @@ # SPDX-License-Identifier: BSD-2 import logging import sys -from typing import List +from typing import ( + List, + Optional, + Union, + TYPE_CHECKING, + Callable, + Any, + Dict, + Type, + Tuple, +) from .._libtpm2_pytss import ffi, lib from ..TSS2_Exception import TSS2_Exception +if TYPE_CHECKING: + from ..constants import TPM_FRIENDLY_INT + try: from .versions import _versions except ImportError as e: @@ -53,7 +66,7 @@ def __init__(self, version: str): hunks = version.split(".") extra_data = version.split("-")[1:] - def handle_extra(): + def handle_extra() -> None: nonlocal extra_data nonlocal commits nonlocal rc @@ -93,7 +106,7 @@ def handle_extra(): else: raise ValueError(f'Invalid version string, got: "{version}"') - def cleanse(xstr): + def cleanse(xstr: str) -> str: if "-" in xstr: return xstr[: xstr.find("-")] @@ -124,49 +137,60 @@ def cleanse(xstr): raise ValueError(f'Invalid version string, got: "{version}"') # Convert to int - major = int(major, 0).to_bytes(4, byteorder="big") - minor = int(minor, 0).to_bytes(4, byteorder="big") - patch = int(patch, 0).to_bytes(4, byteorder="big") - rc = int(rc, 0).to_bytes(4, byteorder="big") - commits = int(commits, 0).to_bytes(4, byteorder="big") - dirty = int(is_dirty).to_bytes(1, byteorder="big") + major_bytes = int(major, 0).to_bytes(4, byteorder="big") + minor_bytes = int(minor, 0).to_bytes(4, byteorder="big") + patch_bytes = int(patch, 0).to_bytes(4, byteorder="big") + rc_bytes = int(rc, 0).to_bytes(4, byteorder="big") + commits_bytes = int(commits, 0).to_bytes(4, byteorder="big") + dirty_bytes = int(is_dirty).to_bytes(1, byteorder="big") # TO make reasoning easy we lay out a big int where each field # can hold 4 bytes of data, except for dirty which is a byte # MAJOR : MINOR : PATCH : RC : COMMITS : DIRTY - concatenated = major + minor + patch + rc + commits + dirty + concatenated = ( + major_bytes + + minor_bytes + + patch_bytes + + rc_bytes + + commits_bytes + + dirty_bytes + ) v = int.from_bytes(concatenated, byteorder="big") self._value = v - def __str__(self): + def __str__(self) -> str: return self._version - def __lt__(self, other): + def __lt__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value < x - def __lte__(self, other): + def __lte__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value <= x - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, (int, self.__class__)): + return False x = other if isinstance(other, int) else other._value return self._value == x - def __ne__(self, other): + def __ne__(self, other: object) -> bool: + if not isinstance(other, (int, self.__class__)): + return False x = other if isinstance(other, int) else other._value return self._value != x - def __ge__(self, other): + def __ge__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value >= x - def __gt__(self, other): + def __gt__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value > x -def _chkrc(rc, acceptable=None): +def _chkrc(rc: int, acceptable: Optional[Union[List[int], int]] = None) -> None: if acceptable is None: acceptable = [] elif isinstance(acceptable, int): @@ -176,7 +200,9 @@ def _chkrc(rc, acceptable=None): raise TSS2_Exception(rc) -def _to_bytes_or_null(value, allow_null=True, encoding=None): +def _to_bytes_or_null( + value: Union[None, bytes, str], allow_null: bool = True, encoding: str = "utf-8" +) -> Union[bytes, ffi.CData]: """Convert to cdata input. None: ffi.NULL (if allow_null == True) @@ -199,7 +225,9 @@ def _to_bytes_or_null(value, allow_null=True, encoding=None): #### Utilities #### -def _CLASS_INT_ATTRS_from_string(cls, str_value, fixup_map=None): +def _CLASS_INT_ATTRS_from_string( + cls: object, str_value: str, fixup_map: Optional[Dict[str, str]] = None +) -> int: """ Given a class, lookup int attributes by name and return that attribute value. :param cls: The class to search. @@ -218,18 +246,20 @@ def _CLASS_INT_ATTRS_from_string(cls, str_value, fixup_map=None): return friendly[str_value.upper()] -def _cpointer_to_ctype(x): +def _cpointer_to_ctype(x: ffi.CData) -> ffi.CType: tipe = ffi.typeof(x) if tipe.kind == "pointer": tipe = tipe.item return tipe -def _fixup_cdata_kwargs(this, _cdata, kwargs): +def _fixup_cdata_kwargs( + this: Any, _cdata: Any, kwargs: Dict[str, Any] +) -> Tuple[ffi.CData, Dict[str, Any]]: # folks may call this routine without a keyword argument which means it may # end up in _cdata, so we want to try and work this out - unknown = None + unknown: Optional[ffi.CData] = None try: # is _cdata actual ffi data? ffi.typeof(_cdata) @@ -254,6 +284,8 @@ def _fixup_cdata_kwargs(this, _cdata, kwargs): # ignore the field that is size or count, and get the one for the data size_field_name = "size" if "TPM2B_" in tipe.cname else "count" field_name = next((v[0] for v in tipe.fields if v[0] != size_field_name), None) + if field_name is None: + raise AttributeError("No non size/could field found") if len(kwargs) != 0: raise RuntimeError( @@ -277,18 +309,20 @@ def _fixup_cdata_kwargs(this, _cdata, kwargs): return (_cdata, kwargs) -def _ref_parent(data, parent): +def _ref_parent(data: ffi.CData, parent: ffi.CData) -> ffi.CData: tipe = ffi.typeof(parent) if tipe.kind != "pointer": return data - def deconstructor(ptr): + def deconstructor(ptr: ffi.CData) -> None: parent return ffi.gc(data, deconstructor) -def _convert_to_python_native(global_map, data, parent=None): +def _convert_to_python_native( + global_map: Dict[str, Any], data: ffi.CData, parent: Optional[ffi.CData] = None +) -> Any: if not isinstance(data, ffi.CData): return data @@ -314,7 +348,7 @@ def _convert_to_python_native(global_map, data, parent=None): return obj -def _fixup_classname(tipe): +def _fixup_classname(tipe: ffi.CType) -> str: # Some versions of tpm2-tss had anonymous structs, so the kind will be struct # but the name will not contain it if tipe.cname.startswith(tipe.kind): @@ -323,15 +357,17 @@ def _fixup_classname(tipe): return tipe.cname -def _mock_bail(): +def _mock_bail() -> bool: return __MOCK__ -def _get_dptr(dptr, free_func): +def _get_dptr(dptr: ffi.CData, free_func: Callable[[ffi.CData], None]) -> ffi.CData: return ffi.gc(dptr[0], free_func) -def _check_friendly_int(friendly, varname, clazz): +def _check_friendly_int( + friendly: int, varname: str, clazz: Type["TPM_FRIENDLY_INT"] +) -> None: if not isinstance(friendly, int): raise TypeError(f"expected {varname} to be type int, got {type(friendly)}") @@ -343,7 +379,9 @@ def _check_friendly_int(friendly, varname, clazz): def is_bug_fixed( - fixed_in=None, backports: List[str] = None, lib: str = "tss2-fapi" + fixed_in: Optional[str] = None, + backports: Optional[List[str]] = None, + lib: str = "tss2-fapi", ) -> bool: """Use pkg-config to determine if a bug was fixed in the currently installed tpm2-tss version.""" if fixed_in and _lib_version_atleast(lib, fixed_in): @@ -368,9 +406,9 @@ def is_bug_fixed( def _check_bug_fixed( - details, - fixed_in=None, - backports: List[str] = None, + details: str, + fixed_in: Optional[str] = None, + backports: Optional[List[str]] = None, lib: str = "tss2-fapi", error: bool = False, ) -> None: @@ -385,7 +423,7 @@ def _check_bug_fixed( logger.warning(message) -def _lib_version_atleast(tss2_lib, version): +def _lib_version_atleast(tss2_lib: str, version: str) -> bool: if tss2_lib not in _versions: return False diff --git a/src/tpm2_pytss/policy.py b/src/tpm2_pytss/policy.py index cc6beadf..70402185 100644 --- a/src/tpm2_pytss/policy.py +++ b/src/tpm2_pytss/policy.py @@ -19,7 +19,8 @@ from ._libtpm2_pytss import ffi, lib from .ESAPI import ESAPI from enum import Enum -from typing import Callable, Union +from typing import Callable, Union, Any, Type, Optional, Dict +from types import TracebackType class policy_cb_types(Enum): @@ -39,7 +40,12 @@ class policy_cb_types(Enum): @ffi.def_extern() -def _policy_cb_calc_pcr(selection, out_selection, out_digest, userdata): +def _policy_cb_calc_pcr( + selection: ffi.CData, + out_selection: ffi.CData, + out_digest: ffi.CData, + userdata: ffi.CData, +) -> int: """Callback wrapper for policy PCR calculations Args: @@ -73,7 +79,7 @@ def _policy_cb_calc_pcr(selection, out_selection, out_digest, userdata): @ffi.def_extern() -def _policy_cb_calc_name(path, name, userdata): +def _policy_cb_calc_name(path: ffi.CData, name: ffi.CData, userdata: ffi.CData) -> int: """Callback wrapper for policy name calculations Args: @@ -100,7 +106,9 @@ def _policy_cb_calc_name(path, name, userdata): @ffi.def_extern() -def _policy_cb_calc_public(path, public, userdata): +def _policy_cb_calc_public( + path: ffi.CData, public: ffi.CData, userdata: ffi.CData +) -> int: """Callback wrapper for getting the public part for a key path Args: @@ -132,7 +140,9 @@ def _policy_cb_calc_public(path, public, userdata): @ffi.def_extern() -def _policy_cb_calc_nvpublic(path, nv_index, nv_public, userdata): +def _policy_cb_calc_nvpublic( + path: ffi.CData, nv_index: int, nv_public: ffi.CData, userdata: ffi.CData +) -> int: """Callback wrapper for getting the public part for a NV path Args: @@ -165,7 +175,13 @@ def _policy_cb_calc_nvpublic(path, nv_index, nv_public, userdata): @ffi.def_extern() -def _policy_cb_exec_auth(name, object_handle, auth_handle, auth_session, userdata): +def _policy_cb_exec_auth( + name: ffi.CData, + object_handle: ffi.CData, + auth_handle: ffi.CData, + auth_session: ffi.CData, + userdata: ffi.CData, +) -> int: """Callback wrapper for getting authorization sessions for a name Args: @@ -180,7 +196,7 @@ def _policy_cb_exec_auth(name, object_handle, auth_handle, auth_session, userdat if not cb: return TSS2_RC.POLICY_RC_NULL_CALLBACK try: - nb = ffi.unpack(name.name, name.size) + nb = ffi.unpack(name.name, int(name.size)) name2b = TPM2B_NAME(nb) cb_object_handle, cb_auth_handle, cb_auth_session = cb(name2b) object_handle[0] = cb_object_handle @@ -197,8 +213,12 @@ def _policy_cb_exec_auth(name, object_handle, auth_handle, auth_session, userdat @ffi.def_extern() def _policy_cb_exec_polsel( - auth_object, branch_names, branch_count, branch_idx, userdata -): + auth_object: ffi.CData, + branch_names: ffi.CData, + branch_count: int, + branch_idx: ffi.CData, + userdata: ffi.CData, +) -> int: """Callback wrapper selection of a policy branch Args: @@ -233,15 +253,15 @@ def _policy_cb_exec_polsel( @ffi.def_extern() def _policy_cb_exec_sign( - key_pem, - public_key_hint, - key_pem_hash_alg, - buf, - buf_size, - signature, - signature_size, - userdata, -): + key_pem: ffi.CData, + public_key_hint: ffi.CData, + key_pem_hash_alg: int, + buf: ffi.CData, + buf_size: int, + signature: ffi.CData, + signature_size: ffi.CData, + userdata: ffi.CData, +) -> int: """Callback wrapper to signing an operation Args: @@ -277,8 +297,13 @@ def _policy_cb_exec_sign( @ffi.def_extern() def _policy_cb_exec_polauth( - key_public, hash_alg, digest, policy_ref, signature, userdata -): + key_public: ffi.CData, + hash_alg: int, + digest: ffi.CData, + policy_ref: ffi.CData, + signature: ffi.CData, + userdata: ffi.CData, +) -> int: """Callback for signing a policy Args: @@ -296,8 +321,8 @@ def _policy_cb_exec_polauth( try: key_pub = TPMT_PUBLIC(_cdata=key_public) halg = TPM2_ALG(hash_alg) - db = ffi.unpack(digest.buffer, digest.size) - pb = ffi.unpack(policy_ref.buffer, policy_ref.size) + db = ffi.unpack(digest.buffer, int(digest.size)) + pb = ffi.unpack(policy_ref.buffer, int(policy_ref.size)) dig = TPM2B_DIGEST(db) polref = TPM2B_NONCE(pb) cb_signature = cb(key_pub, halg, dig, polref) @@ -313,7 +338,9 @@ def _policy_cb_exec_polauth( @ffi.def_extern() -def _policy_cb_exec_polauthnv(nv_public, hash_alg, userdata): +def _policy_cb_exec_polauthnv( + nv_public: ffi.CData, hash_alg: int, userdata: ffi.CData +) -> int: """Callback wrapper for NV policy authorization Args: @@ -339,7 +366,7 @@ def _policy_cb_exec_polauthnv(nv_public, hash_alg, userdata): @ffi.def_extern() -def _policy_cb_exec_poldup(name, userdata): +def _policy_cb_exec_poldup(name: ffi.CData, userdata: ffi.CData) -> int: """Callback wrapper to get name for duplication selection Args: @@ -364,7 +391,7 @@ def _policy_cb_exec_poldup(name, userdata): @ffi.def_extern() -def _policy_cb_exec_polaction(action, userdata): +def _policy_cb_exec_polaction(action: ffi.CData, userdata: ffi.CData) -> int: """Callback wrapper for policy action Args: @@ -408,7 +435,7 @@ def __init__(self, policy: Union[bytes, str], hash_alg: TPM2_ALG): policy = policy.encode() self._policy = policy self._hash_alg = hash_alg - self._callbacks = dict() + self._callbacks: Dict[policy_cb_types, Optional[Callable[..., Any]]] = dict() self._callback_exception = None self._ctx_pp = ffi.new("TSS2_POLICY_CTX **") _chkrc(lib.Tss2_PolicyInit(policy, hash_alg, self._ctx_pp)) @@ -417,13 +444,15 @@ def __init__(self, policy: Union[bytes, str], hash_alg: TPM2_ALG): self._calc_callbacks = ffi.new("TSS2_POLICY_CALC_CALLBACKS *") self._exec_callbacks = ffi.new("TSS2_POLICY_EXEC_CALLBACKS *") - def __enter__(self): + def __enter__(self) -> "policy": return self - def __exit__(self, _type, value, traceback): + def __exit__( + self, _type: Type[Exception], value: Exception, traceback: TracebackType + ) -> None: self.close() - def close(self): + def close(self) -> None: """Finalize the policy instance""" lib.Tss2_PolicyFinalize(self._ctx_pp) self._ctx_pp = ffi.NULL @@ -439,12 +468,14 @@ def hash_alg(self) -> TPM2_ALG: """TPM2_ALG: The hash algorithm to be used during policy calculcation.""" return self._hash_alg - def _get_callback(self, callback_type: policy_cb_types) -> Callable: + def _get_callback( + self, callback_type: policy_cb_types + ) -> Optional[Callable[..., Any]]: return self._callbacks.get(callback_type) def set_callback( - self, callback_type: policy_cb_types, callback: Union[None, Callable] - ): + self, callback_type: policy_cb_types, callback: Union[None, Callable[..., Any]] + ) -> None: """Set callback for policy calculaction or execution Args: @@ -522,7 +553,7 @@ def set_callback( elif update_exec: _chkrc(lib.Tss2_PolicySetExecCallbacks(self._ctx, self._exec_callbacks)) - def execute(self, esys_ctx: ESAPI, session: ESYS_TR): + def execute(self, esys_ctx: ESAPI, session: ESYS_TR) -> None: """Executes the policy Args: @@ -541,7 +572,7 @@ def execute(self, esys_ctx: ESAPI, session: ESYS_TR): finally: self._callback_exception = None - def calculate(self): + def calculate(self) -> None: """Calculate the policy Raises: diff --git a/src/tpm2_pytss/tsskey.py b/src/tpm2_pytss/tsskey.py index 5ad216f9..ae1a9033 100644 --- a/src/tpm2_pytss/tsskey.py +++ b/src/tpm2_pytss/tsskey.py @@ -3,9 +3,11 @@ import warnings from ._libtpm2_pytss import lib from .types import * -from .constants import TPM2_ECC, TPM2_CAP, ESYS_TR +from .constants import TPM2_ECC, TPM2_CAP, ESYS_TR, TPM2_ALG, TPMA_OBJECT, TPM2_RH +from .ESAPI import ESAPI from asn1crypto.core import ObjectIdentifier, Sequence, Boolean, OctetString, Integer from asn1crypto import pem +from typing import Optional, Union _parent_rsa_template = TPMT_PUBLIC( @@ -100,7 +102,7 @@ # _BooleanOne is used to encode True in the same way as tpm2-tss-engine class _BooleanOne(Boolean): - def set(self, value): + def set(self, value: bool) -> None: self._native = bool(value) self.contents = b"\x00" if not value else b"\x01" self._header = None @@ -124,7 +126,13 @@ class _tssprivkey_der(Sequence): ("private", OctetString), ] - def __init__(self, private, public, empty_auth=True, parent=lib.TPM2_RH_OWNER): + def __init__( + self, + private: TPM2B_PRIVATE, + public: TPM2B_PUBLIC, + empty_auth: bool = True, + parent: Union[TPM2_RH, TPM2_HANDLE] = TPM2_RH.OWNER, + ): """Initialize TSSPrivKey using raw values. Args: @@ -139,26 +147,26 @@ def __init__(self, private, public, empty_auth=True, parent=lib.TPM2_RH_OWNER): self._parent = parent @property - def private(self): + def private(self) -> TPM2B_PRIVATE: """TPM2B_PRIVATE: The private part of the TPM key.""" return self._private @property - def public(self): + def public(self) -> TPM2B_PUBLIC: """TPM2B_PUBLIC: The public part of the TPM key.""" return self._public @property - def empty_auth(self): + def empty_auth(self) -> bool: """bool: Defines if the authorization is a empty password.""" return self._empty_auth @property - def parent(self): + def parent(self) -> Union[TPM2_RH, TPM2_HANDLE]: """int: Handle of the parent key.""" return self._parent - def to_der(self): + def to_der(self) -> bytes: """Encode the TSSPrivKey as DER encoded ASN.1. Returns: @@ -174,7 +182,7 @@ def to_der(self): seq["private"] = priv return seq.dump() - def to_pem(self): + def to_pem(self) -> bytes: """Encode the TSSPrivKey as PEM encoded ASN.1. Returns: @@ -184,11 +192,13 @@ def to_pem(self): return pem.armor("TSS2 PRIVATE KEY", der) @staticmethod - def _getparenttemplate(ectx): + def _getparenttemplate(ectx: ESAPI) -> Optional[TPMT_PUBLIC]: more = True al = list() while more: - more, data = ectx.get_capability(TPM2_CAP.ALGS, 0, lib.TPM2_MAX_CAP_ALGS) + more, data = ectx.get_capability( + TPM2_CAP(TPM2_CAP.ALGS), 0, lib.TPM2_MAX_CAP_ALGS + ) algs = data.data.algorithms for i in range(0, algs.count): al.append(algs.algProperties[i].alg) @@ -199,11 +209,13 @@ def _getparenttemplate(ectx): return None @staticmethod - def _getparent(ectx, keytype, parent): - if parent == lib.TPM2_RH_OWNER: + def _getparent( + ectx: ESAPI, keytype: bool, parent: Union[TPM2_RH, TPM2_HANDLE] + ) -> ESYS_TR: + if parent == TPM2_RH.OWNER: template = TSSPrivKey._getparenttemplate(ectx) else: - return ectx.tr_from_tpmpublic(parent) + return ectx.tr_from_tpmpublic(TPM2_HANDLE(parent)) if template is None: raise RuntimeError("Unable to find supported parent key type") inpub = TPM2B_PUBLIC(publicArea=template) @@ -213,11 +225,11 @@ def _getparent(ectx, keytype, parent): in_public=inpub, outside_info=TPM2B_DATA(), creation_pcr=TPML_PCR_SELECTION(), - session1=ESYS_TR.PASSWORD, + session1=ESYS_TR(ESYS_TR.PASSWORD), ) return phandle - def load(self, ectx, password=None): + def load(self, ectx: ESAPI, password: Optional[bytes] = None) -> ESYS_TR: """Load the TSSPrivKey. Args: @@ -237,7 +249,13 @@ def load(self, ectx, password=None): return handle @classmethod - def create(cls, ectx, template, parent=lib.TPM2_RH_OWNER, password=None): + def create( + cls, + ectx: ESAPI, + template: TPMT_PUBLIC, + parent: Union[TPM2_RH, TPM2_HANDLE] = TPM2_RH.OWNER, + password: Optional[bytes] = None, + ) -> "TSSPrivKey": """Create a TssPrivKey using a template. Note: @@ -245,7 +263,7 @@ def create(cls, ectx, template, parent=lib.TPM2_RH_OWNER, password=None): Args: ectx (ESAPI): The ESAPI instance to use for creating the key. - template (TPM2B_PUBLIC): The key template. + template (TPMT_PUBLIC): The key template. parent (int): The parent of the key, default is TPM2_RH_OWNER. password (bytes): The password to set for the key, default is None. @@ -269,8 +287,13 @@ def create(cls, ectx, template, parent=lib.TPM2_RH_OWNER, password=None): @classmethod def create_rsa( - cls, ectx, keyBits=2048, exponent=0, parent=lib.TPM2_RH_OWNER, password=None - ): + cls, + ectx: ESAPI, + keyBits: int = 2048, + exponent: int = 0, + parent: Union[TPM2_RH, TPM2_HANDLE] = TPM2_RH.OWNER, + password: Optional[bytes] = None, + ) -> "TSSPrivKey": """Create a RSA TssPrivKey using a standard RSA key template. Args: @@ -290,8 +313,12 @@ def create_rsa( @classmethod def create_ecc( - cls, ectx, curveID=TPM2_ECC.NIST_P256, parent=lib.TPM2_RH_OWNER, password=None - ): + cls, + ectx: ESAPI, + curveID: TPM2_ECC = TPM2_ECC.NIST_P256, + parent: Union[TPM2_RH, TPM2_HANDLE] = TPM2_RH.OWNER, + password: Optional[bytes] = None, + ) -> "TSSPrivKey": """Create an ECC TssPrivKey using a standard ECC key template. Args: @@ -308,7 +335,7 @@ def create_ecc( return cls.create(ectx, template, parent, password) @classmethod - def from_der(cls, data): + def from_der(cls, data: bytes) -> "TSSPrivKey": """Load a TSSPrivKey from DER ASN.1. Args: @@ -327,7 +354,7 @@ def from_der(cls, data): return cls(private, public, empty_auth, parent) @classmethod - def from_pem(cls, data): + def from_pem(cls, data: bytes) -> "TSSPrivKey": """Load a TSSPrivKey from PEM ASN.1. Args: diff --git a/src/tpm2_pytss/types.py b/src/tpm2_pytss/types.py index b155e6de..f9659b28 100644 --- a/src/tpm2_pytss/types.py +++ b/src/tpm2_pytss/types.py @@ -40,7 +40,13 @@ TPM2_SE, TPM2_HR, ) -from typing import Union, Tuple, Optional +from typing import Union, Tuple, Optional, Any, Iterable, List + +try: + from typing import Self +except ImportError: + # assume mypy is running on python 3.11+ + pass import sys try: @@ -71,7 +77,9 @@ class TPM2_HANDLE(int): class TPM_OBJECT(object): """ Abstract Base class for all TPM Objects. Not suitable for direct instantiation.""" - def __init__(self, _cdata=None, **kwargs): + _cdata: ffi.CData + + def __init__(self, _cdata: Optional[Any] = None, **kwargs: Any): # Rather than trying to mock the FFI interface, just avoid it and return # the base object. This is really only needed for documentation, and it @@ -114,7 +122,7 @@ def __init__(self, _cdata=None, **kwargs): v = subobj TPM_OBJECT.__setattr__(self, k, v) - def __getattribute__(self, key): + def __getattribute__(self, key: str) -> Any: try: # go through object to avoid invoking THIS objects __getattribute__ call # and thus infinite recursion @@ -137,7 +145,7 @@ def __getattribute__(self, key): obj = _convert_to_python_native(globals(), x, parent=self._cdata) return obj - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: _value = value _cdata = object.__getattribute__(self, "_cdata") @@ -199,10 +207,10 @@ def __setattr__(self, key, value): # recurse so we can get handling of setattr with Python wrapped data setattr(self, key, value) - def __dir__(self): - return object.__dir__(self) + dir(self._cdata) + def __dir__(self) -> Iterable[str]: + return list(object.__dir__(self)) + dir(self._cdata) - def marshal(self): + def marshal(self) -> bytes: """Marshal instance into bytes. Returns: @@ -223,7 +231,7 @@ def marshal(self): return bytes(buf[0 : offset[0]]) @classmethod - def unmarshal(cls, buf): + def unmarshal(cls, buf: bytes) -> Tuple["Self", int]: """Unmarshal bytes into type instance. Args: @@ -247,7 +255,9 @@ class TPM2B_SIMPLE_OBJECT(TPM_OBJECT): """ Abstract Base class for all TPM2B Simple Objects. A Simple object contains only a size and byte buffer fields. This is not suitable for direct instantiation.""" - def __init__(self, _cdata=None, **kwargs): + def __init__( + self, _cdata: Optional[Union[ffi.CData, bytes, str]] = None, **kwargs: Any + ): _cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs) _bytefield = type(self)._get_bytefield() @@ -267,14 +277,14 @@ def __init__(self, _cdata=None, **kwargs): super().__init__(_cdata=_cdata) @classmethod - def _get_bytefield(cls): + def _get_bytefield(cls) -> Optional[str]: tipe = ffi.typeof(f"{cls.__name__}") for f in tipe.fields: if f[0] != "size": return f[0] return None - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key == "size": raise AttributeError(f"{key} is read only") @@ -288,19 +298,21 @@ def __setattr__(self, key, value): else: super().__setattr__(key, value) - def __getattribute__(self, key): + def __getattribute__(self, key: str) -> Any: _bytefield = type(self)._get_bytefield() if key == _bytefield: b = getattr(self._cdata, _bytefield) rb = _ref_parent(b, self._cdata) - return memoryview(ffi.buffer(rb, self._cdata.size)) + return memoryview(ffi.buffer(rb, int(self._cdata.size))) return super().__getattribute__(key) - def __len__(self): - return self._cdata.size + def __len__(self) -> int: + return int(self._cdata.size) - def __getitem__(self, index): + def __getitem__(self, index: slice) -> Any: _bytefield = type(self)._get_bytefield() + if _bytefield is None: + raise RuntimeError("unable to find byte field") buf = getattr(self, _bytefield) if isinstance(index, int): if index >= self._cdata.size: @@ -311,8 +323,10 @@ def __getitem__(self, index): else: raise TypeError("index must an int or a slice") - def __bytes__(self): + def __bytes__(self) -> bytes: _bytefield = type(self)._get_bytefield() + if _bytefield is None: + raise RuntimeError("unable to find byte field") buf = getattr(self, _bytefield) return bytes(buf) @@ -331,7 +345,7 @@ def __str__(self) -> str: b = self.__bytes__() return binascii.hexlify(b).decode() - def __eq__(self, value): + def __eq__(self, value: object) -> bool: b = self.__bytes__() return b == value @@ -346,14 +360,14 @@ class TPML_Iterator(object): do_something(alg) """ - def __init__(self, tpml): + def __init__(self, tpml: "TPML_OBJECT"): self._tpml = tpml self._index = 0 - def __iter__(self): + def __iter__(self) -> "Self": return self - def __next__(self): + def __next__(self) -> Any: if self._index > self._tpml.count - 1: raise StopIteration @@ -367,7 +381,7 @@ class TPML_OBJECT(TPM_OBJECT): """ Abstract Base class for all TPML Objects. A TPML object is an object that contains a list of objects. This is not suitable for direct instantiation.""" - def __init__(self, _cdata=None, **kwargs): + def __init__(self, _cdata: Optional[Any] = None, **kwargs: Any): _cdata, kwargs = _fixup_cdata_kwargs(self, _cdata, kwargs) super().__init__(_cdata=_cdata) @@ -411,7 +425,7 @@ def __init__(self, _cdata=None, **kwargs): self._cdata.count = len(kwargs[key]) - def __getattribute__(self, key): + def __getattribute__(self, key: str) -> Any: try: # Can the parent handle it? @@ -449,10 +463,10 @@ def __getattribute__(self, key): return l - def __getitem__(self, item): + def __getitem__(self, item: Union[int, slice]) -> Any: item_was_int = isinstance(item, int) try: - return object.__getitem__(self, item) + return getattr(object, "__getitem__")(self, item) except AttributeError: pass @@ -467,6 +481,8 @@ def __getitem__(self, item): tipe = tipe.item field_name = next((v[0] for v in tipe.fields if v[0] != "count"), None) + if field_name is None: + raise AttributeError("No non count field found") if isinstance(item, int): item = slice(item, item + 1) @@ -496,11 +512,11 @@ def __getitem__(self, item): return objects[0] if item_was_int else objects - def __len__(self): + def __len__(self) -> int: - return self._cdata.count + return int(self._cdata.count) - def __setitem__(self, key, value): + def __setitem__(self, key: Union[int, slice], value: Any) -> None: if not isinstance(key, (int, slice)): raise TypeError(f"list indices must be integers or slices, not {type(key)}") @@ -515,6 +531,8 @@ def __setitem__(self, key, value): tipe = tipe.item field_name = next((v[0] for v in tipe.fields if v[0] != "count"), None) + if field_name is None: + raise AttributeError("No non count field found") cdata_list = self._cdata.__getattribute__(field_name) @@ -537,7 +555,7 @@ def __setitem__(self, key, value): if key.stop > self._cdata.count: self._cdata.count = key.stop - def __iter__(self): + def __iter__(self) -> TPML_Iterator: return TPML_Iterator(self) @@ -557,29 +575,29 @@ class TPM2B_NAME(TPM2B_SIMPLE_OBJECT): pass -def _handle_sym_common(objstr, default_mode="null"): +def _handle_sym_common(objstr: str, default_mode: str = "null") -> Tuple[int, TPM2_ALG]: if objstr is None or len(objstr) == 0: objstr = "128" - bits = objstr[:3] + bitstr = objstr[:3] expected = ["128", "192", "256"] - if bits not in expected: - raise ValueError(f'Expected bits to be one of {expected}, got: "{bits}"') + if bitstr not in expected: + raise ValueError(f'Expected bits to be one of {expected}, got: "{bitstr}"') - bits = int(bits) + bits = int(bitstr) # go past bits objstr = objstr[3:] if len(objstr) == 0: - mode = default_mode + modestr = default_mode else: expected = ["cfb", "cbc", "ofb", "ctr", "ecb"] if objstr not in expected: raise ValueError(f'Expected mode to be one of {expected}, got: "{objstr}"') - mode = objstr + modestr = objstr - mode = TPM2_ALG.parse(mode) + mode = TPM2_ALG.parse(modestr) return (bits, mode) @@ -587,7 +605,7 @@ def _handle_sym_common(objstr, default_mode="null"): class TPMT_SYM_DEF(TPM_OBJECT): @classmethod def parse( - cls, alg: str, is_restricted: bool = False, is_rsapss: bool = False + cls, alg: Optional[str], is_restricted: bool = False, is_rsapss: bool = False ) -> "TPMT_SYM_DEF": """Builds a TPMT_SYM_DEF from a tpm2-tools like specifier strings. @@ -648,8 +666,10 @@ class TPMT_SYM_DEF_OBJECT(TPMT_SYM_DEF): class TPMT_PUBLIC(TPM_OBJECT): + nameAlg: TPM2_ALG + @staticmethod - def _handle_rsa(objstr, templ): + def _handle_rsa(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.RSA if objstr is None or objstr == "": @@ -667,7 +687,7 @@ def _handle_rsa(objstr, templ): return True @staticmethod - def _handle_ecc(objstr, templ): + def _handle_ecc(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.ECC if objstr is None or objstr == "": @@ -683,7 +703,7 @@ def _handle_ecc(objstr, templ): return True @staticmethod - def _handle_aes(objstr, templ): + def _handle_aes(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.SYMCIPHER templ.parameters.symDetail.sym.algorithm = TPM2_ALG.AES @@ -693,7 +713,7 @@ def _handle_aes(objstr, templ): return False @staticmethod - def _handle_camellia(objstr, templ): + def _handle_camellia(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.SYMCIPHER templ.parameters.symDetail.sym.algorithm = TPM2_ALG.CAMELLIA @@ -704,7 +724,7 @@ def _handle_camellia(objstr, templ): return False @staticmethod - def _handle_sm4(objstr, templ): + def _handle_sm4(objstr: str, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.SYMCIPHER templ.parameters.symDetail.sym.algorithm = TPM2_ALG.SM4 @@ -717,28 +737,28 @@ def _handle_sm4(objstr, templ): return False @staticmethod - def _handle_xor(_, templ): + def _handle_xor(_: Any, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.KEYEDHASH templ.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG.XOR return True @staticmethod - def _handle_hmac(_, templ): + def _handle_hmac(_: Any, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.KEYEDHASH templ.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG.HMAC return True @staticmethod - def _handle_keyedhash(_, templ): + def _handle_keyedhash(_: Any, templ: "TPMT_PUBLIC") -> bool: templ.type = TPM2_ALG.KEYEDHASH templ.parameters.keyedHashDetail.scheme.scheme = TPM2_ALG.NULL return False @staticmethod - def _error_on_conflicting_sign_attrs(templ): + def _error_on_conflicting_sign_attrs(templ: "TPMT_PUBLIC") -> None: """ If the scheme is set, both the encrypt and decrypt attributes cannot be set, check to see if this is the case, and turn down: @@ -764,7 +784,7 @@ def _error_on_conflicting_sign_attrs(templ): ) @staticmethod - def _handle_scheme_rsa(scheme, templ): + def _handle_scheme_rsa(scheme: Optional[str], templ: "TPMT_PUBLIC") -> bool: if scheme is None or len(scheme) == 0: scheme = "null" @@ -801,7 +821,7 @@ def _handle_scheme_rsa(scheme, templ): return True @staticmethod - def _handle_scheme_ecc(scheme, templ): + def _handle_scheme_ecc(scheme: Optional[str], templ: "TPMT_PUBLIC") -> bool: if scheme is None or len(scheme) == 0: scheme = "null" @@ -841,7 +861,7 @@ def _handle_scheme_ecc(scheme, templ): return True @staticmethod - def _handle_scheme_keyedhash(scheme, templ): + def _handle_scheme_keyedhash(scheme: Optional[str], templ: "TPMT_PUBLIC") -> None: if scheme is None or scheme == "": scheme = "sha256" @@ -860,7 +880,7 @@ def _handle_scheme_keyedhash(scheme, templ): ) @staticmethod - def _handle_scheme(scheme, templ): + def _handle_scheme(scheme: Optional[str], templ: "TPMT_PUBLIC") -> None: if templ.type == TPM2_ALG.RSA: TPMT_PUBLIC._handle_scheme_rsa(scheme, templ) elif templ.type == TPM2_ALG.ECC: @@ -874,7 +894,7 @@ def _handle_scheme(scheme, templ): ) @staticmethod - def _handle_asymdetail(detail, templ): + def _handle_asymdetail(detail: Optional[str], templ: "TPMT_PUBLIC") -> None: if templ.type == TPM2_ALG.KEYEDHASH: if detail is not None: @@ -901,8 +921,8 @@ def parse( objectAttributes: Union[ TPMA_OBJECT, int, str ] = TPMA_OBJECT.DEFAULT_TPM2_TOOLS_CREATE_ATTRS, - nameAlg: Union[TPM2_ALG, int, str] = "sha256", - authPolicy: bytes = None, + nameAlg: Union[TPM2_ALG, str] = "sha256", + authPolicy: Optional[bytes] = None, ) -> "TPMT_PUBLIC": """Builds a TPMT_PUBLIC from a tpm2-tools like specifier strings. @@ -969,9 +989,9 @@ def parse( keep_processing = False prefix = tuple(filter(lambda x: objstr.startswith(x), expected)) if len(prefix) == 1: - prefix = prefix[0] - keep_processing = getattr(TPMT_PUBLIC, f"_handle_{prefix}")( - objstr[len(prefix) :], templ + prefixstr = prefix[0] + keep_processing = getattr(TPMT_PUBLIC, f"_handle_{prefixstr}")( + objstr[len(prefixstr) :], templ ) else: raise ValueError( @@ -1000,13 +1020,13 @@ def parse( def from_pem( cls, data: bytes, - nameAlg: Union[TPM2_ALG, int] = TPM2_ALG.SHA256, - objectAttributes: Union[TPMA_OBJECT, int] = ( + nameAlg: TPM2_ALG = TPM2_ALG.SHA256, + objectAttributes: TPMA_OBJECT = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - symmetric: TPMT_SYM_DEF_OBJECT = None, - scheme: TPMT_ASYM_SCHEME = None, - password: bytes = None, + symmetric: Optional[TPMT_SYM_DEF_OBJECT] = None, + scheme: Optional[TPMT_ASYM_SCHEME] = None, + password: Optional[bytes] = None, ) -> "TPMT_PUBLIC": """Decode the public part from standard key encodings. @@ -1187,6 +1207,8 @@ class TPM2B_MAX_NV_BUFFER(TPM2B_SIMPLE_OBJECT): class TPM2B_NV_PUBLIC(TPM_OBJECT): + nvPublic: "TPMS_NV_PUBLIC" + def get_name(self) -> TPM2B_NAME: """Get the TPM name of the NV public area. @@ -1215,17 +1237,19 @@ class TPM2B_PRIVATE_VENDOR_SPECIFIC(TPM2B_SIMPLE_OBJECT): class TPM2B_PUBLIC(TPM_OBJECT): + publicArea: TPMT_PUBLIC + @classmethod def from_pem( cls, data: bytes, - nameAlg: Union[TPM2_ALG, int] = TPM2_ALG.SHA256, - objectAttributes: Union[TPMA_OBJECT, int] = ( + nameAlg: TPM2_ALG = TPM2_ALG.SHA256, + objectAttributes: TPMA_OBJECT = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - symmetric: TPMT_SYM_DEF_OBJECT = None, - scheme: TPMT_ASYM_SCHEME = None, - password: bytes = None, + symmetric: Optional[TPMT_SYM_DEF_OBJECT] = None, + scheme: Optional[TPMT_ASYM_SCHEME] = None, + password: Optional[bytes] = None, ) -> "TPM2B_PUBLIC": """Decode the public part from standard key encodings. @@ -1335,10 +1359,10 @@ def get_name(self) -> TPM2B_NAME: @classmethod def parse( cls, - alg="rsa", - objectAttributes=TPMA_OBJECT.DEFAULT_TPM2_TOOLS_CREATE_ATTRS, - nameAlg="sha256", - authPolicy=None, + alg: str = "rsa", + objectAttributes: TPMA_OBJECT = TPMA_OBJECT.DEFAULT_TPM2_TOOLS_CREATE_ATTRS, + nameAlg: Union[TPM2_ALG, str] = "sha256", + authPolicy: Optional[bytes] = None, ) -> "TPM2B_PUBLIC": """Builds a TPM2B_PUBLIC from a tpm2-tools like specifier strings. @@ -1393,6 +1417,8 @@ class TPMT_KEYEDHASH_SCHEME(TPM_OBJECT): class TPM2B_SENSITIVE(TPM_OBJECT): + sensitiveArea: "TPMT_SENSITIVE" + @classmethod def from_pem( cls, data: bytes, password: Optional[bytes] = None @@ -1424,12 +1450,12 @@ def from_pem( def keyedhash_from_secret( cls, secret: bytes, - nameAlg: Union[TPM2_ALG, int] = TPM2_ALG.SHA256, - objectAttributes: Union[TPMA_OBJECT, int] = ( + nameAlg: TPM2_ALG = TPM2_ALG.SHA256, + objectAttributes: TPMA_OBJECT = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - scheme: TPMT_KEYEDHASH_SCHEME = None, - seed: bytes = None, + scheme: Optional[TPMT_KEYEDHASH_SCHEME] = None, + seed: Optional[bytes] = None, ) -> Tuple["TPM2B_SENSITIVE", TPM2B_PUBLIC]: """Generate the private and public part for a keyed hash object from a secret. @@ -1465,13 +1491,13 @@ def keyedhash_from_secret( def symcipher_from_secret( cls, secret: bytes, - algorithm: Union[TPM2_ALG, int] = TPM2_ALG.AES, - mode: Union[TPM2_ALG, int] = TPM2_ALG.CFB, - nameAlg: Union[TPM2_ALG, int] = TPM2_ALG.SHA256, - objectAttributes: Union[TPMA_OBJECT, int] = ( + algorithm: TPM2_ALG = TPM2_ALG.AES, + mode: TPM2_ALG = TPM2_ALG.CFB, + nameAlg: TPM2_ALG = TPM2_ALG.SHA256, + objectAttributes: TPMA_OBJECT = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), - seed: bytes = None, + seed: Optional[bytes] = None, ) -> Tuple["TPM2B_SENSITIVE", TPM2B_PUBLIC]: """Generate the private and public part for a symcipher object from a secret. @@ -1499,7 +1525,7 @@ def symcipher_from_secret( pub = TPM2B_PUBLIC(publicArea=pa) return (priv, pub) - def to_pem(self, public: TPMT_PUBLIC, password=None) -> bytes: + def to_pem(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as PEM encoded ASN.1. Args: @@ -1544,7 +1570,7 @@ def to_der(self, public: TPMT_PUBLIC) -> bytes: return self.sensitiveArea.to_der(public) - def to_ssh(self, public: TPMT_PUBLIC, password: bytes = None) -> bytes: + def to_ssh(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as OPENSSH PEM format. Args: @@ -1717,9 +1743,9 @@ def parse(selections: str) -> "TPML_PCR_SELECTION": f"got {len(selectors)}" ) - selections = [TPMS_PCR_SELECTION.parse(x) for x in selectors] + parsed_selections = [TPMS_PCR_SELECTION.parse(x) for x in selectors] - return TPML_PCR_SELECTION(selections) + return TPML_PCR_SELECTION(parsed_selections) class TPML_TAGGED_PCR_PROPERTY(TPML_OBJECT): @@ -1818,7 +1844,11 @@ def from_tools(cls, data: bytes) -> "TPMS_CONTEXT": ctx.contextBlob, _ = TPM2B_CONTEXT_DATA.unmarshal(data[24:]) return ctx - def to_tools(self, session_type: TPM2_SE = None, auth_hash: TPM2_ALG = None): + def to_tools( + self, + session_type: Optional[TPM2_SE] = None, + auth_hash: Optional[TPM2_ALG] = None, + ) -> bytes: """Marshal the context into a tpm2-tools context blob. Args: @@ -1844,12 +1874,11 @@ def to_tools(self, session_type: TPM2_SE = None, auth_hash: TPM2_ALG = None): ) version = 1 - if session_type is not None: - version = 2 data = b"" - if version == 2: + if isinstance(session_type, TPM2_SE) and isinstance(auth_hash, TPM2_ALG): + version = 2 data = int(0xBADCC0DE).to_bytes(4, "big") + version.to_bytes(4, "big") data = data + session_type.to_bytes(1, "big") data = data + auth_hash.to_bytes(2, "big") @@ -1919,7 +1948,9 @@ class TPMS_PCR_SELECT(TPM_OBJECT): class TPMS_PCR_SELECTION(TPM_OBJECT): - def __init__(self, pcrs=None, **kwargs): + def __init__( + self, pcrs: Optional[Union[str, List[str], List[int]]] = None, **kwargs: Any + ): super().__init__(**kwargs) if not pcrs: @@ -1937,6 +1968,8 @@ def __init__(self, pcrs=None, **kwargs): return for pcr in pcrs: + if isinstance(pcr, str): + pcr = int(pcr) if pcr < 0 or pcr > lib.TPM2_PCR_LAST: raise ValueError(f"PCR Index out of range, got {pcr}") self._cdata.pcrSelect[pcr // 8] |= 1 << (pcr % 8) @@ -1981,6 +2014,7 @@ def parse(selection: str) -> "TPMS_PCR_SELECTION": except ValueError: halg = TPM2_ALG.parse(hunks[0]) + pcrs: Union[Iterable[int], str] if hunks[1] != "all": try: pcrs = [int(x.strip(), 0) for x in hunks[1].split(",")] @@ -2094,7 +2128,7 @@ class TPMU_PUBLIC_ID(TPM_OBJECT): class TPMT_SENSITIVE(TPM_OBJECT): @classmethod - def from_pem(cls, data, password: Optional[bytes] = None): + def from_pem(cls, data: bytes, password: Optional[bytes] = None) -> "Self": """Decode the private part from standard key encodings. Currently supports PEM, DER and SSH encoded private keys. @@ -2113,14 +2147,14 @@ def from_pem(cls, data, password: Optional[bytes] = None): @classmethod def keyedhash_from_secret( cls, - secret, - nameAlg=TPM2_ALG.SHA256, - objectAttributes=( + secret: bytes, + nameAlg: TPM2_ALG = TPM2_ALG.SHA256, + objectAttributes: TPMA_OBJECT = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), scheme: Optional[TPMT_KEYEDHASH_SCHEME] = None, seed: Optional[bytes] = None, - ): + ) -> Tuple["TPMT_SENSITIVE", TPMT_PUBLIC]: """Generate the private and public part for a keyed hash object from a secret. Args: @@ -2156,15 +2190,15 @@ def keyedhash_from_secret( @classmethod def symcipher_from_secret( cls, - secret, - algorithm=TPM2_ALG.AES, - mode=TPM2_ALG.CFB, - nameAlg=TPM2_ALG.SHA256, - objectAttributes=( + secret: bytes, + algorithm: TPM2_ALG = TPM2_ALG.AES, + mode: TPM2_ALG = TPM2_ALG.CFB, + nameAlg: TPM2_ALG = TPM2_ALG.SHA256, + objectAttributes: TPMA_OBJECT = ( TPMA_OBJECT.DECRYPT | TPMA_OBJECT.SIGN_ENCRYPT | TPMA_OBJECT.USERWITHAUTH ), seed: Optional[bytes] = None, - ): + ) -> Tuple["TPMT_SENSITIVE", TPMT_PUBLIC]: """ Generate the private and public part for a symcipher object from a secret. @@ -2207,11 +2241,11 @@ def symcipher_from_secret( def _serialize( self, - encoding: str, + encoding: serialization.Encoding, public: TPMT_PUBLIC, - format: str = serialization.PrivateFormat.TraditionalOpenSSL, - password: bytes = None, - ): + format: serialization.PrivateFormat = serialization.PrivateFormat.TraditionalOpenSSL, + password: Optional[bytes] = None, + ) -> bytes: k = private_to_key(self, public) enc_alg = ( @@ -2226,7 +2260,7 @@ def _serialize( return data - def to_pem(self, public: TPMT_PUBLIC, password: bytes = None): + def to_pem(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as PEM encoded ASN.1. public(TPMT_PUBLIC): The corresponding public key. @@ -2238,7 +2272,7 @@ def to_pem(self, public: TPMT_PUBLIC, password: bytes = None): return self._serialize(serialization.Encoding.PEM, public, password=password) - def to_der(self, public: TPMT_PUBLIC): + def to_der(self, public: TPMT_PUBLIC) -> bytes: """Encode the key as DER encoded ASN.1. public(TPMT_PUBLIC): The corresponding public key. @@ -2249,7 +2283,7 @@ def to_der(self, public: TPMT_PUBLIC): return self._serialize(serialization.Encoding.DER, public) - def to_ssh(self, public: TPMT_PUBLIC, password: bytes = None): + def to_ssh(self, public: TPMT_PUBLIC, password: Optional[bytes] = None) -> bytes: """Encode the key as SSH format. public(TPMT_PUBLIC): The corresponding public key. @@ -2312,7 +2346,9 @@ class TPMU_SIGNATURE(TPM_OBJECT): class TPMT_SIGNATURE(TPM_OBJECT): - def verify_signature(self, key, data): + def verify_signature( + self, key: Union[TPMT_PUBLIC, TPM2B_PUBLIC], data: bytes + ) -> None: """ Verify a TPM generated signature against a key. @@ -2325,7 +2361,7 @@ def verify_signature(self, key, data): """ _verify_signature(self, key, data) - def __bytes__(self): + def __bytes__(self) -> bytes: """Return the underlying bytes for the signature. For RSA and HMAC signatures return the signature bytes, for ECDSA return a ASN.1 encoded signature. diff --git a/src/tpm2_pytss/utils.py b/src/tpm2_pytss/utils.py index 51dc34ee..e877cf95 100644 --- a/src/tpm2_pytss/utils.py +++ b/src/tpm2_pytss/utils.py @@ -19,19 +19,22 @@ TPM2_ECC, TPM2_PT, TPM2_RH, + TPM2_ALG, ) from .internal.templates import _ek from .TSS2_Exception import TSS2_Exception from cryptography.hazmat.primitives import constant_time as ct from cryptography.hazmat.primitives import hashes from cryptography.hazmat.backends import default_backend -from typing import Optional, Tuple, Callable, List +from typing import Optional, Tuple, Callable, List, Union import secrets def make_credential( - public: TPM2B_PUBLIC, credential: bytes, name: TPM2B_NAME + public: Union[TPM2B_PUBLIC, TPMT_PUBLIC], + credential: Union[bytes, TPM2B_DIGEST], + name: Union[TPM2B_NAME, bytes], ) -> Tuple[TPM2B_ID_OBJECT, TPM2B_ENCRYPTED_SECRET]: """Encrypts credential for use with activate_credential @@ -62,7 +65,11 @@ def make_credential( enc_cred = _encrypt(cipher, symmode, symkey, credential.marshal()) halg = _get_digest(public.nameAlg) - hmackey = _kdfa(public.nameAlg, seed, b"INTEGRITY", b"", b"", halg.digest_size * 8) + if halg is None: + raise ValueError(f"unsupported digest algorithm {public.nameAlg}") + hmackey = _kdfa( + public.nameAlg, seed, b"INTEGRITY", b"", b"", halg().digest_size * 8 + ) outerhmac = _hmac(halg, hmackey, enc_cred, name) hmacdata = TPM2B_DIGEST(buffer=outerhmac).marshal() @@ -183,6 +190,8 @@ def wrap( klen = int(bits / 8) symkey = secrets.token_bytes(klen) halg = _get_digest(public.publicArea.nameAlg) + if halg is None: + raise ValueError(f"unsupported digest algorithm {public.nameAlg}") h = hashes.Hash(halg(), backend=default_backend()) h.update(sensb) h.update(name) @@ -198,8 +207,10 @@ def wrap( dupsens = _encrypt(cipher, mode, outerkey, encsens) halg = _get_digest(newparent.nameAlg) + if halg is None: + raise ValueError(f"unsupported digest algorithm {public.nameAlg}") hmackey = _kdfa( - newparent.nameAlg, seed, b"INTEGRITY", b"", b"", halg.digest_size * 8 + newparent.nameAlg, seed, b"INTEGRITY", b"", b"", halg().digest_size * 8 ) outerhmac = _hmac(halg, hmackey, dupsens, name) hmacdata = TPM2B_DIGEST(buffer=outerhmac).marshal() @@ -252,10 +263,12 @@ def unwrap( ValueError: If the public key type or symmetric algorithm are not supported """ halg = _get_digest(newparentpub.nameAlg) + if halg is None: + raise ValueError(f"unsupported digest algorithm {public.nameAlg}") seed = _secret_to_seed(newparentpriv, newparentpub, b"DUPLICATE\x00", outsymseed) hmackey = _kdfa( - newparentpub.nameAlg, seed, b"INTEGRITY", b"", b"", halg.digest_size * 8 + newparentpub.nameAlg, seed, b"INTEGRITY", b"", b"", halg().digest_size * 8 ) buffer = bytes(duplicate) @@ -280,11 +293,13 @@ def unwrap( cipher, mode, bits = _symdef_to_crypt(symdef) halg = _get_digest(public.publicArea.nameAlg) + if halg is None: + raise ValueError(f"unsupported digest algorithm {public.nameAlg}") # unwrap the inner encryption which is the integrity + TPM2B_SENSITIVE innerint_and_decsens = _decrypt(cipher, mode, symkey, sensb) - innerint, offset = TPM2B_DIGEST.unmarshal(innerint_and_decsens) - innerint = bytes(innerint) + innerint2b, offset = TPM2B_DIGEST.unmarshal(innerint_and_decsens) + innerint = bytes(innerint2b) decsensb = innerint_and_decsens[offset:] h = hashes.Hash(halg(), backend=default_backend()) @@ -315,11 +330,11 @@ class NoSuchIndex(Exception): index (int): The NV index requested """ - def __init__(self, index): + def __init__(self, index: int): self.index = index - def __str__(self): - return f"NV index 0x{index:08x} does not exist" + def __str__(self) -> str: + return f"NV index 0x{self.index:08x} does not exist" class NVReadEK: @@ -336,10 +351,10 @@ class NVReadEK: def __init__( self, ectx: ESAPI, - auth_handle: ESYS_TR = None, - session1: ESYS_TR = ESYS_TR.PASSWORD, - session2: ESYS_TR = ESYS_TR.NONE, - session3: ESYS_TR = ESYS_TR.NONE, + auth_handle: Optional[ESYS_TR] = None, + session1: ESYS_TR = ESYS_TR(ESYS_TR.PASSWORD), + session2: ESYS_TR = ESYS_TR(ESYS_TR.NONE), + session3: ESYS_TR = ESYS_TR(ESYS_TR.NONE), ): self._ectx = ectx self._auth_handle = auth_handle @@ -351,7 +366,7 @@ def __init__( more = True while more: more, data = self._ectx.get_capability( - TPM2_CAP.TPM_PROPERTIES, + TPM2_CAP(TPM2_CAP.TPM_PROPERTIES), TPM2_PT.FIXED, 4096, session1=session2, @@ -367,7 +382,7 @@ def __init__( def __call__(self, index: Union[int, TPM2_RH]) -> bytes: try: nvh = self._ectx.tr_from_tpmpublic( - index, session1=self._session2, session2=self._session3 + TPM2_HANDLE(index), session1=self._session2, session2=self._session3 ) except TSS2_Exception as e: if e.rc == 0x18B: @@ -399,7 +414,7 @@ def __call__(self, index: Union[int, TPM2_RH]) -> bytes: def create_ek_template( ektype: str, nv_read_cb: Callable[[Union[int, TPM2_RH]], bytes] -) -> Tuple[bytes, TPM2B_PUBLIC]: +) -> Tuple[Optional[bytes], TPM2B_PUBLIC]: """Creates an Endorsenment Key template which when created matches the EK certificate The template is created according to TCG EK Credential Profile For TPM Family 2.0: diff --git a/test/test_crypto.py b/test/test_crypto.py index b8a2a4e9..a3bf5900 100644 --- a/test/test_crypto.py +++ b/test/test_crypto.py @@ -522,11 +522,11 @@ def test_kdfe(self): str(e.exception), f"unsupported digest algorithm: {TPM2_ALG.LAST + 1}" ) - def test_get_alg(self): - alg = crypto._get_alg(TPM2_ALG.AES) + def test_get_symmetric(self): + alg = crypto._get_symmetric(TPM2_ALG.AES) self.assertEqual(alg, crypto.AES) - nalg = crypto._get_alg(TPM2_ALG.LAST + 1) + nalg = crypto._get_symmetric(TPM2_ALG.LAST + 1) self.assertEqual(nalg, None) def test_symdef_to_crypt(self): diff --git a/test/test_utils.py b/test/test_utils.py index fa16cde8..112b9255 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -6,7 +6,7 @@ from tpm2_pytss.internal.crypto import ( _generate_seed, public_to_key, - _get_alg, + _get_symmetric, _get_digest, ) from tpm2_pytss.utils import * @@ -231,7 +231,7 @@ def test_make_credential_ecc_camellia(self): self.assertEqual(b"credential data", bytes(certinfo)) def test_make_credential_ecc_sm4(self): - if _get_alg(TPM2_ALG.SM4) is None: + if _get_symmetric(TPM2_ALG.SM4) is None: self.skipTest("SM4 is not supported by the cryptography module") elif _get_digest(TPM2_ALG.SM3_256) is None: self.skipTest("SM3 is not supported by the cryptography module")