Skip to content

Commit

Permalink
mypy: skip casting for constants
Browse files Browse the repository at this point in the history
Signed-off-by: Erik Larsson <who+github@cnackers.org>
  • Loading branch information
whooo committed Jan 21, 2024
1 parent 6ad2ca1 commit 6b2d85a
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 105 deletions.
26 changes: 24 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,25 @@ class type_hints_generator(type_generator):
callbacks = dict()
functions = dict()

macro_types = (
("TPM2_ALG_", "TPM2_ALG"),
("ESYS_TR_", "ESYS_TR"),
("TPM2_ECC_", "TPM2_ECC"),
("TPM2_RH_", "TPM2_RH"),
("TPM2_SU_", "TPM2_SU"),
("TPMA_OBJECT_", "TPMA_OBJECT"),
)

def macro_to_type(self, macro):
mt = "int"
ml = 0
for prefix, tn in self.macro_types:
pl = len(prefix)
if macro.startswith(prefix) and pl > ml:
mt = tn
ml = pl
return mt

def _make_callback_output(self, cname):
callback = self.callbacks[cname]
rt, args = callback
Expand Down Expand Up @@ -358,9 +377,12 @@ def write_type_hints(self, macros):
"""
)

# assume all defines are ints
mtl = [x for _, x in self.macro_types]
output += f"from ..constants import {', '.join(mtl)}\n"

for m in macros:
output += f"{m}: int\n"
mt = self.macro_to_type(m)
output += f"{m}: \"{mt}\"\n"

output += "\n# Callback definitions\n"
for cname in self.callbacks:
Expand Down
129 changes: 64 additions & 65 deletions src/tpm2_pytss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TYPE_CHECKING,
Any,
SupportsIndex,
cast,
)

try:
Expand Down Expand Up @@ -124,105 +123,105 @@ def __str__(self) -> str:
return k.lower()
return str(int(self))

def __abs__(self) -> "TPM_FRIENDLY_INT":
def __abs__(self) -> "Self":
return self.__class__(int(self).__abs__())

def __add__(self, value: int) -> "TPM_FRIENDLY_INT":
def __add__(self, value: int) -> "Self":
return self.__class__(int(self).__add__(value))

def __and__(self, value: int) -> "TPM_FRIENDLY_INT":
def __and__(self, value: int) -> "Self":
return self.__class__(int(self).__and__(value))

def __ceil__(self) -> "TPM_FRIENDLY_INT":
def __ceil__(self) -> "Self":
return self.__class__(int(self).__ceil__())

def __divmod__(self, value: int) -> Tuple["TPM_FRIENDLY_INT", "TPM_FRIENDLY_INT"]:
def __divmod__(self, value: int) -> Tuple["Self", "Self"]:
a, b = int(self).__divmod__(value)
return self.__class__(a), self.__class__(b)

def __floor__(self) -> "TPM_FRIENDLY_INT":
def __floor__(self) -> "Self":
return self.__class__(int(self).__floor__())

def __floordiv__(self, value: int) -> "TPM_FRIENDLY_INT":
def __floordiv__(self, value: int) -> "Self":
return self.__class__(int(self).__floordiv__(value))

def __invert__(self) -> "TPM_FRIENDLY_INT":
def __invert__(self) -> "Self":
return self.__class__(int(self).__invert__())

def __lshift__(self, value: int) -> "TPM_FRIENDLY_INT":
def __lshift__(self, value: int) -> "Self":
return self.__class__(int(self).__lshift__(value))

def __mod__(self, value: int) -> "TPM_FRIENDLY_INT":
def __mod__(self, value: int) -> "Self":
return self.__class__(int(self).__mod__(value))

def __mul__(self, value: int) -> "TPM_FRIENDLY_INT":
def __mul__(self, value: int) -> "Self":
return self.__class__(int(self).__mul__(value))

def __neg__(self) -> "TPM_FRIENDLY_INT":
def __neg__(self) -> "Self":
return self.__class__(int(self).__neg__())

def __or__(self, value: int) -> "TPM_FRIENDLY_INT":
def __or__(self, value: int) -> "Self":
return self.__class__(int(self).__or__(value))

def __pos__(self) -> "TPM_FRIENDLY_INT":
def __pos__(self) -> "Self":
return self.__class__(int(self).__pos__())

def __pow__(self, value: int, mod: Optional[int] = None) -> Any:
return self.__class__(int(self).__pow__(value, mod))

def __radd__(self, value: int) -> "TPM_FRIENDLY_INT":
def __radd__(self, value: int) -> "Self":
return self.__class__(int(self).__radd__(value))

def __rand__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rand__(self, value: int) -> "Self":
return self.__class__(int(self).__rand__(value))

def __rdivmod__(self, value: int) -> Tuple["TPM_FRIENDLY_INT", "TPM_FRIENDLY_INT"]:
def __rdivmod__(self, value: int) -> Tuple["Self", "Self"]:
a, b = int(self).__rdivmod__(value)
return self.__class__(a), self.__class__(b)

def __rfloordiv__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rfloordiv__(self, value: int) -> "Self":
return self.__class__(int(self).__rfloordiv__(value))

def __rlshift__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rlshift__(self, value: int) -> "Self":
return self.__class__(int(self).__rlshift__(value))

def __rmod__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rmod__(self, value: int) -> "Self":
return self.__class__(int(self).__rmod__(value))

def __rmul__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rmul__(self, value: int) -> "Self":
return self.__class__(int(self).__rmul__(value))

def __ror__(self, value: int) -> "TPM_FRIENDLY_INT":
def __ror__(self, value: int) -> "Self":
return self.__class__(int(self).__ror__(value))

def __round__(self, ndigits: SupportsIndex = False) -> "TPM_FRIENDLY_INT":
def __round__(self, ndigits: SupportsIndex = False) -> "Self":
return self.__class__(int(self).__round__(ndigits))

def __rpow__(self, value: int, mod: Optional[int] = None) -> Any:
return self.__class__(int(self).__rpow__(value, mod))

def __rrshift__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rrshift__(self, value: int) -> "Self":
return self.__class__(int(self).__rrshift__(value))

def __rshift__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rshift__(self, value: int) -> "Self":
return self.__class__(int(self).__rshift__(value))

def __rsub__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rsub__(self, value: int) -> "Self":
return self.__class__(int(self).__rsub__(value))

def __rtruediv__(self, value: int) -> Any:
return self.__class__(int(self).__rtruediv__(value))

def __rxor__(self, value: int) -> "TPM_FRIENDLY_INT":
def __rxor__(self, value: int) -> "Self":
return self.__class__(int(self).__rxor__(value))

def __sub__(self, value: int) -> "TPM_FRIENDLY_INT":
def __sub__(self, value: int) -> "Self":
return self.__class__(int(self).__sub__(value))

def __truediv__(self, value: int) -> "TPM_FRIENDLY_INT":
def __truediv__(self, value: int) -> "Self":
return self.__class__(int(self).__truediv__(value))

def __xor__(self, value: int) -> "TPM_FRIENDLY_INT":
def __xor__(self, value: int) -> "Self":
return self.__class__(int(self).__xor__(value))

@staticmethod
Expand Down Expand Up @@ -282,7 +281,7 @@ def marshal(self) -> bytes:
return bytes(buf[0 : offset[0]])

@classmethod
def unmarshal(cls, buf: bytes) -> Tuple["TPM_FRIENDLY_INT", int]:
def unmarshal(cls, buf: bytes) -> Tuple["Self", int]:
"""Unmarshal bytes into type instance.
Args:
Expand Down Expand Up @@ -423,8 +422,8 @@ class ESYS_TR(TPM_FRIENDLY_INT):
or a persistent key use :func:`tpm2_pytss.ESAPI.tr_from_tpmpublic`
"""

NONE = cast("ESYS_TR", lib.ESYS_TR_NONE)
PASSWORD = cast("ESYS_TR", lib.ESYS_TR_PASSWORD)
NONE = lib.ESYS_TR_NONE
PASSWORD = lib.ESYS_TR_PASSWORD
PCR0 = lib.ESYS_TR_PCR0
PCR1 = lib.ESYS_TR_PCR1
PCR2 = lib.ESYS_TR_PCR2
Expand Down Expand Up @@ -457,11 +456,11 @@ class ESYS_TR(TPM_FRIENDLY_INT):
PCR29 = lib.ESYS_TR_PCR29
PCR30 = lib.ESYS_TR_PCR30
PCR31 = lib.ESYS_TR_PCR31
OWNER = cast("ESYS_TR", lib.ESYS_TR_RH_OWNER)
NULL = cast("ESYS_TR", lib.ESYS_TR_RH_NULL)
LOCKOUT = cast("ESYS_TR", lib.ESYS_TR_RH_LOCKOUT)
ENDORSEMENT = cast("ESYS_TR", lib.ESYS_TR_RH_ENDORSEMENT)
PLATFORM = cast("ESYS_TR", lib.ESYS_TR_RH_PLATFORM)
OWNER = lib.ESYS_TR_RH_OWNER
NULL = lib.ESYS_TR_RH_NULL
LOCKOUT = lib.ESYS_TR_RH_LOCKOUT
ENDORSEMENT = lib.ESYS_TR_RH_ENDORSEMENT
PLATFORM = lib.ESYS_TR_RH_PLATFORM
PLATFORM_NV = lib.ESYS_TR_RH_PLATFORM_NV
RH_OWNER = lib.ESYS_TR_RH_OWNER
RH_NULL = lib.ESYS_TR_RH_NULL
Expand All @@ -474,7 +473,7 @@ def marshal(self) -> bytes:
raise NotImplementedError("Use serialize() instead")

@classmethod
def unmarshal(cls, buf: bytes) -> Tuple["TPM_FRIENDLY_INT", int]:
def unmarshal(cls, buf: bytes) -> Tuple["Self", int]:
raise NotImplementedError("Use deserialize() instead")

def serialize(self, ectx: "ESAPI") -> bytes:
Expand Down Expand Up @@ -547,21 +546,21 @@ def parts_to_blob(handle: "TPM2_HANDLE", public: "TPM2B_PUBLIC") -> bytes:
@TPM_FRIENDLY_INT._fix_const_type
class TPM2_RH(TPM_FRIENDLY_INT):
SRK = lib.TPM2_RH_SRK
OWNER = cast("TPM2_RH", lib.TPM2_RH_OWNER)
OWNER = lib.TPM2_RH_OWNER
REVOKE = lib.TPM2_RH_REVOKE
TRANSPORT = lib.TPM2_RH_TRANSPORT
OPERATOR = lib.TPM2_RH_OPERATOR
ADMIN = lib.TPM2_RH_ADMIN
EK = lib.TPM2_RH_EK
NULL = cast("TPM2_RH", lib.TPM2_RH_NULL)
NULL = lib.TPM2_RH_NULL
UNASSIGNED = lib.TPM2_RH_UNASSIGNED
try:
PW = lib.TPM2_RS_PW
except AttributeError:
PW = lib.TPM2_RH_PW
LOCKOUT = lib.TPM2_RH_LOCKOUT
ENDORSEMENT = cast("TPM2_RH", lib.TPM2_RH_ENDORSEMENT)
PLATFORM = cast("TPM2_RH", lib.TPM2_RH_PLATFORM)
ENDORSEMENT = lib.TPM2_RH_ENDORSEMENT
PLATFORM = lib.TPM2_RH_PLATFORM
PLATFORM_NV = lib.TPM2_RH_PLATFORM_NV


Expand All @@ -572,18 +571,18 @@ class TPM2_ALG(TPM_FRIENDLY_INT):
RSA = lib.TPM2_ALG_RSA
TDES = lib.TPM2_ALG_TDES
SHA = lib.TPM2_ALG_SHA
SHA1 = cast("TPM2_ALG", lib.TPM2_ALG_SHA1)
SHA1 = lib.TPM2_ALG_SHA1
HMAC = lib.TPM2_ALG_HMAC
AES = cast("TPM2_ALG", lib.TPM2_ALG_AES)
AES = lib.TPM2_ALG_AES
MGF1 = lib.TPM2_ALG_MGF1
KEYEDHASH = lib.TPM2_ALG_KEYEDHASH
XOR = lib.TPM2_ALG_XOR
SHA256 = cast("TPM2_ALG", lib.TPM2_ALG_SHA256)
SHA384 = cast("TPM2_ALG", lib.TPM2_ALG_SHA384)
SHA512 = cast("TPM2_ALG", lib.TPM2_ALG_SHA512)
SHA256 = lib.TPM2_ALG_SHA256
SHA384 = lib.TPM2_ALG_SHA384
SHA512 = lib.TPM2_ALG_SHA512
NULL = lib.TPM2_ALG_NULL
SM3_256 = cast("TPM2_ALG", lib.TPM2_ALG_SM3_256)
SM4 = cast("TPM2_ALG", lib.TPM2_ALG_SM4)
SM3_256 = lib.TPM2_ALG_SM3_256
SM4 = lib.TPM2_ALG_SM4
RSASSA = lib.TPM2_ALG_RSASSA
RSAES = lib.TPM2_ALG_RSAES
RSAPSS = lib.TPM2_ALG_RSAPSS
Expand All @@ -599,11 +598,11 @@ class TPM2_ALG(TPM_FRIENDLY_INT):
KDF1_SP800_108 = lib.TPM2_ALG_KDF1_SP800_108
ECC = lib.TPM2_ALG_ECC
SYMCIPHER = lib.TPM2_ALG_SYMCIPHER
CAMELLIA = cast("TPM2_ALG", lib.TPM2_ALG_CAMELLIA)
CAMELLIA = lib.TPM2_ALG_CAMELLIA
CTR = lib.TPM2_ALG_CTR
SHA3_256 = cast("TPM2_ALG", lib.TPM2_ALG_SHA3_256)
SHA3_384 = cast("TPM2_ALG", lib.TPM2_ALG_SHA3_384)
SHA3_512 = cast("TPM2_ALG", lib.TPM2_ALG_SHA3_512)
SHA3_256 = lib.TPM2_ALG_SHA3_256
SHA3_384 = lib.TPM2_ALG_SHA3_384
SHA3_512 = lib.TPM2_ALG_SHA3_512
OFB = lib.TPM2_ALG_OFB
CBC = lib.TPM2_ALG_CBC
CFB = lib.TPM2_ALG_CFB
Expand All @@ -619,12 +618,12 @@ class TPM2_ALG_ID(TPM2_ALG):

@TPM_FRIENDLY_INT._fix_const_type
class TPM2_ECC(TPM_FRIENDLY_INT):
NONE = cast("TPM2_ECC", lib.TPM2_ECC_NONE)
NIST_P192 = cast("TPM2_ECC", lib.TPM2_ECC_NIST_P192)
NIST_P224 = cast("TPM2_ECC", lib.TPM2_ECC_NIST_P224)
NIST_P256 = cast("TPM2_ECC", lib.TPM2_ECC_NIST_P256)
NIST_P384 = cast("TPM2_ECC", lib.TPM2_ECC_NIST_P384)
NIST_P521 = cast("TPM2_ECC", lib.TPM2_ECC_NIST_P521)
NONE = lib.TPM2_ECC_NONE
NIST_P192 = lib.TPM2_ECC_NIST_P192
NIST_P224 = lib.TPM2_ECC_NIST_P224
NIST_P256 = lib.TPM2_ECC_NIST_P256
NIST_P384 = lib.TPM2_ECC_NIST_P384
NIST_P521 = lib.TPM2_ECC_NIST_P521
BN_P256 = lib.TPM2_ECC_BN_P256
BN_P638 = lib.TPM2_ECC_BN_P638
SM2_P256 = lib.TPM2_ECC_SM2_P256
Expand Down Expand Up @@ -1132,8 +1131,8 @@ class TPM2_ST(TPM_FRIENDLY_INT):

@TPM_FRIENDLY_INT._fix_const_type
class TPM2_SU(TPM_FRIENDLY_INT):
CLEAR = cast("TPM2_SU", lib.TPM2_SU_CLEAR)
STATE = cast("TPM2_SU", lib.TPM2_SU_STATE)
CLEAR = lib.TPM2_SU_CLEAR
STATE = lib.TPM2_SU_STATE


@TPM_FRIENDLY_INT._fix_const_type
Expand Down Expand Up @@ -1350,14 +1349,14 @@ class TPMA_LOCALITY(TPMA_FRIENDLY_INTLIST):
EXTENDED_SHIFT = lib.TPMA_LOCALITY_EXTENDED_SHIFT

@classmethod
def create_extended(cls, value: int) -> "TPMA_LOCALITY":
def create_extended(cls, value: int) -> "Self":
x = (1 << cls.EXTENDED_SHIFT) + value
if x > 255:
raise ValueError("Extended Localities must be less than 256")
return cls(x)

@classmethod
def parse(cls, value: str) -> "TPMA_LOCALITY":
def parse(cls, value: str) -> "Self":
"""Converts a string of | separated localities or an extended locality into a TPMA_LOCALITY instance
Args:
Expand Down
Loading

0 comments on commit 6b2d85a

Please sign in to comment.