Skip to content

Commit

Permalink
feat: uninitialized jumpdests fallback (#1148)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR: 0.5d

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [ ] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

<!-- Please describe the current behavior that you are modifying,
or link to a relevant issue. -->

Resolves #<Issue number>

## What is the new behavior?

<!-- Please describe the behavior or changes that are being added by
this PR. -->

- Adds a fallback so that if a jumpdest index read is 0 and the
jumpdests have not been initialized, initialize them and re-query the
index
-
-

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1148)
<!-- Reviewable:end -->
  • Loading branch information
enitrat authored May 21, 2024
1 parent 98b26fd commit 3aacc3a
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 9 deletions.
54 changes: 48 additions & 6 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from openzeppelin.access.ownable.library import Ownable, Ownable_owner
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.bool import FALSE
from starkware.cairo.common.dict_access import DictAccess
from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin
from starkware.cairo.common.math import unsigned_div_rem, split_int, split_felt
from starkware.cairo.common.memcpy import memcpy
Expand Down Expand Up @@ -33,6 +34,7 @@ from utils.eth_transaction import EthTransaction
from utils.uint256 import uint256_add
from utils.bytes import bytes_to_bytes8_little_endian
from utils.signature import Signature
from utils.utils import Helpers

// @dev: should always be zero for EOAs
@storage_var
Expand Down Expand Up @@ -67,6 +69,10 @@ func Account_cairo1_helpers_class_hash() -> (res: felt) {
func Account_valid_jumpdests() -> (is_valid: felt) {
}

@storage_var
func Account_jumpdests_initialized() -> (initialized: felt) {
}

@event
func transaction_executed(response_len: felt, response: felt*, success: felt, gas_used: felt) {
}
Expand Down Expand Up @@ -431,17 +437,23 @@ namespace AccountContract {
return (latest_account_class, latest_cairo1_helpers_class);
}

// @notice Writes an array of valid jumpdests indexes to storage.
// @param jumpdests_len The length of the jumpdests array.
// @param jumpdests The jumpdests array.
func write_jumpdests{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
jumpdests_len: felt, jumpdests: felt*
) {
// Recursively store the jumpdests.
Internals.write_jumpdests(jumpdests_len=jumpdests_len, jumpdests=jumpdests);
Internals.write_jumpdests(
jumpdests_len=jumpdests_len, jumpdests=jumpdests, iteration_size=1
);
return ();
}

func is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
index: felt
) -> felt {
alloc_locals;
let (base_address) = Account_valid_jumpdests.addr();
let index_address = base_address + index;

Expand All @@ -454,7 +466,32 @@ namespace AccountContract {
tempvar syscall_ptr = syscall_ptr + StorageRead.SIZE;
tempvar value = response.value;

return value;
if (value != 0) {
return value;
}

// Jumpdest is invalid - we verify that the jumpdests have been stored, and if not,
// we store them. call the appropriate function check & store
let (initialized) = Account_jumpdests_initialized.read();

if (initialized != FALSE) {
return value;
}

let (bytecode_len) = Account_bytecode_len.read();
let (bytecode) = Internals.load_bytecode(bytecode_len);
let (valid_jumpdests_start, valid_jumpdests) = Helpers.initialize_jumpdests(
bytecode_len, bytecode
);
let (jumpdests_len, _) = unsigned_div_rem(
valid_jumpdests - valid_jumpdests_start, DictAccess.SIZE
);
Internals.write_jumpdests(
jumpdests_len=jumpdests_len,
jumpdests=cast(valid_jumpdests_start, felt*),
iteration_size=DictAccess.SIZE,
);
return is_valid_jumpdest(index=index);
}
}

Expand Down Expand Up @@ -523,10 +560,13 @@ namespace Internals {
}

