From 49f20a7470c0409147701a6170b36367fa000c68 Mon Sep 17 00:00:00 2001 From: Erik Larsson Date: Wed, 17 Jan 2024 20:42:50 +0100 Subject: [PATCH] ci: add mypy test Start small with checks for just some files Solve the problem with missing type hints for the cffi generated extension by handcrafting the type hints for ffi and generate the type hints for lib. Signed-off-by: Erik Larsson --- .ci/run.sh | 7 ++ .github/workflows/tests.yaml | 22 ++++ .gitignore | 1 + pyproject.toml | 22 +++- setup.cfg | 1 + setup.py | 138 +++++++++++++++++++++++++- src/tpm2_pytss/TSS2_Exception.py | 20 ++-- src/tpm2_pytss/__init__.py | 8 +- src/tpm2_pytss/_libtpm2_pytss/ffi.pyi | 17 ++++ src/tpm2_pytss/internal/utils.py | 98 +++++++++++------- 10 files changed, 285 insertions(+), 49 deletions(-) create mode 100644 src/tpm2_pytss/_libtpm2_pytss/ffi.pyi diff --git a/.ci/run.sh b/.ci/run.sh index 785df5d2..4f85abd1 100755 --- a/.ci/run.sh +++ b/.ci/run.sh @@ -103,6 +103,11 @@ function run_style() { "${PYTHON}" -m black --diff --check "${SRC_ROOT}" } +function run_mypy() { + python3 -m pip install --user -e .[dev] + mypy --strict src/tpm2_pytss +} + if [ "x${TEST}" != "x" ]; then run_test elif [ "x${WHITESPACE}" != "x" ]; then @@ -111,4 +116,6 @@ elif [ "x${STYLE}" != "x" ]; then run_style elif [ "x${PUBLISH_PKG}" != "x" ]; then run_publish_pkg +elif [ "x${MYPY}" != "x" ]; then + run_mypy fi diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index de6fe93d..a475ea60 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -67,6 +67,28 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} files: /tmp/coverage.xml + mypy: + runs-on: ubuntu-20.04 + + steps: + - name: Checkout Repository + uses: actions/checkout@v2 + + - name: Set up Python 3.x + uses: actions/setup-python@v2 + with: + python-version: 3.x + + - name: Install dependencies + env: + TPM2_TSS_VERSION: 4.0.1 + run: ./.ci/install-deps.sh + + - name: Check + env: + MYPY: 1 + run: ./.ci/run.sh + whitespace-check: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index 396cfdbe..c27f1cf2 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ htmlcov /.pytest_cache/ src/tpm2_pytss/internal/type_mapping.py src/tpm2_pytss/internal/versions.py +src/tpm2_pytss/_libtpm2_pytss/lib.pyi diff --git a/pyproject.toml b/pyproject.toml index 80a15f9b..4a2b4171 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=44", "wheel", "setuptools_scm[toml]>=3.4.3", "pycparser", "pkgconfig"] +requires = ["setuptools>=44", "wheel", "setuptools_scm[toml]>=3.4.3", "pycparser", "pkgconfig", "cffi"] build-backend = "setuptools.build_meta" [tool.setuptools_scm] @@ -19,6 +19,26 @@ exclude = ''' | build | dist | esys_binding.py + | .*\.pyi$ ) ) ''' + +[tool.mypy] +mypy_path = "mypy_stubs" +exclude = [ + 'src/tpm2_pytss/utils.py', + 'src/tpm2_pytss/internal/templates.py', + 'src/tpm2_pytss/encoding.py', + 'src/tpm2_pytss/tsskey.py', + 'src/tpm2_pytss/policy.py', + 'src/tpm2_pytss/ESAPI.py', + 'src/tpm2_pytss/FAPI.py', + 'src/tpm2_pytss/types.py', + 'src/tpm2_pytss/internal/crypto.py', + 'src/tpm2_pytss/TCTILdr.py', + 'src/tpm2_pytss/TCTISPIHelper.py', + 'src/tpm2_pytss/TCTI.py', + 'src/tpm2_pytss/constants.py', + 'src/tpm2_pytss/fapi_info.py', +] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 1b7d1313..6376b7f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,3 +55,4 @@ dev = myst-parser build installer + mypy diff --git a/setup.py b/setup.py index 1b5f5135..e4c9f2bf 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ Enum, ) from textwrap import dedent +from cffi import cparser # workaround bug https://github.com/pypa/pip/issues/7953 site.ENABLE_USER_SITE = "--user" in sys.argv[1:] @@ -282,8 +283,143 @@ def run(self): self.copy_file(vp, svp) +class type_hints_generator(type_generator): + is_int = set(("int",)) + callbacks = dict() + functions = dict() + + def build_function(self, d): + args = list() + for param in d.args.params: + pn = param.name + if pn is None: + # if the param doesn't have a name, ignore + continue + elif isinstance( + param.type, + (cparser.pycparser.c_ast.PtrDecl, cparser.pycparser.c_ast.ArrayDecl), + ): + ft = "CData" + elif isinstance( + param.type, cparser.pycparser.c_ast.TypeDecl + ) and isinstance(param.type.type, cparser.pycparser.c_ast.IdentifierType): + tn = param.type.type.names[0] + if tn in self.is_int: + ft = "int" + else: + raise ValueError(f"unable to handle C type {param.type}") + args.append((pn, ft)) + if isinstance(d.type, cparser.pycparser.c_ast.TypeDecl) and isinstance( + d.type.type, cparser.pycparser.c_ast.IdentifierType + ): + tn = d.type.type.names[0] + if tn in self.is_int: + rt = "int" + elif tn == "void": + rt = "None" + else: + rt = "CData" + elif isinstance(d.type, cparser.pycparser.c_ast.PtrDecl): + rt = "CData" + return (rt, tuple(args)) + + def write_type_hints(self, macros): + output = dedent( + """# SPDX-License-Identifier: BSD-2 + # This file is generated during the build process. + from typing import Callable + from .ffi import CData + + # Defines + """ + ) + + # assume all defines are ints + for m in macros: + output += f"{m}: int\n" + + output += "\n# Callback definitions\n" + for cname, callback in self.callbacks.items(): + rt, args = callback + paramtypes = list() + for _, at in args: + paramtypes.append(at) + output += f"{cname}: Callable[[{', '.join(paramtypes)}], {rt}]\n" + + output += "\n# Function definitions\n" + for fname, function in self.functions.items(): + rt, args = function + params = list() + for an, at in args: + if an == "in": + an = "in_" + params.append(f"{an}: {at}") + output += f"def {fname}({', '.join(params)}) -> {rt}:...\n" + + p = os.path.join(self.build_lib, "tpm2_pytss/_libtpm2_pytss/lib.pyi") + sp = os.path.join( + os.path.dirname(__file__), "src/tpm2_pytss/_libtpm2_pytss/lib.pyi" + ) + + if not self.dry_run: + self.mkpath(os.path.dirname(p)) + with open(p, "wt") as tf: + tf.seek(0) + tf.truncate(0) + tf.write(output) + + if self.inplace: + self.copy_file(p, sp) + + def run(self): + super().run() + + with open("libesys.h", "r") as sf: + cdata = sf.read() + + parser = cparser.Parser() + ast, macros, _ = parser._parse(cdata) + + for d in ast: + name = d.name + if isinstance(name, str) and name.startswith("__cffi_"): + # internal cffi stuff, ignore + continue + if isinstance(d, cparser.pycparser.c_ast.Typedef): + d = d.type + if isinstance( + d.type, + ( + cparser.pycparser.c_ast.Struct, + cparser.pycparser.c_ast.Union, + cparser.pycparser.c_ast.Enum, + ), + ): + # ignore unions, structs and enums + pass + elif isinstance(d.type, cparser.pycparser.c_ast.IdentifierType): + # check if type is int, otherwise ignore + if len(d.type.names) == 0: + break + tn = d.type.names[0] + if tn in self.is_int: + self.is_int.add(name) + elif isinstance(d, cparser.pycparser.c_ast.PtrDecl) and isinstance( + d.type, cparser.pycparser.c_ast.FuncDecl + ): + rt, args = self.build_function(d.type) + self.callbacks[name] = (rt, args) + elif isinstance(d, cparser.pycparser.c_ast.Decl) and isinstance( + d.type, cparser.pycparser.c_ast.FuncDecl + ): + rt, args = self.build_function(d.type) + self.functions[name] = (rt, args) + + self.write_type_hints(macros) + + setup( use_scm_version=True, cffi_modules=["scripts/libtss2_build.py:ffibuilder"], - cmdclass={"build_ext": type_generator}, + cmdclass={"build_ext": type_hints_generator}, ) diff --git a/src/tpm2_pytss/TSS2_Exception.py b/src/tpm2_pytss/TSS2_Exception.py index 309e25db..58327174 100644 --- a/src/tpm2_pytss/TSS2_Exception.py +++ b/src/tpm2_pytss/TSS2_Exception.py @@ -1,5 +1,9 @@ from ._libtpm2_pytss import lib, ffi -from typing import Union +from typing import Union, TYPE_CHECKING + + +if TYPE_CHECKING: + from .constants import TSS2_RC, TPM2_RC class TSS2_Exception(RuntimeError): @@ -25,7 +29,7 @@ def __init__(self, rc: Union["TSS2_RC", "TPM2_RC", int]): else: self._error = self._rc - def _parse_fmt1(self): + def _parse_fmt1(self) -> None: self._error = lib.TPM2_RC_FMT1 + (self.rc & 0x3F) if self.rc & lib.TPM2_RC_P: @@ -36,31 +40,31 @@ def _parse_fmt1(self): self._handle = (self.rc & lib.TPM2_RC_N_MASK) >> 8 @property - def rc(self): + def rc(self) -> int: """int: The return code from the API call.""" return self._rc @property - def handle(self): + def handle(self) -> int: """int: The handle related to the error, 0 if not related to any handle.""" return self._handle @property - def parameter(self): + def parameter(self) -> int: """int: The parameter related to the error, 0 if not related to any parameter.""" return self._parameter @property - def session(self): + def session(self) -> int: """int: The session related to the error, 0 if not related to any session.""" return self._session @property - def error(self): + def error(self) -> int: """int: The error with handle, parameter and session stripped.""" return self._error @property - def fmt1(self): + def fmt1(self) -> bool: """bool: True if the error is related to a handle, parameter or session """ return bool(self._rc & lib.TPM2_RC_FMT1) diff --git a/src/tpm2_pytss/__init__.py b/src/tpm2_pytss/__init__.py index 83f401c8..c931ee43 100644 --- a/src/tpm2_pytss/__init__.py +++ b/src/tpm2_pytss/__init__.py @@ -4,13 +4,13 @@ # if we can't, provide a better message. try: from ._libtpm2_pytss import lib -except ImportError as e: - parts = e.msg.split(": ", 2) +except ImportError as ie: + parts = ie.msg.split(": ", 2) if len(parts) != 3: - raise e + raise ie path, error, symbol = parts if error != "undefined symbol": - raise e + raise ie raise ImportError( f"failed to load tpm2-tss bindigs in {path} due to missing symbol {symbol}, " + "ensure that you are using the same libraries the python module was built against." diff --git a/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi b/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi new file mode 100644 index 00000000..2a9c1ab7 --- /dev/null +++ b/src/tpm2_pytss/_libtpm2_pytss/ffi.pyi @@ -0,0 +1,17 @@ +from typing import Optional, Callable, Iterable + +error: type[Exception] +class CData: + def __getitem__(self, index: int) -> "CData": ... + +class CType: + kind: str + cname: str + item: "CType" + fields: Iterable[str] + +NULL: CData +def gc(cdata: CData, destructor: Callable[[CData], None], size: int = 0)-> CData: ... +def typeof(cdata: CData) -> CType: ... # FIXME, support str +def new(cdecl: str, init: Optional[Callable[[CData], CData]] = None) -> CData: ... +def string(cdata: CData, maxlen: Optional[int] = None) -> bytes: ... diff --git a/src/tpm2_pytss/internal/utils.py b/src/tpm2_pytss/internal/utils.py index 80a8ee8c..ab6ba840 100644 --- a/src/tpm2_pytss/internal/utils.py +++ b/src/tpm2_pytss/internal/utils.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: BSD-2 import logging import sys -from typing import List +from typing import List, Optional, Union, TYPE_CHECKING, Callable, Any from .._libtpm2_pytss import ffi, lib from ..TSS2_Exception import TSS2_Exception +if TYPE_CHECKING: + from ..constants import TPM_FRIENDLY_INT + try: from .versions import _versions except ImportError as e: @@ -53,7 +56,7 @@ def __init__(self, version: str): hunks = version.split(".") extra_data = version.split("-")[1:] - def handle_extra(): + def handle_extra() -> None: nonlocal extra_data nonlocal commits nonlocal rc @@ -93,7 +96,7 @@ def handle_extra(): else: raise ValueError(f'Invalid version string, got: "{version}"') - def cleanse(xstr): + def cleanse(xstr: str) -> str: if "-" in xstr: return xstr[: xstr.find("-")] @@ -124,49 +127,60 @@ def cleanse(xstr): raise ValueError(f'Invalid version string, got: "{version}"') # Convert to int - major = int(major, 0).to_bytes(4, byteorder="big") - minor = int(minor, 0).to_bytes(4, byteorder="big") - patch = int(patch, 0).to_bytes(4, byteorder="big") - rc = int(rc, 0).to_bytes(4, byteorder="big") - commits = int(commits, 0).to_bytes(4, byteorder="big") - dirty = int(is_dirty).to_bytes(1, byteorder="big") + major_bytes = int(major, 0).to_bytes(4, byteorder="big") + minor_bytes = int(minor, 0).to_bytes(4, byteorder="big") + patch_bytes = int(patch, 0).to_bytes(4, byteorder="big") + rc_bytes = int(rc, 0).to_bytes(4, byteorder="big") + commits_bytes = int(commits, 0).to_bytes(4, byteorder="big") + dirty_bytes = int(is_dirty).to_bytes(1, byteorder="big") # TO make reasoning easy we lay out a big int where each field # can hold 4 bytes of data, except for dirty which is a byte # MAJOR : MINOR : PATCH : RC : COMMITS : DIRTY - concatenated = major + minor + patch + rc + commits + dirty + concatenated = ( + major_bytes + + minor_bytes + + patch_bytes + + rc_bytes + + commits_bytes + + dirty_bytes + ) v = int.from_bytes(concatenated, byteorder="big") self._value = v - def __str__(self): + def __str__(self) -> str: return self._version - def __lt__(self, other): + def __lt__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value < x - def __lte__(self, other): + def __lte__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value <= x - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, (int, self.__class__)): + return False x = other if isinstance(other, int) else other._value return self._value == x - def __ne__(self, other): + def __ne__(self, other: object) -> bool: + if not isinstance(other, (int, self.__class__)): + return False x = other if isinstance(other, int) else other._value return self._value != x - def __ge__(self, other): + def __ge__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value >= x - def __gt__(self, other): + def __gt__(self, other: Union["TSS2Version", int]) -> bool: x = other if isinstance(other, int) else other._value return self._value > x -def _chkrc(rc, acceptable=None): +def _chkrc(rc: int, acceptable: Optional[Union[list[int], int]] = None) -> None: if acceptable is None: acceptable = [] elif isinstance(acceptable, int): @@ -176,7 +190,9 @@ def _chkrc(rc, acceptable=None): raise TSS2_Exception(rc) -def _to_bytes_or_null(value, allow_null=True, encoding=None): +def _to_bytes_or_null( + value: Union[None, bytes, str], allow_null: bool = True, encoding: str = "utf-8" +) -> Union[bytes, ffi.CData]: """Convert to cdata input. None: ffi.NULL (if allow_null == True) @@ -199,7 +215,9 @@ def _to_bytes_or_null(value, allow_null=True, encoding=None): #### Utilities #### -def _CLASS_INT_ATTRS_from_string(cls, str_value, fixup_map=None): +def _CLASS_INT_ATTRS_from_string( + cls: object, str_value: str, fixup_map: Optional[dict[str, str]] = None +) -> int: """ Given a class, lookup int attributes by name and return that attribute value. :param cls: The class to search. @@ -218,18 +236,20 @@ def _CLASS_INT_ATTRS_from_string(cls, str_value, fixup_map=None): return friendly[str_value.upper()] -def _cpointer_to_ctype(x): +def _cpointer_to_ctype(x: ffi.CData) -> ffi.CType: tipe = ffi.typeof(x) if tipe.kind == "pointer": tipe = tipe.item return tipe -def _fixup_cdata_kwargs(this, _cdata, kwargs): +def _fixup_cdata_kwargs( + this: Any, _cdata: ffi.CData, kwargs: dict[str, Any] +) -> tuple[ffi.CData, dict[str, Any]]: # folks may call this routine without a keyword argument which means it may # end up in _cdata, so we want to try and work this out - unknown = None + unknown: Optional[ffi.CData] = None try: # is _cdata actual ffi data? ffi.typeof(_cdata) @@ -254,6 +274,8 @@ def _fixup_cdata_kwargs(this, _cdata, kwargs): # ignore the field that is size or count, and get the one for the data size_field_name = "size" if "TPM2B_" in tipe.cname else "count" field_name = next((v[0] for v in tipe.fields if v[0] != size_field_name), None) + if field_name is None: + raise AttributeError("No non size/could field found") if len(kwargs) != 0: raise RuntimeError( @@ -277,18 +299,20 @@ def _fixup_cdata_kwargs(this, _cdata, kwargs): return (_cdata, kwargs) -def _ref_parent(data, parent): +def _ref_parent(data: ffi.CData, parent: ffi.CData) -> ffi.CData: tipe = ffi.typeof(parent) if tipe.kind != "pointer": return data - def deconstructor(ptr): + def deconstructor(ptr: ffi.CData) -> None: parent return ffi.gc(data, deconstructor) -def _convert_to_python_native(global_map, data, parent=None): +def _convert_to_python_native( + global_map: dict[str, Any], data: ffi.CData, parent: Optional[ffi.CData] = None +) -> Any: if not isinstance(data, ffi.CData): return data @@ -314,7 +338,7 @@ def _convert_to_python_native(global_map, data, parent=None): return obj -def _fixup_classname(tipe): +def _fixup_classname(tipe: ffi.CType) -> str: # Some versions of tpm2-tss had anonymous structs, so the kind will be struct # but the name will not contain it if tipe.cname.startswith(tipe.kind): @@ -323,15 +347,17 @@ def _fixup_classname(tipe): return tipe.cname -def _mock_bail(): +def _mock_bail() -> bool: return __MOCK__ -def _get_dptr(dptr, free_func): +def _get_dptr(dptr: ffi.CData, free_func: Callable[[ffi.CData], None]) -> ffi.CData: return ffi.gc(dptr[0], free_func) -def _check_friendly_int(friendly, varname, clazz): +def _check_friendly_int( + friendly: int, varname: int, clazz: type["TPM_FRIENDLY_INT"] +) -> None: if not isinstance(friendly, int): raise TypeError(f"expected {varname} to be type int, got {type(friendly)}") @@ -343,7 +369,9 @@ def _check_friendly_int(friendly, varname, clazz): def is_bug_fixed( - fixed_in=None, backports: List[str] = None, lib: str = "tss2-fapi" + fixed_in: Optional[str] = None, + backports: Optional[List[str]] = None, + lib: str = "tss2-fapi", ) -> bool: """Use pkg-config to determine if a bug was fixed in the currently installed tpm2-tss version.""" if fixed_in and _lib_version_atleast(lib, fixed_in): @@ -368,9 +396,9 @@ def is_bug_fixed( def _check_bug_fixed( - details, - fixed_in=None, - backports: List[str] = None, + details: str, + fixed_in: Optional[str] = None, + backports: Optional[List[str]] = None, lib: str = "tss2-fapi", error: bool = False, ) -> None: @@ -385,7 +413,7 @@ def _check_bug_fixed( logger.warning(message) -def _lib_version_atleast(tss2_lib, version): +def _lib_version_atleast(tss2_lib: str, version: str) -> bool: if tss2_lib not in _versions: return False