diff --git a/src/commands/node_manager_start.py b/src/commands/node_manager_start.py index 52aa1737..f8634d3f 100644 --- a/src/commands/node_manager_start.py +++ b/src/commands/node_manager_start.py @@ -6,12 +6,18 @@ import click from eth_typing import ChecksumAddress from sw_utils import InterruptHandler +from web3.types import Gwei from src.common.clients import close_clients, setup_clients from src.common.logging import LOG_LEVELS, setup_logging +from src.common.protocol_config import update_oracles_cache +from src.common.typings import ValidatorType from src.common.utils import log_verbose -from src.common.validators import validate_eth_address -from src.config.networks import AVAILABLE_NETWORKS, NETWORKS +from src.common.validators import ( + validate_eth_address, + validate_max_validator_balance_gwei, +) +from src.config.networks import AVAILABLE_NETWORKS, MAINNET, NETWORKS from src.config.settings import ( DEFAULT_CONSENSUS_ENDPOINT, DEFAULT_EXECUTION_ENDPOINT, @@ -21,14 +27,68 @@ ) from src.node_manager.startup_check import startup_checks from src.node_manager.tasks import NodeManagerTask +from src.validators.database import NetworkValidatorCrud +from src.validators.keystores.load import load_keystore logger = logging.getLogger(__name__) @click.option( - '--withdrawals-address', + '--keystores-password-file', + type=click.Path(exists=True, file_okay=True, dir_okay=False), + envvar='KEYSTORES_PASSWORD_FILE', + help='Absolute path to the password file for decrypting keystores.', +) +@click.option( + '--keystores-dir', + type=click.Path(exists=True, file_okay=False, dir_okay=True), + envvar='KEYSTORES_DIR', + help='Absolute path to the directory with all the encrypted keystores.', +) +@click.option( + '--wallet-password-file', + type=click.Path(exists=True, file_okay=True, dir_okay=False), + envvar='WALLET_PASSWORD_FILE', + help='Absolute path to the wallet password file.', +) +@click.option( + '--wallet-file', + type=click.Path(exists=True, file_okay=True, dir_okay=False), + envvar='WALLET_FILE', + help='Absolute path to the wallet.', +) +@click.option( + '--max-validator-balance-gwei', + type=int, + envvar='MAX_VALIDATOR_BALANCE_GWEI', + help=f'The maximum validator balance in Gwei. ' + f'Default is {NETWORKS[MAINNET].MAX_VALIDATOR_BALANCE_GWEI} Gwei', + callback=validate_max_validator_balance_gwei, +) +@click.option( + '--validator-type', + help='Type of the validators to register:' + f' {ValidatorType.V1.value} or {ValidatorType.V2.value}.', + envvar='VALIDATOR_TYPE', + default=ValidatorType.V2.value, + type=click.Choice( + [x.value for x in ValidatorType], + case_sensitive=False, + ), + callback=lambda ctx, param, value: ValidatorType(value), + show_default=True, +) +@click.option( + '--max-fee-per-gas-gwei', + type=int, + envvar='MAX_FEE_PER_GAS_GWEI', + help=f'Maximum fee per gas for transactions. ' + f'Default is {NETWORKS[MAINNET].MAX_FEE_PER_GAS_GWEI} Gwei', +) +@click.option( + '--operator-address', callback=validate_eth_address, - envvar='WITHDRAWALS_ADDRESS', + envvar='OPERATOR_ADDRESS', prompt='Enter your operator withdrawals (cold wallet) address', help='The operator withdrawals (cold wallet) address.', ) @@ -96,7 +156,7 @@ show_default=True, ) @click.command(help='Start node manager operator service') -# pylint: disable-next=too-many-arguments +# pylint: disable-next=too-many-arguments,too-many-locals def node_manager_start( consensus_endpoints: str, execution_endpoints: str, @@ -106,7 +166,14 @@ def node_manager_start( log_level: str, log_format: str, network: str, - withdrawals_address: ChecksumAddress, + operator_address: ChecksumAddress, + max_fee_per_gas_gwei: int | None, + validator_type: ValidatorType, + max_validator_balance_gwei: int | None, + wallet_file: str | None, + wallet_password_file: str | None, + keystores_dir: str | None, + keystores_password_file: str | None, ) -> None: network_config = NETWORKS[network] vault = network_config.COMMUNITY_VAULT_CONTRACT_ADDRESS @@ -122,27 +189,50 @@ def node_manager_start( verbose=verbose, log_level=log_level, log_format=log_format, + max_fee_per_gas_gwei=max_fee_per_gas_gwei, + validator_type=validator_type, + max_validator_balance_gwei=( + Gwei(max_validator_balance_gwei) if max_validator_balance_gwei else None + ), + keystores_dir=keystores_dir, + keystores_password_file=keystores_password_file, + wallet_file=wallet_file, + wallet_password_file=wallet_password_file, ) try: - asyncio.run(_start(withdrawals_address)) + asyncio.run(_start(operator_address)) except Exception as e: log_verbose(e) sys.exit(1) -async def _start(withdrawals_address: ChecksumAddress) -> None: +async def _start( + operator_address: ChecksumAddress, +) -> None: setup_logging() await setup_clients() if not settings.skip_startup_checks: - await startup_checks(withdrawals_address) + await startup_checks(operator_address) try: + NetworkValidatorCrud().setup() + + keystore = await load_keystore() + + # start operator tasks + logger.info('Updating oracles cache...') + await update_oracles_cache() + logger.info( 'Started node manager service, polling eligibility for %s', - withdrawals_address, + operator_address, ) with InterruptHandler() as interrupt_handler: - await NodeManagerTask(withdrawals_address).run(interrupt_handler) + task = NodeManagerTask( + operator_address=operator_address, + keystore=keystore, + ) + await task.run(interrupt_handler) finally: await close_clients() diff --git a/src/node_manager/oracles.py b/src/node_manager/oracles.py index 13da5819..f6914e5e 100644 --- a/src/node_manager/oracles.py +++ b/src/node_manager/oracles.py @@ -1,20 +1,50 @@ import asyncio +import dataclasses import logging import random +from collections import defaultdict +from typing import Callable, Sequence, TypeVar from aiohttp import ClientError, ClientSession, ClientTimeout +from eth_account.messages import encode_defunct +from eth_typing import ChecksumAddress, HexStr from sw_utils.common import urljoin from sw_utils.typings import ProtocolConfig from web3 import Web3 from web3.types import Wei -from src.common.utils import format_error, warning_verbose +from src.common.contracts import validators_registry_contract +from src.common.exceptions import ( + InvalidOraclesRequestError, + NotEnoughOracleApprovalsError, +) +from src.common.utils import ( + RateLimiter, + format_error, + get_current_timestamp, + warning_verbose, +) +from src.common.wallet import wallet from src.config.settings import ORACLES_VALIDATORS_TIMEOUT -from src.node_manager.typings import EligibleOperator +from src.node_manager.typings import ( + EligibleOperator, + NodeManagerApprovalRequest, + NodeManagerRegistrationApproval, + NodeManagerRegistrationOraclesApproval, +) +from src.validators.execution import get_validators_start_index +from src.validators.keystores.base import BaseKeystore +from src.validators.signing.common import get_encrypted_exit_signature_shards +from src.validators.typings import Validator + +T = TypeVar('T') logger = logging.getLogger(__name__) ELIGIBLE_OPERATORS_PATH = '/nodes-manager/eligible-operators' +REGISTER_VALIDATORS_PATH = '/nodes-manager/register-validators' + +# Eligible operators polling async def poll_eligible_operators( @@ -83,3 +113,294 @@ async def _fetch_eligible_operators( ) for item in data ] + + +async def poll_registration_approval( + keystore: BaseKeystore, + validators: Sequence[Validator], + operator_address: ChecksumAddress, + protocol_config: ProtocolConfig, +) -> tuple[NodeManagerApprovalRequest, NodeManagerRegistrationOraclesApproval]: + """Poll oracles until registration approval is obtained.""" + oracles_request: NodeManagerApprovalRequest | None = None + deadline: int | None = None + validators_registry_root = await validators_registry_contract.get_registry_root() + + approvals_min_interval = 1 + rate_limiter = RateLimiter(approvals_min_interval) + + while True: + await rate_limiter.ensure_interval() + + current_registry_root = await validators_registry_contract.get_registry_root() + if current_registry_root != validators_registry_root: + validators_registry_root = current_registry_root + oracles_request = None + + current_timestamp = get_current_timestamp() + if oracles_request is None or deadline is None or deadline <= current_timestamp: + deadline = current_timestamp + protocol_config.signature_validity_period + + oracles_request = await create_approval_request( + protocol_config=protocol_config, + keystore=keystore, + validators=validators, + registry_root=current_registry_root, + deadline=deadline, + operator_address=operator_address, + ) + + try: + raw_approvals = await send_registration_requests(protocol_config, oracles_request) + oracles_approval = process_registration_approvals( + raw_approvals, protocol_config.validators_threshold + ) + return oracles_request, oracles_approval + except NotEnoughOracleApprovalsError as e: + logger.error( + 'Not enough oracle approvals for community vault registration: %d.' + ' Threshold is %d.', + e.num_votes, + e.threshold, + ) + except InvalidOraclesRequestError: + logger.error('All oracles failed to respond for community vault registration') + + +# Registration approval polling + + +# pylint: disable-next=too-many-arguments +async def create_approval_request( + protocol_config: ProtocolConfig, + keystore: BaseKeystore, + validators: Sequence[Validator], + registry_root: HexStr, + deadline: int, + operator_address: ChecksumAddress, +) -> NodeManagerApprovalRequest: + """Build a NodesManager approval request with exit signature shards.""" + validators_start_index = await get_validators_start_index() + logger.debug( + 'Next validator index for community vault exit signature: %d', validators_start_index + ) + + request = NodeManagerApprovalRequest( + validator_index=validators_start_index, + operator_address=operator_address, + validators_root=registry_root, + public_keys=[], + deposit_signatures=[], + public_key_shards=[], + exit_signature_shards=[], + deadline=deadline, + amounts=[], + validators_manager_signature=_sign_deadline(deadline), + ) + + for validator_index, validator in enumerate(validators, validators_start_index): + shards = validator.exit_signature_shards + + if not shards: + shards = await get_encrypted_exit_signature_shards( + keystore=keystore, + public_key=validator.public_key, + validator_index=validator_index, + protocol_config=protocol_config, + exit_signature=validator.exit_signature, + ) + + if not shards: + logger.warning( + 'Failed to get exit signature shards for validator %s', validator.public_key + ) + break + + if validator.deposit_signature is None: + raise ValueError('Deposit signature is required for validator') + + request.public_keys.append(validator.public_key) + request.deposit_signatures.append(validator.deposit_signature) + request.public_key_shards.append(shards.public_keys) + request.exit_signature_shards.append(shards.exit_signatures) + request.amounts.append(validator.amount) + + if not request.public_keys: + raise ValueError( + 'Failed to build validator registration request:' + ' no validators with valid exit signature shards' + ) + + return request + + +# Generic oracle request helpers + + +async def send_registration_requests( + protocol_config: ProtocolConfig, + request: NodeManagerApprovalRequest, +) -> dict[ChecksumAddress, NodeManagerRegistrationApproval]: + """Request registration approval from all oracles in parallel.""" + return await _send_oracle_requests( + protocol_config, + dataclasses.asdict(request), + REGISTER_VALIDATORS_PATH, + _parse_registration_response, + ) + + +def _parse_registration_response(data: dict) -> NodeManagerRegistrationApproval: + """Parse oracle response containing both keeper and Nodes Manager signatures.""" + keeper_params = data['keeper_params'] + return NodeManagerRegistrationApproval( + keeper_signature=HexStr(keeper_params['signature']), + nodes_manager_signature=HexStr(data['nodes_manager_signature']), + ipfs_hash=keeper_params['ipfs_hash'], + deadline=keeper_params['deadline'], + ) + + +def process_registration_approvals( + approvals: dict[ChecksumAddress, NodeManagerRegistrationApproval], + votes_threshold: int, +) -> NodeManagerRegistrationOraclesApproval: + """Combine registration approvals into separate keeper and Nodes Manager signature blobs.""" + candidates: dict[ + tuple[str, int], list[tuple[ChecksumAddress, NodeManagerRegistrationApproval]] + ] = defaultdict(list) + for address, approval in approvals.items(): + candidates[approval.ipfs_hash, approval.deadline].append((address, approval)) + + if not candidates: + raise InvalidOraclesRequestError() + + winner = max(candidates, key=lambda x: len(candidates[x])) + votes = candidates[winner] + if len(votes) < votes_threshold: + raise NotEnoughOracleApprovalsError(num_votes=len(votes), threshold=votes_threshold) + + sorted_votes = sorted(votes, key=lambda x: Web3.to_int(hexstr=x[0]))[:votes_threshold] + + keeper_signatures: list[HexStr] = [] + signatures: list[HexStr] = [] + for _, approval in sorted_votes: + keeper_signatures.append(approval.keeper_signature) + signatures.append(approval.nodes_manager_signature) + + return NodeManagerRegistrationOraclesApproval( + nodes_manager_signatures=signatures, + keeper_signatures=keeper_signatures, + ipfs_hash=winner[0], + deadline=winner[1], + ) + + +async def _send_oracle_requests( + protocol_config: ProtocolConfig, + payload: dict, + path: str, + parser: Callable[[dict], T], +) -> dict[ChecksumAddress, T]: + """Send a NodesManager request to all oracles in parallel and collect approvals.""" + endpoints = [(oracle.address, oracle.endpoints) for oracle in protocol_config.oracles] + + async with ClientSession(timeout=ClientTimeout(ORACLES_VALIDATORS_TIMEOUT)) as session: + results = await asyncio.gather( + *[ + _send_request_to_replicas( + session=session, + replicas=replicas, + payload=payload, + path=path, + parser=parser, + ) + for _, replicas in endpoints + ], + return_exceptions=True, + ) + + approvals: dict[ChecksumAddress, T] = {} + failed_endpoints: list[str] = [] + + for (address, replicas), result in zip(endpoints, results): + if isinstance(result, BaseException): + warning_verbose( + 'All endpoints for oracle %s failed to sign community vault request (%s). ' + 'Last error: %s', + address, + path, + format_error(result), + ) + failed_endpoints.extend(replicas) + continue + + approvals[address] = result + + logger.info( + 'Fetched oracle approvals for community vault request %s: ' + 'deadline=%d. Received %d out of %d approvals.', + path, + payload.get('deadline', 0), + len(approvals), + len(protocol_config.oracles), + ) + + if failed_endpoints: + logger.error( + 'The oracles with endpoints %s have failed to respond.', ', '.join(failed_endpoints) + ) + + return approvals + + +# pylint: disable=duplicate-code +async def _send_request_to_replicas( + session: ClientSession, + replicas: list[str], + payload: dict, + path: str, + parser: Callable[[dict], T], +) -> T: + """Try replicas in random order, return first success.""" + last_error: BaseException | None = None + replicas = random.sample(replicas, len(replicas)) # nosec + + for endpoint in replicas: + try: + return await _send_request(session, endpoint, payload, path, parser) + except (ClientError, asyncio.TimeoutError) as e: + warning_verbose('%s for endpoint %s', format_error(e), endpoint) + last_error = e + + if last_error: + raise last_error + + raise RuntimeError('Failed to get response from replicas') + + +async def _send_request( + session: ClientSession, + endpoint: str, + payload: dict, + path: str, + parser: Callable[[dict], T], +) -> T: + """Send a NodesManager POST request to a single oracle endpoint.""" + url = urljoin(endpoint, path) + logger.debug('Sending community vault request to %s', url) + + async with session.post(url=url, json=payload) as response: + if response.status == 400: + logger.warning('%s response: %s', url, await response.json()) + response.raise_for_status() + data = await response.json() + + logger.debug('Received community vault response from %s: %s', url, data) + return parser(data) + + +def _sign_deadline(deadline: int) -> HexStr: + """EIP-191 personal_sign of the deadline timestamp.""" + message = encode_defunct(text=str(deadline)) + return HexStr(wallet.sign_message(message).signature.hex()) diff --git a/src/node_manager/register_validators.py b/src/node_manager/register_validators.py new file mode 100644 index 00000000..ae6e894c --- /dev/null +++ b/src/node_manager/register_validators.py @@ -0,0 +1,89 @@ +import logging +from typing import Sequence + +from eth_typing import ChecksumAddress, HexStr +from sw_utils.typings import Bytes32 +from web3 import Web3 +from web3.exceptions import ContractLogicError + +from src.common.clients import execution_client +from src.common.contracts import nodes_manager_contract, validators_registry_contract +from src.common.execution import build_gas_manager +from src.common.utils import format_error +from src.config.settings import settings +from src.node_manager.typings import NodeManagerRegistrationOraclesApproval +from src.validators.execution import get_validators_start_index +from src.validators.signing.common import encode_tx_validator_list +from src.validators.typings import Validator + +logger = logging.getLogger(__name__) + + +# pylint: disable=too-many-locals +async def register_validators( + operator_address: ChecksumAddress, + approval: NodeManagerRegistrationOraclesApproval, + validators: Sequence[Validator], + validators_registry_root: HexStr, + validator_index: int, +) -> HexStr | None: + """Submit registerValidators transaction to NodesManager contract.""" + registry_root = await validators_registry_contract.get_registry_root() + if registry_root != validators_registry_root: + logger.info('Validators registry root has changed. Retrying...') + return None + + current_validator_index = await get_validators_start_index() + if current_validator_index != validator_index: + logger.info('Validator index has changed. Retrying...') + return None + + tx_validators = [ + Web3.to_bytes(tx_validator) + for tx_validator in encode_tx_validator_list(validators=validators) + ] + signatures = b''.join(Web3.to_bytes(hexstr=s) for s in approval.nodes_manager_signatures) + keeper_params = ( + Bytes32(Web3.to_bytes(hexstr=validators_registry_root)), + approval.deadline, + b''.join(tx_validators), + b''.join(Web3.to_bytes(hexstr=s) for s in approval.keeper_signatures), + approval.ipfs_hash, + ) + + logger.info('Submitting community vault validator registration transaction') + + try: + await nodes_manager_contract.functions.registerValidators( + operator_address, keeper_params, signatures + ).estimate_gas() + except (ValueError, ContractLogicError) as e: + logger.error('Failed to register community vault validator(s): %s', format_error(e)) + if settings.verbose: + logger.exception(e) + return None + + try: + gas_manager = build_gas_manager() + tx_params = await gas_manager.get_high_priority_tx_params() + tx = await nodes_manager_contract.functions.registerValidators( + operator_address, keeper_params, signatures + ).transact(tx_params) + except Exception as e: + logger.error('Failed to register community vault validator(s): %s', format_error(e)) + if settings.verbose: + logger.exception(e) + return None + + tx_hash = Web3.to_hex(tx) + logger.info( + 'Waiting for register community vault validator(s) transaction %s confirmation', tx_hash + ) + tx_receipt = await execution_client.eth.wait_for_transaction_receipt( + tx, timeout=settings.execution_transaction_timeout + ) + if not tx_receipt['status']: + logger.error('Register community vault validator(s) transaction failed') + return None + + return tx_hash diff --git a/src/node_manager/startup_check.py b/src/node_manager/startup_check.py index b38c6f05..1a271427 100644 --- a/src/node_manager/startup_check.py +++ b/src/node_manager/startup_check.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -async def startup_checks(withdrawals_address: ChecksumAddress) -> None: +async def startup_checks(operator_address: ChecksumAddress) -> None: validate_settings() logger.info('Checking connection to database...') @@ -79,7 +79,7 @@ async def startup_checks(withdrawals_address: ChecksumAddress) -> None: wait_for_keystores_dir() logger.info('Found keystores dir') - await _check_validators_manager(withdrawals_address) + await _check_validators_manager(operator_address) await _check_community_vault() @@ -93,8 +93,8 @@ async def _check_community_vault() -> None: ) -async def _check_validators_manager(withdrawals_address: ChecksumAddress) -> None: - validators_manager = await nodes_manager_contract.validators_manager(withdrawals_address) +async def _check_validators_manager(operator_address: ChecksumAddress) -> None: + validators_manager = await nodes_manager_contract.validators_manager(operator_address) if validators_manager != wallet.account.address: raise ClickException( diff --git a/src/node_manager/tasks.py b/src/node_manager/tasks.py index 8bd45bec..264849c0 100644 --- a/src/node_manager/tasks.py +++ b/src/node_manager/tasks.py @@ -2,33 +2,131 @@ from eth_typing import ChecksumAddress from sw_utils import InterruptHandler +from sw_utils.typings import ProtocolConfig from web3 import Web3 +from web3.types import Gwei, Wei +from src.common.execution import check_gas_price from src.common.protocol_config import get_protocol_config from src.common.tasks import BaseTask -from src.node_manager.oracles import poll_eligible_operators +from src.common.typings import ValidatorType +from src.config.settings import settings +from src.node_manager.oracles import poll_eligible_operators, poll_registration_approval +from src.node_manager.register_validators import register_validators +from src.validators.keystores.base import BaseKeystore +from src.validators.tasks import get_deposits_amounts +from src.validators.utils import get_validators_for_registration logger = logging.getLogger(__name__) class NodeManagerTask(BaseTask): - """Periodically polls oracles to check operator eligibility.""" + """Periodically polls oracles to check operator eligibility and register validators.""" - def __init__(self, withdrawals_address: ChecksumAddress) -> None: - self.withdrawals_address = withdrawals_address + def __init__( + self, + operator_address: ChecksumAddress, + keystore: BaseKeystore, + ) -> None: + self.operator_address = operator_address + self.keystore = keystore async def process_block(self, interrupt_handler: InterruptHandler) -> None: + if not await check_gas_price(high_priority=True): + logger.debug('Gas price too high, skipping validators registration') + return + protocol_config = await get_protocol_config() + + eligible_amount = await self._get_eligible_amount(protocol_config) + if eligible_amount is None: + return + + logger.info( + 'Operator %s is eligible to register/fund %s ETH worth of validators', + self.operator_address, + eligible_amount, + ) + + amount_gwei = Gwei(int(Web3.from_wei(eligible_amount, 'gwei'))) + + if settings.validator_type == ValidatorType.V1: + if not settings.disable_validators_registration: + await self._process_registration( + amount=amount_gwei, + protocol_config=protocol_config, + ) + return + + # Fund existing compounding validators first + if not settings.disable_validators_funding: + amount_gwei = await self._process_funding( + amount=amount_gwei, + operator_address=self.operator_address, + protocol_config=protocol_config, + ) + + # Register new validators with remaining amount + if not settings.disable_validators_registration: + await self._process_registration( + amount=amount_gwei, + protocol_config=protocol_config, + ) + + async def _get_eligible_amount(self, protocol_config: ProtocolConfig) -> Wei | None: eligible_operators = await poll_eligible_operators(protocol_config) for operator in eligible_operators: - if operator.address == self.withdrawals_address: - amount_eth = Web3.from_wei(operator.amount, 'ether') - logger.info( - 'Operator %s is eligible to register/fund %s ETH worth of validators', - self.withdrawals_address, - amount_eth, - ) - return + if operator.address == self.operator_address: + return operator.amount + return None + + async def _process_registration( + self, + amount: Gwei, + protocol_config: ProtocolConfig, + ) -> None: + """Register new validators with the eligible amount.""" + amounts = get_deposits_amounts(amount, settings.validator_type) + if not amounts: + logger.info('No remaining amount for new validator registration') + return + + batch_limit = protocol_config.validators_approval_batch_limit + amounts = amounts[:batch_limit] + + validators = await get_validators_for_registration(self.keystore, amounts) + if not validators: + logger.warning('No available validators for registration') + return + + request, approval = await poll_registration_approval( + keystore=self.keystore, + validators=validators, + operator_address=self.operator_address, + protocol_config=protocol_config, + ) + + tx_hash = await register_validators( + operator_address=self.operator_address, + approval=approval, + validators=validators, + validators_registry_root=request.validators_root, + validator_index=request.validator_index, + ) + + if tx_hash: + pub_keys = ', '.join([v.public_key for v in validators]) + logger.info('Registered community vault validators %s: tx=%s', pub_keys, tx_hash) - logger.debug('Operator %s is not eligible', self.withdrawals_address) + async def _process_funding( + self, + amount: Gwei, + operator_address: ChecksumAddress, + protocol_config: ProtocolConfig, + ) -> Gwei: + # linter mock + logger.info( + 'amount: %s; operator: %s; protocol: %s', amount, operator_address, protocol_config + ) + return Gwei(amount) diff --git a/src/node_manager/tests/__init__.py b/src/node_manager/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/node_manager/tests/test_oracles.py b/src/node_manager/tests/test_oracles.py new file mode 100644 index 00000000..0617c263 --- /dev/null +++ b/src/node_manager/tests/test_oracles.py @@ -0,0 +1,146 @@ +import pytest +from eth_typing import ChecksumAddress, HexStr +from sw_utils.tests.factories import faker + +from src.common.exceptions import ( + InvalidOraclesRequestError, + NotEnoughOracleApprovalsError, +) +from src.node_manager.oracles import ( + _parse_registration_response, + process_registration_approvals, +) +from src.node_manager.typings import ( + NodeManagerRegistrationApproval, + NodeManagerRegistrationOraclesApproval, +) + +ORACLE_ADDRESSES: list[ChecksumAddress] = sorted( + [faker.eth_address() for _ in range(5)], +) + + +class TestProcessRegistrationApprovals: + def test_basic_consensus(self) -> None: + """All oracles agree on the same ipfs_hash and deadline.""" + approvals = { + ORACLE_ADDRESSES[0]: _make_registration_approval( + keeper_sig=faker.account_signature(), + sig=faker.account_signature(), + ), + ORACLE_ADDRESSES[1]: _make_registration_approval( + keeper_sig=faker.account_signature(), + sig=faker.account_signature(), + ), + ORACLE_ADDRESSES[2]: _make_registration_approval( + keeper_sig=faker.account_signature(), + sig=faker.account_signature(), + ), + } + result = process_registration_approvals(approvals, votes_threshold=2) + + assert isinstance(result, NodeManagerRegistrationOraclesApproval) + assert result.ipfs_hash == 'QmTest123' + assert result.deadline == 1000 + # Signatures are sorted by oracle address (ascending int value) and truncated to threshold + assert len(result.keeper_signatures) == 2 + assert len(result.nodes_manager_signatures) == 2 + + def test_exact_threshold(self) -> None: + """Exactly threshold votes should succeed.""" + approvals = { + ORACLE_ADDRESSES[0]: _make_registration_approval(), + ORACLE_ADDRESSES[1]: _make_registration_approval(), + } + result = process_registration_approvals(approvals, votes_threshold=2) + assert len(result.keeper_signatures) == 2 + + def test_below_threshold_raises(self) -> None: + approvals = { + ORACLE_ADDRESSES[0]: _make_registration_approval(), + } + with pytest.raises(NotEnoughOracleApprovalsError) as exc_info: + process_registration_approvals(approvals, votes_threshold=2) + assert exc_info.value.num_votes == 1 + assert exc_info.value.threshold == 2 + + def test_empty_approvals_raises(self) -> None: + with pytest.raises(InvalidOraclesRequestError): + process_registration_approvals({}, votes_threshold=1) + + def test_split_vote_picks_majority(self) -> None: + """When oracles disagree on ipfs_hash/deadline, pick the group with the most votes.""" + approvals = { + ORACLE_ADDRESSES[0]: _make_registration_approval(ipfs_hash='QmA', deadline=100), + ORACLE_ADDRESSES[1]: _make_registration_approval(ipfs_hash='QmB', deadline=200), + ORACLE_ADDRESSES[2]: _make_registration_approval(ipfs_hash='QmA', deadline=100), + } + result = process_registration_approvals(approvals, votes_threshold=2) + assert result.ipfs_hash == 'QmA' + assert result.deadline == 100 + + def test_split_vote_below_threshold(self) -> None: + """No group reaches threshold → error.""" + approvals = { + ORACLE_ADDRESSES[0]: _make_registration_approval(ipfs_hash='QmA', deadline=100), + ORACLE_ADDRESSES[1]: _make_registration_approval(ipfs_hash='QmB', deadline=200), + ORACLE_ADDRESSES[2]: _make_registration_approval(ipfs_hash='QmC', deadline=300), + } + with pytest.raises(NotEnoughOracleApprovalsError): + process_registration_approvals(approvals, votes_threshold=2) + + def test_signatures_sorted_by_address(self) -> None: + """Signatures are concatenated in ascending oracle address order.""" + addr_low = ORACLE_ADDRESSES[0] + addr_high = ORACLE_ADDRESSES[-1] + keeper_sig_low = faker.account_signature() + keeper_sig_high = faker.account_signature() + sig_low = faker.account_signature() + sig_high = faker.account_signature() + approvals = { + addr_high: _make_registration_approval(keeper_sig=keeper_sig_high, sig=sig_high), + addr_low: _make_registration_approval(keeper_sig=keeper_sig_low, sig=sig_low), + } + result = process_registration_approvals(approvals, votes_threshold=2) + assert result.keeper_signatures[0] == keeper_sig_low + assert result.keeper_signatures[1] == keeper_sig_high + + +class TestParsers: + def test_parse_registration_response(self) -> None: + keeper_sig_hex = faker.account_signature() + sig_hex = faker.account_signature() + data = { + 'keeper_params': { + 'signature': keeper_sig_hex, + 'ipfs_hash': 'QmTest', + 'deadline': 12345, + }, + 'nodes_manager_signature': sig_hex, + } + result = _parse_registration_response(data) + assert result.keeper_signature == keeper_sig_hex + assert result.nodes_manager_signature == sig_hex + assert result.ipfs_hash == 'QmTest' + assert result.deadline == 12345 + + +# --- Helpers --- + + +def _make_registration_approval( + keeper_sig: HexStr | None = None, + sig: HexStr | None = None, + ipfs_hash: str = 'QmTest123', + deadline: int = 1000, +) -> NodeManagerRegistrationApproval: + if keeper_sig is None: + keeper_sig = faker.account_signature() + if sig is None: + sig = faker.account_signature() + return NodeManagerRegistrationApproval( + keeper_signature=keeper_sig, + nodes_manager_signature=sig, + ipfs_hash=ipfs_hash, + deadline=deadline, + ) diff --git a/src/node_manager/tests/test_oracles_http.py b/src/node_manager/tests/test_oracles_http.py new file mode 100644 index 00000000..0ec025d4 --- /dev/null +++ b/src/node_manager/tests/test_oracles_http.py @@ -0,0 +1,190 @@ +import pytest +from aioresponses import aioresponses +from eth_typing import HexStr +from sw_utils.tests.factories import faker, get_mocked_protocol_config +from sw_utils.typings import Oracle, ProtocolConfig +from web3 import Web3 +from web3.types import Wei + +from src.common.tests.utils import ether_to_gwei +from src.node_manager.oracles import poll_eligible_operators, send_registration_requests +from src.node_manager.typings import NodeManagerApprovalRequest + +# --- poll_eligible_operators tests --- + + +@pytest.mark.usefixtures('fake_settings') +class TestPollEligibleOperators: + async def test_returns_eligible_operators(self) -> None: + config = _make_protocol_config([['http://oracle1']]) + operator_address = faker.eth_address() + response_data = [ + {'address': operator_address.lower(), 'amount': Web3.to_wei(32, 'ether')}, + ] + + with aioresponses() as m: + m.get( + 'http://oracle1/nodes-manager/eligible-operators', + payload=response_data, + ) + result = await poll_eligible_operators(config) + + assert len(result) == 1 + assert result[0].address == operator_address + assert result[0].amount == Web3.to_wei(32, 'ether') + + async def test_returns_empty_on_all_failures(self) -> None: + config = _make_protocol_config([['http://oracle1']]) + + with aioresponses() as m: + m.get( + 'http://oracle1/nodes-manager/eligible-operators', + status=500, + ) + result = await poll_eligible_operators(config) + + assert result == [] + + async def test_falls_back_to_next_oracle(self) -> None: + config = _make_protocol_config( + [['http://oracle1'], ['http://oracle2']], + ) + address = faker.eth_address().lower() + response_data = [ + {'address': address, 'amount': 100}, + ] + + with aioresponses() as m: + m.get('http://oracle1/nodes-manager/eligible-operators', status=500) + m.get('http://oracle2/nodes-manager/eligible-operators', payload=response_data) + result = await poll_eligible_operators(config) + + assert len(result) == 1 + assert result[0].address == Web3.to_checksum_address(address) + assert result[0].amount == Wei(100) + + async def test_replica_fallback(self) -> None: + """If first replica fails, tries the next one.""" + config = _make_protocol_config([['http://replica1', 'http://replica2']]) + address = faker.eth_address().lower() + response_data = [{'address': address, 'amount': 100}] + + with aioresponses() as m: + m.get('http://replica1/nodes-manager/eligible-operators', status=500) + m.get('http://replica2/nodes-manager/eligible-operators', payload=response_data) + result = await poll_eligible_operators(config) + + assert len(result) == 1 + assert result[0].address == Web3.to_checksum_address(address) + assert result[0].amount == Wei(100) + + +# --- send_registration_requests tests --- + + +@pytest.mark.usefixtures('fake_settings') +class TestSendRegistrationRequests: + async def test_collects_approvals(self) -> None: + config = _make_protocol_config( + [['http://oracle1'], ['http://oracle2']], + threshold=2, + ) + request = _make_registration_request() + keeper_signature = faker.account_signature() + signature = faker.account_signature() + + oracle_response = { + 'keeper_params': { + 'signature': keeper_signature, + 'ipfs_hash': faker.ipfs_hash(), + 'deadline': 1000, + }, + 'nodes_manager_signature': signature, + } + + with aioresponses() as m: + m.post( + 'http://oracle1/nodes-manager/register-validators', + payload=oracle_response, + ) + m.post( + 'http://oracle2/nodes-manager/register-validators', + payload=oracle_response, + ) + approvals = await send_registration_requests(config, request) + + assert len(approvals) == 2 + for approval in approvals.values(): + assert approval.keeper_signature == keeper_signature + assert approval.nodes_manager_signature == signature + + async def test_partial_failure(self) -> None: + """One oracle fails, the other succeeds — still returns what we got.""" + config = _make_protocol_config( + [['http://oracle1'], ['http://oracle2']], + threshold=1, + ) + request = _make_registration_request() + keeper_signature = faker.account_signature() + signature = faker.account_signature() + oracle_response = { + 'keeper_params': { + 'signature': keeper_signature, + 'ipfs_hash': faker.ipfs_hash(), + 'deadline': 1000, + }, + 'nodes_manager_signature': signature, + } + + with aioresponses() as m: + m.post('http://oracle1/nodes-manager/register-validators', status=500) + m.post( + 'http://oracle2/nodes-manager/register-validators', + payload=oracle_response, + ) + approvals = await send_registration_requests(config, request) + + assert len(approvals) == 1 + approval = next(iter(approvals.values())) + assert approval.keeper_signature == keeper_signature + assert approval.nodes_manager_signature == signature + + +# --- Helpers --- + +_ORACLE_PUBKEYS: list[HexStr] = [faker.account_public_key() for _ in range(9)] + + +def _make_protocol_config( + oracle_endpoints: list[list[str]], + threshold: int = 2, +) -> ProtocolConfig: + oracles = [] + for i, endpoints in enumerate(oracle_endpoints): + oracles.append( + Oracle( + public_key=_ORACLE_PUBKEYS[i], + endpoints=endpoints, + ) + ) + return get_mocked_protocol_config( + oracles=oracles, + validators_threshold=threshold, + signature_validity_period=60, + validators_approval_batch_limit=10, + ) + + +def _make_registration_request() -> NodeManagerApprovalRequest: + return NodeManagerApprovalRequest( + validator_index=0, + operator_address=faker.eth_address(), + validators_root=faker.merkle_root(), + public_keys=[faker.validator_public_key()], + deposit_signatures=[faker.validator_signature()], + public_key_shards=[[faker.validator_public_key()]], + exit_signature_shards=[[faker.validator_signature()]], + deadline=1000, + amounts=[ether_to_gwei(32)], + validators_manager_signature=faker.account_signature(), + ) diff --git a/src/node_manager/tests/test_register_validators.py b/src/node_manager/tests/test_register_validators.py new file mode 100644 index 00000000..7d0d5546 --- /dev/null +++ b/src/node_manager/tests/test_register_validators.py @@ -0,0 +1,185 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from eth_typing import ChecksumAddress +from sw_utils.tests.factories import faker +from sw_utils.typings import Bytes32 +from web3 import Web3 +from web3.exceptions import ContractLogicError + +from src.node_manager.register_validators import register_validators +from src.node_manager.typings import NodeManagerRegistrationOraclesApproval + +MODULE = 'src.node_manager.register_validators' + +OPERATOR_ADDR: ChecksumAddress = faker.eth_address() + + +@pytest.mark.usefixtures('fake_settings') +class TestRegisterValidators: + @patch(f'{MODULE}.get_validators_start_index', new_callable=AsyncMock, return_value=5) + @patch(f'{MODULE}.validators_registry_contract') + async def test_returns_none_on_root_change( + self, + mock_registry: MagicMock, + mock_index: AsyncMock, + ) -> None: + """If registry root changed since approval, return None.""" + mock_registry.get_registry_root = AsyncMock(return_value=faker.merkle_root()) + + result = await register_validators( + operator_address=OPERATOR_ADDR, + approval=_make_approval(), + validators=[], + validators_registry_root=faker.merkle_root(), + validator_index=5, + ) + assert result is None + + @patch(f'{MODULE}.get_validators_start_index', new_callable=AsyncMock, return_value=10) + @patch(f'{MODULE}.validators_registry_contract') + async def test_returns_none_on_index_change( + self, + mock_registry: MagicMock, + mock_index: AsyncMock, + ) -> None: + """If validator index changed, return None.""" + root = faker.merkle_root() + mock_registry.get_registry_root = AsyncMock(return_value=root) + + result = await register_validators( + operator_address=OPERATOR_ADDR, + approval=_make_approval(), + validators=[], + validators_registry_root=root, + validator_index=5, # doesn't match mock return of 10 + ) + assert result is None + + @patch(f'{MODULE}.encode_tx_validator_list', return_value=[b'\x00' * 100]) + @patch(f'{MODULE}.get_validators_start_index', new_callable=AsyncMock, return_value=5) + @patch(f'{MODULE}.validators_registry_contract') + async def test_success( + self, + mock_registry: MagicMock, + mock_index: AsyncMock, + mock_encode: MagicMock, + ) -> None: + root = faker.merkle_root() + mock_registry.get_registry_root = AsyncMock(return_value=root) + approval = _make_approval() + + mock_fn = MagicMock() + mock_fn.return_value.estimate_gas = AsyncMock() + mock_fn.return_value.transact = AsyncMock(return_value=b'\x01' * 32) + + mock_contract = MagicMock() + mock_contract.functions.registerValidators = mock_fn + + gm = MagicMock() + gm.get_high_priority_tx_params = AsyncMock(return_value={}) + + mock_exec = MagicMock() + mock_exec.eth.wait_for_transaction_receipt = AsyncMock(return_value={'status': 1}) + + with ( + patch(f'{MODULE}.nodes_manager_contract', mock_contract), + patch(f'{MODULE}.build_gas_manager', return_value=gm), + patch(f'{MODULE}.execution_client', mock_exec), + ): + result = await register_validators( + operator_address=OPERATOR_ADDR, + approval=approval, + validators=[], + validators_registry_root=root, + validator_index=5, + ) + assert result is not None + + expected_signatures = b''.join( + Web3.to_bytes(hexstr=s) for s in approval.nodes_manager_signatures + ) + expected_keeper_params = ( + Bytes32(Web3.to_bytes(hexstr=root)), + approval.deadline, + b''.join(Web3.to_bytes(v) for v in mock_encode.return_value), + b''.join(Web3.to_bytes(hexstr=s) for s in approval.keeper_signatures), + approval.ipfs_hash, + ) + mock_fn.assert_called_with(OPERATOR_ADDR, expected_keeper_params, expected_signatures) + + @patch(f'{MODULE}.encode_tx_validator_list', return_value=[b'\x00' * 100]) + @patch(f'{MODULE}.get_validators_start_index', new_callable=AsyncMock, return_value=5) + @patch(f'{MODULE}.validators_registry_contract') + async def test_gas_estimation_error_returns_none( + self, + mock_registry: MagicMock, + mock_index: AsyncMock, + mock_encode: MagicMock, + ) -> None: + root = faker.merkle_root() + mock_registry.get_registry_root = AsyncMock(return_value=root) + + mock_fn = MagicMock() + mock_fn.return_value.estimate_gas = AsyncMock(side_effect=ContractLogicError('revert')) + + mock_contract = MagicMock() + mock_contract.functions.registerValidators = mock_fn + + with patch(f'{MODULE}.nodes_manager_contract', mock_contract): + result = await register_validators( + operator_address=OPERATOR_ADDR, + approval=_make_approval(), + validators=[], + validators_registry_root=root, + validator_index=5, + ) + assert result is None + + @patch(f'{MODULE}.encode_tx_validator_list', return_value=[b'\x00' * 100]) + @patch(f'{MODULE}.get_validators_start_index', new_callable=AsyncMock, return_value=5) + @patch(f'{MODULE}.validators_registry_contract') + async def test_failed_tx_receipt_returns_none( + self, + mock_registry: MagicMock, + mock_index: AsyncMock, + mock_encode: MagicMock, + ) -> None: + root = faker.merkle_root() + mock_registry.get_registry_root = AsyncMock(return_value=root) + + mock_fn = MagicMock() + mock_fn.return_value.estimate_gas = AsyncMock() + mock_fn.return_value.transact = AsyncMock(return_value=b'\x01' * 32) + + mock_contract = MagicMock() + mock_contract.functions.registerValidators = mock_fn + + gm = MagicMock() + gm.get_high_priority_tx_params = AsyncMock(return_value={}) + + mock_exec = MagicMock() + mock_exec.eth.wait_for_transaction_receipt = AsyncMock(return_value={'status': 0}) + + with ( + patch(f'{MODULE}.nodes_manager_contract', mock_contract), + patch(f'{MODULE}.build_gas_manager', return_value=gm), + patch(f'{MODULE}.execution_client', mock_exec), + ): + result = await register_validators( + operator_address=OPERATOR_ADDR, + approval=_make_approval(), + validators=[], + validators_registry_root=root, + validator_index=5, + ) + assert result is None + + +def _make_approval() -> NodeManagerRegistrationOraclesApproval: + return NodeManagerRegistrationOraclesApproval( + keeper_signatures=[faker.account_signature()], + nodes_manager_signatures=[faker.account_signature()], + ipfs_hash=faker.ipfs_hash(), + deadline=1000, + ) diff --git a/src/node_manager/tests/test_tasks.py b/src/node_manager/tests/test_tasks.py new file mode 100644 index 00000000..b1752503 --- /dev/null +++ b/src/node_manager/tests/test_tasks.py @@ -0,0 +1,371 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from eth_typing import ChecksumAddress +from sw_utils.tests.factories import faker +from sw_utils.typings import ProtocolConfig +from web3 import Web3 +from web3.types import Gwei, Wei + +from src.common.tests.utils import ether_to_gwei +from src.common.typings import ValidatorType +from src.config.settings import settings +from src.node_manager.tasks import NodeManagerTask +from src.node_manager.typings import ( + EligibleOperator, + NodeManagerApprovalRequest, + NodeManagerRegistrationOraclesApproval, +) +from src.validators.typings import Validator + +OPERATOR_ADDR: ChecksumAddress = faker.eth_address() +OTHER_ADDR: ChecksumAddress = faker.eth_address() + +MODULE = 'src.node_manager.tasks' + + +@pytest.mark.usefixtures('fake_settings') +class TestProcessBlock: + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=False) + async def test_skips_when_gas_too_high(self, mock_gas: AsyncMock) -> None: + task = _make_task() + interrupt = MagicMock() + + await task.process_block(interrupt) + + mock_gas.assert_awaited_once_with(high_priority=True) + + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock, return_value=[]) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_no_eligible_operators( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + ) -> None: + task = _make_task() + await task.process_block(MagicMock()) + mock_poll.assert_awaited_once() + + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_skips_other_operator( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + ) -> None: + """Operators not matching operator_address are skipped.""" + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OTHER_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + task = _make_task() + await task.process_block(MagicMock()) + # Should not proceed to registration for other operator + mock_poll.assert_awaited_once() + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock, return_value='0xtxhash') + @patch(f'{MODULE}.poll_registration_approval', new_callable=AsyncMock) + @patch(f'{MODULE}.get_validators_for_registration', new_callable=AsyncMock) + @patch(f'{MODULE}.get_deposits_amounts') + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_full_registration_flow( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + mock_deposits: MagicMock, + mock_get_validators: AsyncMock, + mock_poll_reg: AsyncMock, + mock_register: AsyncMock, + ) -> None: + """Full flow: eligible operator → stub funding → register new validators.""" + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OPERATOR_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + mock_deposits.return_value = [ether_to_gwei(32)] + + validator = Validator( + public_key=faker.validator_public_key(), + amount=ether_to_gwei(32), + deposit_signature=faker.validator_signature(), + ) + mock_get_validators.return_value = [validator] + + request = MagicMock(spec=NodeManagerApprovalRequest) + request.validators_root = faker.merkle_root() + request.validator_index = 0 + approval = MagicMock(spec=NodeManagerRegistrationOraclesApproval) + mock_poll_reg.return_value = (request, approval) + + task = _make_task() + await task.process_block(MagicMock()) + + mock_register.assert_awaited_once() + + +@pytest.mark.usefixtures('fake_settings') +class TestProcessBlockV1: + """V1 validator type: skips funding, goes straight to registration.""" + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock, return_value='0xtxhash') + @patch(f'{MODULE}.poll_registration_approval', new_callable=AsyncMock) + @patch(f'{MODULE}.get_validators_for_registration', new_callable=AsyncMock) + @patch(f'{MODULE}.get_deposits_amounts') + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_v1_skips_funding_and_registers( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + mock_deposits: MagicMock, + mock_get_validators: AsyncMock, + mock_poll_reg: AsyncMock, + mock_register: AsyncMock, + ) -> None: + settings.validator_type = ValidatorType.V1 + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OPERATOR_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + mock_deposits.return_value = [ether_to_gwei(32)] + + validator = Validator( + public_key=faker.validator_public_key(), + amount=ether_to_gwei(32), + deposit_signature=faker.validator_signature(), + ) + mock_get_validators.return_value = [validator] + + request = MagicMock(spec=NodeManagerApprovalRequest) + request.validators_root = faker.merkle_root() + request.validator_index = 0 + approval = MagicMock(spec=NodeManagerRegistrationOraclesApproval) + mock_poll_reg.return_value = (request, approval) + + task = _make_task() + await task.process_block(MagicMock()) + + mock_register.assert_awaited_once() + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock) + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_v1_disable_registration_skips( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + mock_register: AsyncMock, + ) -> None: + """V1 with disable_validators_registration returns early.""" + settings.validator_type = ValidatorType.V1 + settings.disable_validators_registration = True + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OPERATOR_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + + task = _make_task() + await task.process_block(MagicMock()) + + mock_register.assert_not_awaited() + + +@pytest.mark.usefixtures('fake_settings') +class TestProcessBlockV2: + """V2 validator type: calls funding first, then registration.""" + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock, return_value='0xtxhash') + @patch(f'{MODULE}.poll_registration_approval', new_callable=AsyncMock) + @patch(f'{MODULE}.get_validators_for_registration', new_callable=AsyncMock) + @patch(f'{MODULE}.get_deposits_amounts') + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_v2_calls_funding_then_registration( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + mock_deposits: MagicMock, + mock_get_validators: AsyncMock, + mock_poll_reg: AsyncMock, + mock_register: AsyncMock, + ) -> None: + settings.validator_type = ValidatorType.V2 + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OPERATOR_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + mock_deposits.return_value = [ether_to_gwei(32)] + + validator = Validator( + public_key=faker.validator_public_key(), + amount=ether_to_gwei(32), + deposit_signature=faker.validator_signature(), + ) + mock_get_validators.return_value = [validator] + + request = MagicMock(spec=NodeManagerApprovalRequest) + request.validators_root = faker.merkle_root() + request.validator_index = 0 + approval = MagicMock(spec=NodeManagerRegistrationOraclesApproval) + mock_poll_reg.return_value = (request, approval) + + task = _make_task() + await task.process_block(MagicMock()) + + mock_register.assert_awaited_once() + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_deposits_amounts', return_value=[]) + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_v2_disable_funding_skips_to_registration( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + mock_deposits: MagicMock, + mock_register: AsyncMock, + ) -> None: + """V2 with disable_validators_funding skips funding but still tries registration.""" + settings.validator_type = ValidatorType.V2 + settings.disable_validators_funding = True + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OPERATOR_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + + task = _make_task() + await task.process_block(MagicMock()) + + # Registration path was entered (get_deposits_amounts called with original amount) + mock_deposits.assert_called_once() + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_deposits_amounts') + @patch(f'{MODULE}.poll_eligible_operators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_protocol_config', new_callable=AsyncMock) + @patch(f'{MODULE}.check_gas_price', new_callable=AsyncMock, return_value=True) + async def test_v2_disable_registration_skips_registration( + self, + mock_gas: AsyncMock, + mock_config: AsyncMock, + mock_poll: AsyncMock, + mock_deposits: MagicMock, + mock_register: AsyncMock, + ) -> None: + """V2 with disable_validators_registration skips registration after funding.""" + settings.validator_type = ValidatorType.V2 + settings.disable_validators_registration = True + mock_config.return_value = _make_protocol_config() + mock_poll.return_value = [ + EligibleOperator(address=OPERATOR_ADDR, amount=Wei(Web3.to_wei(32, 'ether'))), + ] + + task = _make_task() + await task.process_block(MagicMock()) + + mock_deposits.assert_not_called() + mock_register.assert_not_awaited() + + +@pytest.mark.usefixtures('fake_settings') +class TestProcessFunding: + async def test_stub_returns_full_amount(self) -> None: + """Current stub _process_funding returns full amount unchanged.""" + task = _make_task() + result = await task._process_funding( + amount=Gwei(64000000000), + operator_address=OPERATOR_ADDR, + protocol_config=_make_protocol_config(), + ) + assert result == Gwei(64000000000) + + +@pytest.mark.usefixtures('fake_settings') +class TestProcessRegistration: + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_deposits_amounts', return_value=[]) + async def test_no_amounts_skips( + self, + mock_deposits: MagicMock, + mock_register: AsyncMock, + ) -> None: + task = _make_task() + await task._process_registration( + amount=Gwei(100), + protocol_config=_make_protocol_config(), + ) + mock_register.assert_not_awaited() + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_validators_for_registration', new_callable=AsyncMock, return_value=[]) + @patch(f'{MODULE}.get_deposits_amounts', return_value=[ether_to_gwei(32)]) + async def test_no_validators_skips( + self, + mock_deposits: MagicMock, + mock_get_validators: AsyncMock, + mock_register: AsyncMock, + ) -> None: + task = _make_task() + await task._process_registration( + amount=ether_to_gwei(32), + protocol_config=_make_protocol_config(), + ) + mock_get_validators.assert_awaited_once() + mock_register.assert_not_awaited() + + @patch(f'{MODULE}.register_validators', new_callable=AsyncMock) + @patch(f'{MODULE}.get_validators_for_registration', new_callable=AsyncMock, return_value=[]) + @patch(f'{MODULE}.get_deposits_amounts') + async def test_amounts_truncated_to_batch_limit( + self, + mock_deposits: MagicMock, + mock_get_validators: AsyncMock, + mock_register: AsyncMock, + ) -> None: + """Amounts list is truncated to validators_approval_batch_limit.""" + mock_deposits.return_value = [ether_to_gwei(32)] * 5 + + config = _make_protocol_config() + config.validators_approval_batch_limit = 2 + + task = _make_task() + await task._process_registration( + amount=ether_to_gwei(160), + protocol_config=config, + ) + # get_validators_for_registration receives only 2 amounts (truncated from 5) + call_args = mock_get_validators.call_args + amounts_arg = call_args.kwargs.get( + 'amounts', call_args.args[1] if len(call_args.args) > 1 else None + ) + assert len(amounts_arg) == 2 + mock_register.assert_not_awaited() + + +def _make_protocol_config() -> MagicMock: + config = MagicMock(spec=ProtocolConfig) + config.validators_approval_batch_limit = 10 + return config + + +def _make_task() -> NodeManagerTask: + keystore = MagicMock() + return NodeManagerTask( + operator_address=OPERATOR_ADDR, + keystore=keystore, + ) diff --git a/src/node_manager/typings.py b/src/node_manager/typings.py index 91c4d7c7..46ae284e 100644 --- a/src/node_manager/typings.py +++ b/src/node_manager/typings.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from eth_typing import ChecksumAddress +from eth_typing import ChecksumAddress, HexStr from web3.types import Wei @@ -10,3 +10,40 @@ class EligibleOperator: address: ChecksumAddress amount: Wei + + +@dataclass +# pylint: disable-next=too-many-instance-attributes +class NodeManagerApprovalRequest: + """Approval request for NodesManager validator registration.""" + + validator_index: int + operator_address: ChecksumAddress + validators_root: HexStr + public_keys: list[HexStr] + deposit_signatures: list[HexStr] + public_key_shards: list[list[HexStr]] + exit_signature_shards: list[list[HexStr]] + deadline: int + amounts: list[int] + validators_manager_signature: HexStr + + +@dataclass +class NodeManagerRegistrationApproval: + """Single oracle's registration response with both signature types.""" + + nodes_manager_signature: HexStr + keeper_signature: HexStr + ipfs_hash: str + deadline: int + + +@dataclass +class NodeManagerRegistrationOraclesApproval: + """Combined registration approval from multiple oracles.""" + + nodes_manager_signatures: list[HexStr] + keeper_signatures: list[HexStr] + ipfs_hash: str + deadline: int diff --git a/src/validators/tasks.py b/src/validators/tasks.py index 809cc1df..169d7a6b 100644 --- a/src/validators/tasks.py +++ b/src/validators/tasks.py @@ -100,7 +100,7 @@ async def process_funding( Raises FundingException on failure. """ compounding_validators_balances = await fetch_compounding_validators_balances() - funding_amounts = _get_funding_amounts( + funding_amounts = get_funding_amounts( compounding_validators_balances=compounding_validators_balances, vault_assets=vault_assets, ) @@ -184,7 +184,7 @@ async def register_new_validators( keystore: BaseKeystore | None, relayer: RelayerClient | None = None, ) -> HexStr | None: - validators_amounts = _get_deposits_amounts(vault_assets, settings.validator_type) + validators_amounts = get_deposits_amounts(vault_assets, settings.validator_type) validators_count = len(validators_amounts) if not validators_amounts: # not enough balance to register validators @@ -302,7 +302,7 @@ async def load_genesis_validators() -> None: logger.info('Loaded %d genesis validators', len(genesis_validators)) -def _get_deposits_amounts(vault_assets: Gwei, validator_type: ValidatorType) -> list[Gwei]: +def get_deposits_amounts(vault_assets: Gwei, validator_type: ValidatorType) -> list[Gwei]: """Returns a list of amounts in Gwei for each validator to be registered.""" if vault_assets < MIN_ACTIVATION_BALANCE_GWEI: return [] @@ -317,7 +317,7 @@ def _get_deposits_amounts(vault_assets: Gwei, validator_type: ValidatorType) -> return amounts -def _get_funding_amounts( +def get_funding_amounts( compounding_validators_balances: dict[HexStr, Gwei], vault_assets: Gwei ) -> dict[HexStr, Gwei]: result = {} diff --git a/src/validators/tests/test_tasks.py b/src/validators/tests/test_tasks.py index e1c6ec55..e5335b93 100644 --- a/src/validators/tests/test_tasks.py +++ b/src/validators/tests/test_tasks.py @@ -14,75 +14,75 @@ from src.validators.exceptions import FundingException from src.validators.tasks import ( ValidatorRegistrationSubtask, - _get_deposits_amounts, - _get_funding_amounts, + get_deposits_amounts, + get_funding_amounts, ) from src.validators.typings import VaultValidator @pytest.mark.usefixtures('fake_settings') def test_get_deposits_amounts(): - assert _get_deposits_amounts(0, ValidatorType.V1) == [] - assert _get_deposits_amounts(0, ValidatorType.V2) == [] + assert get_deposits_amounts(0, ValidatorType.V1) == [] + assert get_deposits_amounts(0, ValidatorType.V2) == [] - assert _get_deposits_amounts(ether_to_gwei(32), ValidatorType.V1) == [ + assert get_deposits_amounts(ether_to_gwei(32), ValidatorType.V1) == [ MIN_ACTIVATION_BALANCE_GWEI ] - assert _get_deposits_amounts(ether_to_gwei(32), ValidatorType.V2) == [ + assert get_deposits_amounts(ether_to_gwei(32), ValidatorType.V2) == [ MIN_ACTIVATION_BALANCE_GWEI ] - assert _get_deposits_amounts(ether_to_gwei(33), ValidatorType.V1) == [ + assert get_deposits_amounts(ether_to_gwei(33), ValidatorType.V1) == [ MIN_ACTIVATION_BALANCE_GWEI, ] - assert _get_deposits_amounts(ether_to_gwei(33), ValidatorType.V2) == [ether_to_gwei(33)] + assert get_deposits_amounts(ether_to_gwei(33), ValidatorType.V2) == [ether_to_gwei(33)] - assert _get_deposits_amounts(ether_to_gwei(64), ValidatorType.V1) == [ + assert get_deposits_amounts(ether_to_gwei(64), ValidatorType.V1) == [ MIN_ACTIVATION_BALANCE_GWEI, MIN_ACTIVATION_BALANCE_GWEI, ] - assert _get_deposits_amounts(ether_to_gwei(64), ValidatorType.V2) == [ether_to_gwei(64)] + assert get_deposits_amounts(ether_to_gwei(64), ValidatorType.V2) == [ether_to_gwei(64)] - assert _get_deposits_amounts(ether_to_gwei(66), ValidatorType.V1) == [ + assert get_deposits_amounts(ether_to_gwei(66), ValidatorType.V1) == [ MIN_ACTIVATION_BALANCE_GWEI, MIN_ACTIVATION_BALANCE_GWEI, ] - assert _get_deposits_amounts(ether_to_gwei(66), ValidatorType.V2) == [ether_to_gwei(66)] + assert get_deposits_amounts(ether_to_gwei(66), ValidatorType.V2) == [ether_to_gwei(66)] assert ( - _get_deposits_amounts(ether_to_gwei(2048), ValidatorType.V1) + get_deposits_amounts(ether_to_gwei(2048), ValidatorType.V1) == [MIN_ACTIVATION_BALANCE_GWEI] * 64 ) - assert _get_deposits_amounts(settings.max_validator_balance_gwei, ValidatorType.V2) == [ + assert get_deposits_amounts(settings.max_validator_balance_gwei, ValidatorType.V2) == [ settings.max_validator_balance_gwei, ] - assert _get_deposits_amounts(ether_to_gwei(2048), ValidatorType.V2) == [ + assert get_deposits_amounts(ether_to_gwei(2048), ValidatorType.V2) == [ settings.max_validator_balance_gwei, ether_to_gwei(2048) - settings.max_validator_balance_gwei, ] assert ( - _get_deposits_amounts(ether_to_gwei(2050), ValidatorType.V1) + get_deposits_amounts(ether_to_gwei(2050), ValidatorType.V1) == [MIN_ACTIVATION_BALANCE_GWEI] * 64 ) assert ( - _get_deposits_amounts(ether_to_gwei(2081), ValidatorType.V1) + get_deposits_amounts(ether_to_gwei(2081), ValidatorType.V1) == [MIN_ACTIVATION_BALANCE_GWEI] * 65 ) - assert _get_deposits_amounts(ether_to_gwei(2050), ValidatorType.V2) == [ + assert get_deposits_amounts(ether_to_gwei(2050), ValidatorType.V2) == [ settings.max_validator_balance_gwei, ether_to_gwei(2050) - settings.max_validator_balance_gwei, ] - assert _get_deposits_amounts(ether_to_gwei(2081), ValidatorType.V2) == [ + assert get_deposits_amounts(ether_to_gwei(2081), ValidatorType.V2) == [ settings.max_validator_balance_gwei, ether_to_gwei(2081) - settings.max_validator_balance_gwei, ] assert ( - _get_deposits_amounts(ether_to_gwei(4096), ValidatorType.V1) + get_deposits_amounts(ether_to_gwei(4096), ValidatorType.V1) == [MIN_ACTIVATION_BALANCE_GWEI] * 128 ) - assert _get_deposits_amounts(settings.max_validator_balance_gwei * 2, ValidatorType.V2) == [ + assert get_deposits_amounts(settings.max_validator_balance_gwei * 2, ValidatorType.V2) == [ settings.max_validator_balance_gwei, settings.max_validator_balance_gwei, ] @@ -93,13 +93,13 @@ def test_get_funding_amounts(data_dir): public_key_1 = faker.eth_address() public_key_2 = faker.eth_address() - data = _get_funding_amounts({public_key_1: ether_to_gwei(32)}, vault_assets=ether_to_gwei(1)) + data = get_funding_amounts({public_key_1: ether_to_gwei(32)}, vault_assets=ether_to_gwei(1)) assert data == {public_key_1: ether_to_gwei(1)} - data = _get_funding_amounts({public_key_1: ether_to_gwei(32)}, vault_assets=ether_to_gwei(100)) + data = get_funding_amounts({public_key_1: ether_to_gwei(32)}, vault_assets=ether_to_gwei(100)) assert data == {public_key_1: ether_to_gwei(100)} - data = _get_funding_amounts( + data = get_funding_amounts( {public_key_1: ether_to_gwei(32), public_key_2: ether_to_gwei(33)}, vault_assets=ether_to_gwei(2100), ) @@ -108,7 +108,7 @@ def test_get_funding_amounts(data_dir): public_key_1: ether_to_gwei(188), } - data = _get_funding_amounts( + data = get_funding_amounts( {public_key_1: ether_to_gwei(1934), public_key_2: ether_to_gwei(32)}, vault_assets=ether_to_gwei(11.5), ) @@ -116,7 +116,7 @@ def test_get_funding_amounts(data_dir): public_key_1: ether_to_gwei(11), } - data = _get_funding_amounts( + data = get_funding_amounts( {public_key_1: ether_to_gwei(32), public_key_2: ether_to_gwei(33)}, vault_assets=ether_to_gwei(2100.5), )