diff --git a/main.py b/main.py index dd51fd5..a72605c 100644 --- a/main.py +++ b/main.py @@ -288,7 +288,7 @@ async def cli_loop(sk, pk, chain, mempool, network): # Main entry point # ────────────────────────────────────────────── -async def run_node(port: int, connect_to: str | None, fund: int, datadir: str | None): +async def run_node(port: int, host: str, connect_to: str | None, fund: int, datadir: str | None): """Boot the node, optionally connect to a peer, then enter the CLI.""" sk, pk = create_wallet() @@ -326,7 +326,7 @@ async def on_peer_connected(writer): await writer.drain() logger.info("🔄 Sent state sync to new peer") - network.set_on_peer_connected(on_peer_connected) + network._on_peer_connected = on_peer_connected await network.start(port=port, host=host) @@ -373,7 +373,7 @@ def main(): ) try: - asyncio.run(run_node(args.port, args.connect, args.fund, args.datadir)) + asyncio.run(run_node(args.port, args.host, args.connect, args.fund, args.datadir)) except KeyboardInterrupt: print("\nNode shut down.") diff --git a/minichain/p2p.py b/minichain/p2p.py index ee52d7d..5e4950a 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -43,7 +43,7 @@ def register_handler(self, handler_callback): raise ValueError("handler_callback must be callable") self._handler_callback = handler_callback - async def start(self, port: int = 9000): + async def start(self, port: int = 9000, host: str = "127.0.0.1"): """Start listening for incoming peer connections on the given port.""" self._port = port self._server = await asyncio.start_server( @@ -206,7 +206,13 @@ def _validate_block_payload(self, payload): def _validate_message(self, message): if not isinstance(message, dict): return False - if set(message) != {"type", "data"}: + # Allow _peer_addr field added by _listen_to_peer + required_fields = {"type", "data"} + if not required_fields.issubset(set(message)): + return False + # Reject messages with unexpected fields (except _peer_addr) + allowed_fields = {"type", "data", "_peer_addr"} + if not set(message).issubset(allowed_fields): return False msg_type = message.get("type")