diff --git a/src/kakarot/accounts/library.cairo b/src/kakarot/accounts/library.cairo index 6dfd56e77..1eb6e69d1 100644 --- a/src/kakarot/accounts/library.cairo +++ b/src/kakarot/accounts/library.cairo @@ -3,6 +3,7 @@ from openzeppelin.access.ownable.library import Ownable, Ownable_owner from starkware.cairo.common.alloc import alloc from starkware.cairo.common.bool import FALSE +from starkware.cairo.common.dict_access import DictAccess from starkware.cairo.common.cairo_builtins import HashBuiltin, BitwiseBuiltin from starkware.cairo.common.math import unsigned_div_rem, split_int, split_felt from starkware.cairo.common.memcpy import memcpy @@ -33,6 +34,7 @@ from utils.eth_transaction import EthTransaction from utils.uint256 import uint256_add from utils.bytes import bytes_to_bytes8_little_endian from utils.signature import Signature +from utils.utils import Helpers // @dev: should always be zero for EOAs @storage_var @@ -67,6 +69,10 @@ func Account_cairo1_helpers_class_hash() -> (res: felt) { func Account_valid_jumpdests() -> (is_valid: felt) { } +@storage_var +func Account_jumpdests_initialized() -> (initialized: felt) { +} + @event func transaction_executed(response_len: felt, response: felt*, success: felt, gas_used: felt) { } @@ -431,17 +437,23 @@ namespace AccountContract { return (latest_account_class, latest_cairo1_helpers_class); } + // @notice Writes an array of valid jumpdests indexes to storage. + // @param jumpdests_len The length of the jumpdests array. + // @param jumpdests The jumpdests array. func write_jumpdests{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( jumpdests_len: felt, jumpdests: felt* ) { // Recursively store the jumpdests. - Internals.write_jumpdests(jumpdests_len=jumpdests_len, jumpdests=jumpdests); + Internals.write_jumpdests( + jumpdests_len=jumpdests_len, jumpdests=jumpdests, iteration_size=1 + ); return (); } func is_valid_jumpdest{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( index: felt ) -> felt { + alloc_locals; let (base_address) = Account_valid_jumpdests.addr(); let index_address = base_address + index; @@ -454,7 +466,32 @@ namespace AccountContract { tempvar syscall_ptr = syscall_ptr + StorageRead.SIZE; tempvar value = response.value; - return value; + if (value != 0) { + return value; + } + + // Jumpdest is invalid - we verify that the jumpdests have been stored, and if not, + // we store them. call the appropriate function check & store + let (initialized) = Account_jumpdests_initialized.read(); + + if (initialized != FALSE) { + return value; + } + + let (bytecode_len) = Account_bytecode_len.read(); + let (bytecode) = Internals.load_bytecode(bytecode_len); + let (valid_jumpdests_start, valid_jumpdests) = Helpers.initialize_jumpdests( + bytecode_len, bytecode + ); + let (jumpdests_len, _) = unsigned_div_rem( + valid_jumpdests - valid_jumpdests_start, DictAccess.SIZE + ); + Internals.write_jumpdests( + jumpdests_len=jumpdests_len, + jumpdests=cast(valid_jumpdests_start, felt*), + iteration_size=DictAccess.SIZE, + ); + return is_valid_jumpdest(index=index); } } @@ -523,10 +560,13 @@ namespace Internals { } // @notice Store the jumpdests of the contract. - // @param jumpdests_len The length of the jumpdests. - // @param jumpdests The jumpdests of the contract. + // @dev This function can be used by either passing an array of valid jumpdests, + // or a dict that only contains valid entries (i.e. no invalid index has been read). + // @param jumpdests_len The length of the valid jumpdests. + // @param jumpdests The jumpdests of the contract. Can be an array of valid indexes or a dict. + // @param iteration_size The size of the object we are iterating over. func write_jumpdests{syscall_ptr: felt*, pedersen_ptr: HashBuiltin*, range_check_ptr}( - jumpdests_len: felt, jumpdests: felt* + jumpdests_len: felt, jumpdests: felt*, iteration_size: felt ) { alloc_locals; @@ -546,6 +586,7 @@ namespace Internals { let jumpdests = cast([ap - 2], felt*); let remaining = [ap - 1]; let base_address = [fp]; + let iteration_size = [fp - 3]; let index_to_store = [jumpdests]; tempvar storage_address = base_address + index_to_store; @@ -555,11 +596,12 @@ namespace Internals { ); %{ syscall_handler.storage_write(segments=segments, syscall_ptr=ids.syscall_ptr) %} tempvar syscall_ptr = syscall_ptr + StorageWrite.SIZE; - tempvar jumpdests = jumpdests + 1; + tempvar jumpdests = jumpdests + iteration_size; tempvar remaining = remaining - 1; jmp body if remaining != 0; + Account_jumpdests_initialized.write(1); return (); } diff --git a/tests/src/kakarot/accounts/test_contract_account.py b/tests/src/kakarot/accounts/test_contract_account.py index 22aaca3ac..9a034d061 100644 --- a/tests/src/kakarot/accounts/test_contract_account.py +++ b/tests/src/kakarot/accounts/test_contract_account.py @@ -157,7 +157,7 @@ def test__should_store_valid_jumpdests(self, cairo_run): class TestReadJumpdests: @pytest.fixture - def storage(self, jumpdests): + def store_jumpdests(self, jumpdests): base_address = get_storage_var_address("Account_valid_jumpdests") valid_addresses = [base_address + jumpdest for jumpdest in jumpdests] @@ -168,9 +168,11 @@ def _storage(address): @pytest.mark.parametrize("jumpdests", [[0x02, 0x10, 0xFF]]) def test__should_return_if_jumpdest_valid( - self, cairo_run, jumpdests, storage + self, cairo_run, jumpdests, store_jumpdests ): - with patch.object(SyscallHandler, "mock_storage", side_effect=storage): + with patch.object( + SyscallHandler, "mock_storage", side_effect=store_jumpdests + ): for jumpdest in jumpdests: assert cairo_run("test__is_valid_jumpdest", index=jumpdest) == 1 @@ -180,6 +182,66 @@ def test__should_return_if_jumpdest_valid( ] SyscallHandler.mock_storage.assert_has_calls(calls) + @pytest.fixture + def patch_account_storage(self, account_code): + code_len_address = get_storage_var_address("Account_bytecode_len") + base_jumpdests_address = get_storage_var_address( + "Account_valid_jumpdests" + ) + chunks = wrap(account_code, 2 * 31) + + def _storage(address, value=None): + if value is not None: + SyscallHandler.patches[address] = value + return + if address == code_len_address: + return len(bytes.fromhex(account_code)) + elif address >= base_jumpdests_address: + return 0 + return int(chunks[address], 16) + + return _storage + + # Code contains both valid and invalid jumpdests + # PUSH1 4 // Offset 0 + # JUMP // Offset 2 (previous instruction occupies 2 bytes) + # INVALID // Offset 3 + # JUMPDEST // Offset 4 + # PUSH1 1 // Offset 5 + # PUSH1 0x5B // invalid jumpdest + @pytest.mark.parametrize( + "account_code, jumpdests, results", + [("600456fe5b6001605b", [0x04, 0x08], [1, 0])], + ) + def test__should_return_if_jumpdest_valid_when_not_stored( + self, cairo_run, account_code, jumpdests, results, patch_account_storage + ): + with patch.object( + SyscallHandler, "mock_storage", side_effect=patch_account_storage + ): + for jumpdest, result in zip(jumpdests, results): + assert ( + cairo_run("test__is_valid_jumpdest", index=jumpdest) + == result + ) + + base_address = get_storage_var_address("Account_valid_jumpdests") + jumpdests_initialized_address = get_storage_var_address( + "Account_jumpdests_initialized" + ) + expected_read_calls = [ + call(address=base_address + jumpdest) for jumpdest in jumpdests + ] + [call(address=jumpdests_initialized_address)] + + expected_write_calls = [ + call(address=base_address + jumpdest, value=1) + for jumpdest, result in zip(jumpdests, results) + if result == 1 + ] + [call(address=jumpdests_initialized_address, value=1)] + + SyscallHandler.mock_storage.assert_has_calls(expected_read_calls) + SyscallHandler.mock_storage.assert_has_calls(expected_write_calls) + class TestValidate: @pytest.mark.parametrize("seed", (41, 42)) @pytest.mark.parametrize("transaction", TRANSACTIONS)