Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
from itertools import count
from pathlib import Path
from random import randint
from typing import Any, Dict, Generator, Iterator, List, Literal, Self, Tuple
from typing import Any, Dict, Generator, Iterator, List, Literal, Tuple

import pytest
import yaml
from filelock import FileLock
from pydantic import PrivateAttr

from execution_testing.base_types import (
Account,
Address,
Bytes,
EthereumTestRootModel,
Hash,
HexNumber,
Number,
Expand Down Expand Up @@ -45,6 +43,7 @@
from execution_testing.tools import Initcode
from execution_testing.vm import Bytecode, Op

from ..shared.address_stubs import AddressStubs
from ..shared.pre_alloc import Alloc as SharedAlloc
from ..shared.pre_alloc import AllocFlags
from .contracts import (
Expand All @@ -55,52 +54,6 @@
logger = get_logger(__name__)


class AddressStubs(EthereumTestRootModel[Dict[str, Address]]):
"""
Address stubs class.

The key represents the label that is used in the test to tag the contract,
and the value is the address where the contract is already located at in
the current network.
"""

root: Dict[str, Address]

def __contains__(self, item: str) -> bool:
"""Check if an item is in the address stubs."""
return item in self.root

def __getitem__(self, item: str) -> Address:
"""Get an item from the address stubs."""
return self.root[item]

@classmethod
def model_validate_json_or_file(cls, json_data_or_path: str) -> Self:
"""
Try to load from file if the value resembles a path that ends with
.json/.yml and the file exists.
"""
lower_json_data_or_path = json_data_or_path.lower()
if (
lower_json_data_or_path.endswith(".json")
or lower_json_data_or_path.endswith(".yml")
or lower_json_data_or_path.endswith(".yaml")
):
path = Path(json_data_or_path)
if path.is_file():
path_suffix = path.suffix.lower()
if path_suffix == ".json":
return cls.model_validate_json(path.read_text())
elif path_suffix in [".yml", ".yaml"]:
loaded_yaml = yaml.safe_load(path.read_text())
if loaded_yaml is None:
return cls(root={})
return cls.model_validate(loaded_yaml)
if json_data_or_path.strip() == "":
return cls(root={})
return cls.model_validate_json(json_data_or_path)


def pytest_addoption(parser: pytest.Parser) -> None:
"""Add command-line options to pytest."""
pre_alloc_group = parser.getgroup(
Expand Down Expand Up @@ -452,10 +405,11 @@ def _deploy_contract(
if not isinstance(storage, Storage):
storage = Storage(storage) # type: ignore

if stub is not None and self._address_stubs is not None:
if stub is not None:
if stub not in self._address_stubs:
raise ValueError(
f"Stub name {stub} not found in address stubs"
f"Stub '{stub}' not found in address stubs. "
"Provide --address-stubs with a mapping file."
)
contract_address = self._address_stubs[stub]
logger.info(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Pytest plugin to run the execute in remote-rpc-mode."""

import os
from pathlib import Path

import pytest
Expand All @@ -11,7 +10,7 @@
ChainConfigDefaults,
)

from ..pre_alloc import AddressStubs
from ...shared.helpers import get_rpc_endpoint
from .chain_builder_eth_rpc import ChainBuilderEthRPC, TestingRPC


Expand All @@ -25,7 +24,8 @@ def pytest_addoption(parser: pytest.Parser) -> None:
required=False,
action="store",
dest="rpc_endpoint",
help="RPC endpoint to an execution client",
default=None,
help="RPC endpoint to an execution client.",
)
remote_rpc_group.addoption(
"--tx-wait-timeout",
Expand All @@ -38,18 +38,6 @@ def pytest_addoption(parser: pytest.Parser) -> None:
"included in a block"
),
)
remote_rpc_group.addoption(
"--address-stubs",
action="store",
dest="address_stubs",
default=AddressStubs(root={}),
type=AddressStubs.model_validate_json_or_file,
help=(
"The address stubs for contracts that have already been placed "
"in the chain and to use for the test. Can be a JSON formatted "
"string or a path to a YAML or JSON file."
),
)

engine_rpc_group = parser.getgroup(
"engine_rpc", "Arguments defining engine RPC configuration"
Expand Down Expand Up @@ -107,9 +95,7 @@ def pytest_addoption(parser: pytest.Parser) -> None:
def pytest_configure(config: pytest.Config) -> None:
"""Check if a chain ID configuration is provided."""
# Verify chain ID config is consistent with the remote RPC endpoint
rpc_endpoint = config.getoption("rpc_endpoint") or os.environ.get(
"RPC_ENDPOINT"
)
rpc_endpoint = get_rpc_endpoint(config)
if rpc_endpoint is None:
pytest.fail(
"RPC endpoint must be provided with the --rpc-endpoint flag or "
Expand Down Expand Up @@ -174,9 +160,7 @@ def rpc_endpoint(request: pytest.FixtureRequest) -> str:
Return remote RPC endpoint to be used to make requests to the execution
client.
"""
rpc_endpoint = request.config.getoption("rpc_endpoint") or os.environ.get(
"RPC_ENDPOINT"
)
rpc_endpoint = get_rpc_endpoint(request.config)
assert rpc_endpoint is not None
return rpc_endpoint

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ def test_address_stubs(input_value: Any, expected: AddressStubs) -> None:
AddressStubs({}),
id="empty_address_stubs_json",
),
pytest.param(
"empty.yaml",
"",
AddressStubs({}),
id="empty_address_stubs_yaml",
),
pytest.param(
"one_address.json",
'{"DEPOSIT_CONTRACT_ADDRESS": "0x00000000219ab540356cbb839cbe05303d7705fa"}', # noqa: E501
Expand All @@ -62,18 +56,6 @@ def test_address_stubs(input_value: Any, expected: AddressStubs) -> None:
),
id="single_address_json",
),
pytest.param(
"one_address.yaml",
"DEPOSIT_CONTRACT_ADDRESS: 0x00000000219ab540356cbb839cbe05303d7705fa", # noqa: E501
AddressStubs(
{
"DEPOSIT_CONTRACT_ADDRESS": Address(
"0x00000000219ab540356cbb839cbe05303d7705fa"
),
}
),
id="single_address_yaml",
),
],
)
def test_address_stubs_from_files(
Expand All @@ -87,3 +69,10 @@ def test_address_stubs_from_files(
filename.write_text(file_contents)

assert AddressStubs.model_validate_json_or_file(str(filename)) == expected


def test_address_stubs_file_not_found(pytester: pytest.Pytester) -> None:
"""Test that a missing JSON file raises FileNotFoundError."""
missing_test = pytester.path.joinpath("nonexistent.json")
with pytest.raises(FileNotFoundError):
AddressStubs.model_validate_json_or_file(str(missing_test))
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from execution_testing.tools import Initcode

from ..shared.execute_fill import stub_accounts_key
from ..shared.pre_alloc import Alloc as SharedAlloc
from ..shared.pre_alloc import AllocFlags

Expand All @@ -48,9 +49,14 @@ def pytest_addoption(parser: pytest.Parser) -> None:
"pre_alloc",
"Arguments defining pre-allocation behavior during test filling.",
)

# No options for now
del pre_alloc_group
# TODO: consolidate with execute/rpc/remote.py
pre_alloc_group.addoption(
"--rpc-endpoint",
action="store",
dest="rpc_endpoint",
default=None,
help="RPC endpoint to an execution client.",
)


DELEGATION_DESIGNATION = b"\xef\x01\x00"
Expand All @@ -62,12 +68,20 @@ class Alloc(SharedAlloc):

_eoa_fund_amount_default: int = PrivateAttr(10**21)
_account_salt: Dict[Hash, int] = PrivateAttr(default_factory=dict)
_stub_accounts: Dict[str, Account] = PrivateAttr(default_factory=dict)

def __init__(
self, *args: Any, fork: Fork, flags: AllocFlags, **kwargs: Any
self,
*args: Any,
fork: Fork,
flags: AllocFlags,
stub_accounts: Dict[str, Account] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the pre-alloc."""
super().__init__(*args, fork=fork, flags=flags, **kwargs)
if stub_accounts is not None:
self._stub_accounts = stub_accounts

def get_next_account_salt(self, account_hash: Hash) -> int:
"""Retrieve the next salt for this account."""
Expand Down Expand Up @@ -232,28 +246,32 @@ def _deploy_contract(
label: str | None,
stub: str | None,
) -> Address:
"""
Filler implementation of contract deployment.
"""
del stub

if storage is None:
storage = {}
code = self.code_pre_processor(code)
code_bytes = (
bytes(code) if not isinstance(code, (bytes, str)) else code
)
max_code_size = self._fork.transitions_from().max_code_size()
assert len(code_bytes) <= max_code_size, (
f"code too large: {len(code_bytes)} > {max_code_size}"
)
"""Filler implementation of contract deployment."""
if stub is not None:
if stub not in self._stub_accounts:
raise ValueError(
f"Stub '{stub}' not found in address stubs. "
"Provide --address-stubs with a mapping file."
)
account = self._stub_accounts[stub]
else:
if storage is None:
storage = {}
code = self.code_pre_processor(code)
code_bytes = (
bytes(code) if not isinstance(code, (bytes, str)) else code
)
max_code_size = self._fork.transitions_from().max_code_size()
assert len(code_bytes) <= max_code_size, (
f"code too large: {len(code_bytes)} > {max_code_size}"
)

account = Account(
nonce=nonce,
balance=balance,
code=code,
storage=storage,
)
account = Account(
nonce=nonce,
balance=balance,
code=code,
storage=storage,
)

if address is not None:
assert address not in self, (
Expand Down Expand Up @@ -451,11 +469,20 @@ def eoa_by_index(i: int) -> EOA:
return EOA(key=TestPrivateKey + i if i != 1 else TestPrivateKey2, nonce=0)


@pytest.fixture(scope="session")
def stub_accounts(
request: pytest.FixtureRequest,
) -> Dict[str, Account]:
"""Return stub accounts pre-populated during configuration."""
return request.config.stash.get(stub_accounts_key, {})


@pytest.fixture(scope="function")
def pre(
alloc_flags: AllocFlags,
fork: Fork | None,
request: pytest.FixtureRequest,
stub_accounts: Dict[str, Account],
) -> Alloc:
"""Return default pre allocation for all tests (Empty alloc)."""
# FIXME: Static tests don't have a fork so we need to get it from the node.
Expand All @@ -467,4 +494,5 @@ def pre(
return Alloc(
flags=alloc_flags,
fork=actual_fork,
stub_accounts=stub_accounts,
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"--fork",
"--from",
"--until",
"--address-stubs",
"--help",
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Address stubs model shared by the filler and execute plugins.
This model maps stub labels to on-chain contract addresses.
"""

from pathlib import Path
from typing import Dict, Self

from execution_testing.base_types import Address, EthereumTestRootModel


class AddressStubs(EthereumTestRootModel[Dict[str, Address]]):
"""
Address stubs class.

The key represents the label that is used in the test to tag the contract,
and the value is the address where the contract is already located at in
the current network.
"""

root: Dict[str, Address]

def __contains__(self, item: str) -> bool:
"""Check if an item is in the address stubs."""
return item in self.root

def __getitem__(self, item: str) -> Address:
"""Get an item from the address stubs."""
return self.root[item]

@classmethod
def model_validate_json_or_file(cls, json_data_or_path: str) -> Self:
"""
Parse a JSON string or load from a JSON file.

If the value ends with `.json` and the file exists, the file
contents are loaded; otherwise the value is parsed as inline JSON.
"""
if json_data_or_path.lower().endswith(".json"):
path = Path(json_data_or_path)
if path.is_file():
return cls.model_validate_json(path.read_text())
else:
raise FileNotFoundError(
f"Address stubs file not found: {path}"
)
if json_data_or_path.strip() == "":
return cls(root={})
return cls.model_validate_json(json_data_or_path)
Loading
Loading