Skip to content

Commit

Permalink
Reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
riptl committed Apr 4, 2024
1 parent 943914d commit 3e14b13
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 29 deletions.
52 changes: 39 additions & 13 deletions src/test_suite/codec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@ def decode_input(instruction_context: pb.InstrContext):
- instruction_context (pb.InstrContext): Instruction context (will be modified).
"""
if instruction_context.program_id:
instruction_context.program_id = superbased58.decode_32(instruction_context.program_id)
instruction_context.program_id = superbased58.decode_32(
instruction_context.program_id
)
if instruction_context.loader_id:
instruction_context.loader_id = superbased58.decode_32(instruction_context.loader_id)
instruction_context.loader_id = superbased58.decode_32(
instruction_context.loader_id
)

for i in range(len(instruction_context.accounts)):
if instruction_context.accounts[i].address:
instruction_context.accounts[i].address = superbased58.decode_32(instruction_context.accounts[i].address)
instruction_context.accounts[i].address = superbased58.decode_32(
instruction_context.accounts[i].address
)
if instruction_context.accounts[i].data:
instruction_context.accounts[i].data = base64.b64decode(instruction_context.accounts[i].data)
instruction_context.accounts[i].data = base64.b64decode(
instruction_context.accounts[i].data
)
if instruction_context.accounts[i].owner:
instruction_context.accounts[i].owner = superbased58.decode_32(instruction_context.accounts[i].owner)
instruction_context.accounts[i].owner = superbased58.decode_32(
instruction_context.accounts[i].owner
)

if instruction_context.data:
instruction_context.data = base64.b64decode(instruction_context.data)
Expand All @@ -37,17 +47,27 @@ def encode_input(instruction_context: pb.InstrContext):
- instruction_context (pb.InstrContext): Instruction context (will be modified).
"""
if instruction_context.program_id:
instruction_context.program_id = superbased58.encode_32(instruction_context.program_id)
instruction_context.program_id = superbased58.encode_32(
instruction_context.program_id
)
if instruction_context.loader_id:
instruction_context.loader_id = superbased58.encode_32(instruction_context.loader_id)
instruction_context.loader_id = superbased58.encode_32(
instruction_context.loader_id
)

for i in range(len(instruction_context.accounts)):
if instruction_context.accounts[i].address:
instruction_context.accounts[i].address = superbased58.encode_32(instruction_context.accounts[i].address)
instruction_context.accounts[i].address = superbased58.encode_32(
instruction_context.accounts[i].address
)
if instruction_context.accounts[i].data:
instruction_context.accounts[i].data = base64.b64encode(instruction_context.accounts[i].data)
instruction_context.accounts[i].data = base64.b64encode(
instruction_context.accounts[i].data
)
if instruction_context.accounts[i].owner:
instruction_context.accounts[i].owner = superbased58.encode_32(instruction_context.accounts[i].owner)
instruction_context.accounts[i].owner = superbased58.encode_32(
instruction_context.accounts[i].owner
)

