Skip to content

Commit

Permalink
- various small fixes to static typing, as well as other changes to
Browse files Browse the repository at this point in the history
  allow mypy to run;
  • Loading branch information
jaltmayerpizzorno committed Sep 27, 2023
1 parent 0dd21b2 commit 59ce907
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
22 changes: 11 additions & 11 deletions src/slipcover/branch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ast
import sys
from typing import List, Union

BRANCH_NAME = "_slipcover_branches"
PYTHON_VERSION = sys.version_info[0:2]

def preinstrument(tree: ast.AST) -> ast.AST:
"""Prepares an AST for Slipcover instrumentation, inserting assignments indicating where branches happen."""
Expand All @@ -11,25 +11,25 @@ class SlipcoverTransformer(ast.NodeTransformer):
def __init__(self):
pass

def _mark_branch(self, from_line: int, to_line: int) -> ast.AST:
def _mark_branch(self, from_line: int, to_line: int) -> List[ast.stmt]:
mark = ast.Assign([ast.Name(BRANCH_NAME, ast.Store())],
ast.Tuple([ast.Constant(from_line), ast.Constant(to_line)], ast.Load()))

for node in ast.walk(mark):
# we ignore line 0, so this avoids generating extra line probes
node.lineno = 0 if PYTHON_VERSION >= (3,11) else from_line
node.lineno = 0 if sys.version_info >= (3,11) else from_line

return [mark]

def visit_FunctionDef(self, node: ast.AST) -> ast.AST:
def visit_FunctionDef(self, node: Union[ast.AsyncFunctionDef, ast.FunctionDef]) -> ast.AST:
# Mark BRANCH_NAME global, so that our assignment are easier to find (only STORE_NAME/STORE_GLOBAL,
# but not STORE_FAST, etc.)
has_docstring = ast.get_docstring(node, clean=False) is not None
node.body.insert(1 if has_docstring else 0, ast.Global([BRANCH_NAME]))
super().generic_visit(node)
return node

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST:
return self.visit_FunctionDef(node)

def _mark_branches(self, node: ast.AST) -> ast.AST:
Expand All @@ -44,19 +44,19 @@ def _mark_branches(self, node: ast.AST) -> ast.AST:
super().generic_visit(node)
return node

def visit_If(self, node: ast.If) -> ast.If:
def visit_If(self, node: ast.If) -> ast.AST:
return self._mark_branches(node)

def visit_For(self, node: ast.For) -> ast.For:
def visit_For(self, node: ast.For) -> ast.AST:
return self._mark_branches(node)

def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AsyncFor:
def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST:
return self._mark_branches(node)

def visit_While(self, node: ast.While) -> ast.While:
def visit_While(self, node: ast.While) -> ast.AST:
return self._mark_branches(node)

if PYTHON_VERSION >= (3,10): # new in Python 3.10
if sys.version_info >= (3,10): # new in Python 3.10
def visit_Match(self, node: ast.Match) -> ast.Match:
for case in node.cases:
case.body = self._mark_branch(node.lineno, case.body[0].lineno) + case.body
Expand All @@ -72,7 +72,7 @@ def visit_Match(self, node: ast.Match) -> ast.Match:
super().generic_visit(node)
return node

if PYTHON_VERSION >= (3,10):
if sys.version_info >= (3,10):
def is_Match(node: ast.AST) -> bool:
return isinstance(node, ast.Match)
else:
Expand Down
30 changes: 14 additions & 16 deletions src/slipcover/bytecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import sys
import dis
import types
from typing import List

PYTHON_VERSION = sys.version_info[0:2]
from typing import List, Tuple

# FIXME provide __all__

# Python 3.10a7 changed branch opcodes' argument to mean instruction
# (word) offset, rather than bytecode offset.
if PYTHON_VERSION >= (3,10):
if sys.version_info >= (3,10):
def offset2branch(offset: int) -> int:
assert offset % 2 == 0
return offset//2
Expand All @@ -30,7 +28,7 @@ def branch2offset(arg: int) -> int:
op_LOAD_CONST = dis.opmap["LOAD_CONST"]
op_LOAD_GLOBAL = dis.opmap["LOAD_GLOBAL"]