// @notice Store the jumpdests of the contract.
// @param jumpdests_len The length of the jumpdests.
// @param jumpdests The jumpdests of the contract.
// @dev This function can be used by either passing an array of valid jumpdests,
// or a dict that only contains valid entries (i.e. no invalid index has been read).
// @param jumpdests_len The length of the valid jumpdests.
// @param jumpdests The jumpdests of the contract. Can be an array of valid indexes or a dict.
// @param iteration_size The size of the object we are iterating over.
func write_jumpdests{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}(
jumpdests_len: felt, jumpdests: felt*
jumpdests_len: felt, jumpdests: felt*, iteration_size: felt
) {
alloc_locals;

Expand All @@ -546,6 +586,7 @@ namespace Internals {
let jumpdests = cast([ap - 2], felt*);
let remaining = [ap - 1];
let base_address = [fp];
let iteration_size = [fp - 3];

let index_to_store = [jumpdests];
tempvar storage_address = base_address + index_to_store;
Expand All @@ -555,11 +596,12 @@ namespace Internals {
);
%{ syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr) %}
tempvar syscall_ptr = syscall_ptr + StorageWrite.SIZE;
tempvar jumpdests = jumpdests + 1;
tempvar jumpdests = jumpdests + iteration_size;
tempvar remaining = remaining - 1;

jmp body if remaining != 0;

Account_jumpdests_initialized.write(1);
return ();
}

Expand Down
68 changes: 65 additions & 3 deletions tests/src/kakarot/accounts/test_contract_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test__should_store_valid_jumpdests(self, cairo_run):

class TestReadJumpdests:
@pytest.fixture
def storage(self, jumpdests):
def store_jumpdests(self, jumpdests):
base_address = get_storage_var_address("Account_valid_jumpdests")
valid_addresses = [base_address + jumpdest for jumpdest in jumpdests]

Expand All @@ -168,9 +168,11 @@ def _storage(address):

@pytest.mark.parametrize("jumpdests", [[0x02, 0x10, 0xFF]])
def test__should_return_if_jumpdest_valid(
self, cairo_run, jumpdests, storage
self, cairo_run, jumpdests, store_jumpdests
):
with patch.object(SyscallHandler, "mock_storage", side_effect=storage):
with patch.object(
SyscallHandler, "mock_storage", side_effect=store_jumpdests
):
for jumpdest in jumpdests:
assert cairo_run("test__is_valid_jumpdest", index=jumpdest) == 1

Expand All @@ -180,6 +182,66 @@ def test__should_return_if_jumpdest_valid(
]
SyscallHandler.mock_storage.assert_has_calls(calls)

@pytest.fixture
def patch_account_storage(self, account_code):
code_len_address = get_storage_var_address("Account_bytecode_len")
base_jumpdests_address = get_storage_var_address(
"Account_valid_jumpdests"
)
chunks = wrap(account_code, 2 * 31)

def _storage(address, value=None):
if value is not None:
SyscallHandler.patches[address] = value
return
if address == code_len_address:
return len(bytes.fromhex(account_code))
elif address >= base_jumpdests_address:
return 0
return int(chunks[address], 16)

return _storage

# Code contains both valid and invalid jumpdests
# PUSH1 4 // Offset 0
# JUMP // Offset 2 (previous instruction occupies 2 bytes)
# INVALID // Offset 3
# JUMPDEST // Offset 4
# PUSH1 1 // Offset 5
# PUSH1 0x5B // invalid jumpdest
@pytest.mark.parametrize(
"account_code, jumpdests, results",
[("600456fe5b6001605b", [0x04, 0x08], [1, 0])],
)
def test__should_return_if_jumpdest_valid_when_not_stored(
self, cairo_run, account_code, jumpdests, results, patch_account_storage
):
with patch.object(
SyscallHandler, "mock_storage", side_effect=patch_account_storage
):
for jumpdest, result in zip(jumpdests, results):
assert (
cairo_run("test__is_valid_jumpdest", index=jumpdest)
== result
)

base_address = get_storage_var_address("Account_valid_jumpdests")
jumpdests_initialized_address = get_storage_var_address(
"Account_jumpdests_initialized"
)
expected_read_calls = [
call(address=base_address + jumpdest) for jumpdest in jumpdests
] + [call(address=jumpdests_initialized_address)]

expected_write_calls = [
call(address=base_address + jumpdest, value=1)
for jumpdest, result in zip(jumpdests, results)
if result == 1
] + [call(address=jumpdests_initialized_address, value=1)]

SyscallHandler.mock_storage.assert_has_calls(expected_read_calls)
SyscallHandler.mock_storage.assert_has_calls(expected_write_calls)

class TestValidate:
@pytest.mark.parametrize("seed", (41, 42))
@pytest.mark.parametrize("transaction", TRANSACTIONS)
Expand Down

0 comments on commit 3aacc3a

Please sign in to comment.