if instruction_context.data:
instruction_context.data = base64.b64encode(instruction_context.data)
Expand All @@ -63,8 +83,14 @@ def encode_output(instruction_effects: pb.InstrEffects):
"""
for i in range(len(instruction_effects.modified_accounts)):
if instruction_effects.modified_accounts[i].address:
instruction_effects.modified_accounts[i].address = superbased58.encode_32(instruction_effects.modified_accounts[i].address)
instruction_effects.modified_accounts[i].address = superbased58.encode_32(
instruction_effects.modified_accounts[i].address
)
if instruction_effects.modified_accounts[i].data:
instruction_effects.modified_accounts[i].data = base64.b64encode(instruction_effects.modified_accounts[i].data)
instruction_effects.modified_accounts[i].data = base64.b64encode(
instruction_effects.modified_accounts[i].data
)
if instruction_effects.modified_accounts[i].owner:
instruction_effects.modified_accounts[i].owner = superbased58.encode_32(instruction_effects.modified_accounts[i].owner)
instruction_effects.modified_accounts[i].owner = superbased58.encode_32(
instruction_effects.modified_accounts[i].owner
)
10 changes: 8 additions & 2 deletions src/test_suite/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import signal
import subprocess
import os
from test_suite.multiprocessing_utils import initialize_process_output_buffers, process_instruction
from test_suite.multiprocessing_utils import (
initialize_process_output_buffers,
process_instruction,
)


def debug_target(shared_library, test_input, pipe):
initialize_process_output_buffers()
Expand Down Expand Up @@ -43,7 +47,9 @@ def debug_host(shared_library, instruction_context, gdb):

# Spawn the Python interpreter
pipe, child_pipe = Pipe()
target = multiprocessing.Process(target=debug_target, args=(shared_library, instruction_context, child_pipe))
target = multiprocessing.Process(
target=debug_target, args=(shared_library, instruction_context, child_pipe)
)
target.start()
# Wait for a signal that the child process is ready
assert pipe.recv() == "started"
Expand Down
51 changes: 37 additions & 14 deletions src/test_suite/multiprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@


def process_instruction(
library: ctypes.CDLL,
serialized_instruction_context: str
library: ctypes.CDLL, serialized_instruction_context: str
) -> pb.InstrEffects | None:
"""
Process an instruction through a provided shared library and return the result.
Expand All @@ -27,9 +26,9 @@ def process_instruction(
# Define argument and return types
library.sol_compat_instr_execute_v1.argtypes = [
POINTER(ctypes.c_uint8), # out_ptr
POINTER(c_uint64), # out_psz
POINTER(c_uint64), # out_psz
POINTER(ctypes.c_uint8), # in_ptr
c_uint64 # in_sz
c_uint64, # in_sz
]
library.sol_compat_instr_execute_v1.restype = c_int

Expand All @@ -40,14 +39,16 @@ def process_instruction(
out_sz = ctypes.c_uint64(OUTPUT_BUFFER_SIZE)

# Call the function
result = library.sol_compat_instr_execute_v1(globals.output_buffer_pointer, ctypes.byref(out_sz), in_ptr, in_sz)
result = library.sol_compat_instr_execute_v1(
globals.output_buffer_pointer, ctypes.byref(out_sz), in_ptr, in_sz
)

# Result == 0 means execution failed
if result == 0:
return None

# Process the output
output_data = bytearray(globals.output_buffer_pointer[:out_sz.value])
output_data = bytearray(globals.output_buffer_pointer[: out_sz.value])
output_object = pb.InstrEffects()
output_object.ParseFromString(output_data)

Expand Down Expand Up @@ -90,7 +91,9 @@ def generate_test_case(test_file: Path) -> tuple[Path, str | None]:
return test_file, instruction_context.SerializeToString(deterministic=True)


def process_single_test_case(file: Path, serialized_instruction_context: str | None) -> tuple[str, dict[str, str | None] | None]:
def process_single_test_case(
file: Path, serialized_instruction_context: str | None
) -> tuple[str, dict[str, str | None] | None]:
"""
Process a single execution context (file, serialized instruction context) through
all target libraries and returns serialized instruction effects. This
Expand All @@ -111,8 +114,14 @@ def process_single_test_case(file: Path, serialized_instruction_context: str | N
# Execute test case on each target library
results = {}
for target in globals.target_libraries:
instruction_effects = process_instruction(globals.target_libraries[target], serialized_instruction_context)
result = instruction_effects.SerializeToString(deterministic=True) if instruction_effects else None
instruction_effects = process_instruction(
globals.target_libraries[target], serialized_instruction_context
)
result = (
instruction_effects.SerializeToString(deterministic=True)
if instruction_effects
else None
)
results[target] = result

return file.stem, results
Expand Down Expand Up @@ -175,13 +184,22 @@ def check_consistency_in_results(file_stem: Path, results: dict) -> dict[str, bo
protobuf_structures[iteration] = protobuf_struct

# Write output Protobuf struct to logs
with open(globals.output_dir / target.stem / str(iteration) / (file_stem + ".txt"), "w") as f:
with open(
globals.output_dir
/ target.stem
/ str(iteration)
/ (file_stem + ".txt"),
"w",
) as f:
if protobuf_struct:
f.write(text_format.MessageToString(protobuf_struct))
else:
f.write(str(None))

test_case_passed = all(protobuf_structures[iteration] == protobuf_structures[0] for iteration in range(globals.n_iterations))
test_case_passed = all(
protobuf_structures[iteration] == protobuf_structures[0]
for iteration in range(globals.n_iterations)
)
results_per_target[target] = 1 if test_case_passed else -1

return results_per_target
Expand Down Expand Up @@ -222,13 +240,16 @@ def build_test_results(file_stem: Path, results: dict[str, str | None]) -> int:
else:
f.write(str(None))

test_case_passed = all(protobuf_structures[globals.solana_shared_library] == result for result in protobuf_structures.values())
test_case_passed = all(
protobuf_structures[globals.solana_shared_library] == result
for result in protobuf_structures.values()
)

# 1 = passed, -1 = failed
return 1 if test_case_passed else -1


def initialize_process_output_buffers(randomize_output_buffer = False):
def initialize_process_output_buffers(randomize_output_buffer=False):
"""
Initialize shared memory and pointers for output buffers for each process.
Expand All @@ -239,4 +260,6 @@ def initialize_process_output_buffers(randomize_output_buffer = False):

if randomize_output_buffer:
output_buffer_random_bytes = os.urandom(OUTPUT_BUFFER_SIZE)
globals.output_buffer_pointer = (ctypes.c_uint8 * OUTPUT_BUFFER_SIZE)(*output_buffer_random_bytes)
globals.output_buffer_pointer = (ctypes.c_uint8 * OUTPUT_BUFFER_SIZE)(
*output_buffer_random_bytes
)

0 comments on commit 3e14b13

Please sign in to comment.