if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
op_RESUME = dis.opmap["RESUME"]
op_PUSH_NULL = dis.opmap["PUSH_NULL"]
op_PRECALL = dis.opmap["PRECALL"]
Expand Down Expand Up @@ -64,12 +62,12 @@ def opcode_arg(opcode: int, arg: int, min_ext : int = 0) -> List[int]:
[op_EXTENDED_ARG, (arg >> (ext - i) * 8) & 0xFF]
)
bytecode.extend([opcode, arg & 0xFF])
if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
bytecode.extend([op_CACHE, 0] * dis._inline_cache_entries[opcode])
return bytecode


def unpack_opargs(code: bytes) -> List[(int, int, int, int)]:
def unpack_opargs(code: bytes) -> Tuple[int, int, int, int]:
"""Unpacks opcodes and their arguments, returning:
- the beginning offset, including that of the first EXTENDED_ARG, if any
Expand All @@ -86,7 +84,7 @@ def unpack_opargs(code: bytes) -> List[(int, int, int, int)]:
ext_arg = (ext_arg | code[off+1]) << 8
else:
arg = (ext_arg | code[off+1])
if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
while off+2 < len(code) and code[off+2] == op_CACHE:
off += 2
yield (next_off, off+2-next_off, op, arg)
Expand Down Expand Up @@ -155,7 +153,7 @@ def adjust_length(self) -> int:

return change

def code(self) -> bytes:
def code(self) -> List[int]:
"""Emits this branch's code."""
assert self.length >= 2 + 2*arg_ext_needed(self.arg())
return opcode_arg(self.opcode, self.arg(), (self.length-2)//2)
Expand Down Expand Up @@ -239,7 +237,7 @@ def adjust(self, insert_offset: int, insert_length: int) -> None:
def from_code(code: types.CodeType) -> List[ExceptionTableEntry]:
"""Returns a list of exception table entries from a code object."""

if PYTHON_VERSION < (3,11): return []
if sys.version_info < (3,11): return []

entries = []
it = iter(code.co_exceptiontable)
Expand Down Expand Up @@ -345,7 +343,7 @@ def make_lnotab(firstlineno : int, lines : List[LineEntry]) -> bytes:
return bytes(lnotab)


if PYTHON_VERSION == (3,10):
if sys.version_info >= (3,9) and sys.version_info < (3,11): # 3.10
@staticmethod
def make_linetable(firstlineno : int, lines : List[LineEntry]) -> bytes:
"""Generates the line number table used by Python 3.10 to map offsets to line numbers."""
Expand Down Expand Up @@ -393,7 +391,7 @@ def make_linetable(firstlineno : int, lines : List[LineEntry]) -> bytes:
return bytes(linetable)


if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
@staticmethod
def make_linetable(firstlineno : int, lines : List[LineEntry]) -> bytes:
"""Generates the positions table used by Python 3.11+ to map offsets to line numbers."""
Expand Down Expand Up @@ -490,7 +488,7 @@ def insert_function_call(self, offset, function, args, repl_length=0):

insert = bytearray()

if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
insert.extend([op_NOP, 0, # for disabling
op_PUSH_NULL, 0] +
opcode_arg(op_LOAD_CONST, function))
Expand Down Expand Up @@ -610,7 +608,7 @@ def replace_global_with_const(self, global_name, const_index):
def find_load_globals():
for op_off, op_len, op, op_arg in unpack_opargs(self.patch):
if op == op_LOAD_GLOBAL:
if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
if (op_arg>>1) == name_index:
yield (op_off, op_len, op, op_arg)
else:
Expand Down Expand Up @@ -704,12 +702,12 @@ def finish(self):
replace["co_code"] = bytes(self.patch)

if self.branches is not None:
if PYTHON_VERSION < (3,10):
if sys.version_info < (3,10):
replace["co_lnotab"] = LineEntry.make_lnotab(self.orig_code.co_firstlineno, self.lines)
else:
replace["co_linetable"] = LineEntry.make_linetable(self.orig_code.co_firstlineno, self.lines)

if PYTHON_VERSION >= (3,11):
if sys.version_info >= (3,11):
replace["co_exceptiontable"] = ExceptionTableEntry.make_exceptiontable(self.ex_table)

return self.orig_code.replace(**replace)

0 comments on commit 59ce907

Please sign in to comment.