From 59ce907c6adb00bb6a2e6db01e66b389abc9a972 Mon Sep 17 00:00:00 2001 From: Juan Altmayer Pizzorno Date: Wed, 27 Sep 2023 12:06:50 -0400 Subject: [PATCH] - various small fixes to static typing, as well as other changes to allow mypy to run; --- src/slipcover/branch.py | 22 +++++++++++----------- src/slipcover/bytecode.py | 30 ++++++++++++++---------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/src/slipcover/branch.py b/src/slipcover/branch.py index 597665e..d3c73a5 100644 --- a/src/slipcover/branch.py +++ b/src/slipcover/branch.py @@ -1,8 +1,8 @@ import ast import sys +from typing import List, Union BRANCH_NAME = "_slipcover_branches" -PYTHON_VERSION = sys.version_info[0:2] def preinstrument(tree: ast.AST) -> ast.AST: """Prepares an AST for Slipcover instrumentation, inserting assignments indicating where branches happen.""" @@ -11,17 +11,17 @@ class SlipcoverTransformer(ast.NodeTransformer): def __init__(self): pass - def _mark_branch(self, from_line: int, to_line: int) -> ast.AST: + def _mark_branch(self, from_line: int, to_line: int) -> List[ast.stmt]: mark = ast.Assign([ast.Name(BRANCH_NAME, ast.Store())], ast.Tuple([ast.Constant(from_line), ast.Constant(to_line)], ast.Load())) for node in ast.walk(mark): # we ignore line 0, so this avoids generating extra line probes - node.lineno = 0 if PYTHON_VERSION >= (3,11) else from_line + node.lineno = 0 if sys.version_info >= (3,11) else from_line return [mark] - def visit_FunctionDef(self, node: ast.AST) -> ast.AST: + def visit_FunctionDef(self, node: Union[ast.AsyncFunctionDef, ast.FunctionDef]) -> ast.AST: # Mark BRANCH_NAME global, so that our assignment are easier to find (only STORE_NAME/STORE_GLOBAL, # but not STORE_FAST, etc.) has_docstring = ast.get_docstring(node, clean=False) is not None @@ -29,7 +29,7 @@ def visit_FunctionDef(self, node: ast.AST) -> ast.AST: super().generic_visit(node) return node - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: return self.visit_FunctionDef(node) def _mark_branches(self, node: ast.AST) -> ast.AST: @@ -44,19 +44,19 @@ def _mark_branches(self, node: ast.AST) -> ast.AST: super().generic_visit(node) return node - def visit_If(self, node: ast.If) -> ast.If: + def visit_If(self, node: ast.If) -> ast.AST: return self._mark_branches(node) - def visit_For(self, node: ast.For) -> ast.For: + def visit_For(self, node: ast.For) -> ast.AST: return self._mark_branches(node) - def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AsyncFor: + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: return self._mark_branches(node) - def visit_While(self, node: ast.While) -> ast.While: + def visit_While(self, node: ast.While) -> ast.AST: return self._mark_branches(node) - if PYTHON_VERSION >= (3,10): # new in Python 3.10 + if sys.version_info >= (3,10): # new in Python 3.10 def visit_Match(self, node: ast.Match) -> ast.Match: for case in node.cases: case.body = self._mark_branch(node.lineno, case.body[0].lineno) + case.body @@ -72,7 +72,7 @@ def visit_Match(self, node: ast.Match) -> ast.Match: super().generic_visit(node) return node - if PYTHON_VERSION >= (3,10): + if sys.version_info >= (3,10): def is_Match(node: ast.AST) -> bool: return isinstance(node, ast.Match) else: diff --git a/src/slipcover/bytecode.py b/src/slipcover/bytecode.py index 166d472..291b8db 100644 --- a/src/slipcover/bytecode.py +++ b/src/slipcover/bytecode.py @@ -2,15 +2,13 @@ import sys import dis import types -from typing import List - -PYTHON_VERSION = sys.version_info[0:2] +from typing import List, Tuple # FIXME provide __all__ # Python 3.10a7 changed branch opcodes' argument to mean instruction # (word) offset, rather than bytecode offset. -if PYTHON_VERSION >= (3,10): +if sys.version_info >= (3,10): def offset2branch(offset: int) -> int: assert offset % 2 == 0 return offset//2 @@ -30,7 +28,7 @@ def branch2offset(arg: int) -> int: op_LOAD_CONST = dis.opmap["LOAD_CONST"] op_LOAD_GLOBAL = dis.opmap["LOAD_GLOBAL"] -if PYTHON_VERSION >= (3,11): +if sys.version_info >= (3,11): op_RESUME = dis.opmap["RESUME"] op_PUSH_NULL = dis.opmap["PUSH_NULL"] op_PRECALL = dis.opmap["PRECALL"] @@ -64,12 +62,12 @@ def opcode_arg(opcode: int, arg: int, min_ext : int = 0) -> List[int]: [op_EXTENDED_ARG, (arg >> (ext - i) * 8) & 0xFF] ) bytecode.extend([opcode, arg & 0xFF]) - if PYTHON_VERSION >= (3,11): + if sys.version_info >= (3,11): bytecode.extend([op_CACHE, 0] * dis._inline_cache_entries[opcode]) return bytecode -def unpack_opargs(code: bytes) -> List[(int, int, int, int)]: +def unpack_opargs(code: bytes) -> Tuple[int, int, int, int]: """Unpacks opcodes and their arguments, returning: - the beginning offset, including that of the first EXTENDED_ARG, if any @@ -86,7 +84,7 @@ def unpack_opargs(code: bytes) -> List[(int, int, int, int)]: ext_arg = (ext_arg | code[off+1]) << 8 else: arg = (ext_arg | code[off+1]) - if PYTHON_VERSION >= (3,11): + if sys.version_info >= (3,11): while off+2 < len(code) and code[off+2] == op_CACHE: off += 2 yield (next_off, off+2-next_off, op, arg) @@ -155,7 +153,7 @@ def adjust_length(self) -> int: return change - def code(self) -> bytes: + def code(self) -> List[int]: """Emits this branch's code.""" assert self.length >= 2 + 2*arg_ext_needed(self.arg()) return opcode_arg(self.opcode, self.arg(), (self.length-2)//2) @@ -239,7 +237,7 @@ def adjust(self, insert_offset: int, insert_length: int) -> None: def from_code(code: types.CodeType) -> List[ExceptionTableEntry]: """Returns a list of exception table entries from a code object.""" - if PYTHON_VERSION < (3,11): return [] + if sys.version_info < (3,11): return [] entries = [] it = iter(code.co_exceptiontable) @@ -345,7 +343,7 @@ def make_lnotab(firstlineno : int, lines : List[LineEntry]) -> bytes: return bytes(lnotab) - if PYTHON_VERSION == (3,10): + if sys.version_info >= (3,9) and sys.version_info < (3,11): # 3.10 @staticmethod def make_linetable(firstlineno : int, lines : List[LineEntry]) -> bytes: """Generates the line number table used by Python 3.10 to map offsets to line numbers.""" @@ -393,7 +391,7 @@ def make_linetable(firstlineno : int, lines : List[LineEntry]) -> bytes: return bytes(linetable) - if PYTHON_VERSION >= (3,11): + if sys.version_info >= (3,11): @staticmethod def make_linetable(firstlineno : int, lines : List[LineEntry]) -> bytes: """Generates the positions table used by Python 3.11+ to map offsets to line numbers.""" @@ -490,7 +488,7 @@ def insert_function_call(self, offset, function, args, repl_length=0): insert = bytearray() - if PYTHON_VERSION >= (3,11): + if sys.version_info >= (3,11): insert.extend([op_NOP, 0, # for disabling op_PUSH_NULL, 0] + opcode_arg(op_LOAD_CONST, function)) @@ -610,7 +608,7 @@ def replace_global_with_const(self, global_name, const_index): def find_load_globals(): for op_off, op_len, op, op_arg in unpack_opargs(self.patch): if op == op_LOAD_GLOBAL: - if PYTHON_VERSION >= (3,11): + if sys.version_info >= (3,11): if (op_arg>>1) == name_index: yield (op_off, op_len, op, op_arg) else: @@ -704,12 +702,12 @@ def finish(self): replace["co_code"] = bytes(self.patch) if self.branches is not None: - if PYTHON_VERSION < (3,10): + if sys.version_info < (3,10): replace["co_lnotab"] = LineEntry.make_lnotab(self.orig_code.co_firstlineno, self.lines) else: replace["co_linetable"] = LineEntry.make_linetable(self.orig_code.co_firstlineno, self.lines) - if PYTHON_VERSION >= (3,11): + if sys.version_info >= (3,11): replace["co_exceptiontable"] = ExceptionTableEntry.make_exceptiontable(self.ex_table) return self.orig_code.replace(**replace)