Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/minimal_chainfl_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def main():
agents=agents,
coordinator=coordinator,
blockchain=blockchain,
rounds=config['experiment']['rounds']
rounds=config['experiment']['rounds'],
scheduler=scheduler,
)
runner.scheduler = scheduler # inject if needed
runner.run()

# Show blockchain
Expand Down
13 changes: 11 additions & 2 deletions src/chainfl/simulator/simulation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ class SimulationRunner:
Runs the ChainFL simulation loop across multiple federated agents.
"""

def __init__(self, agents, coordinator, blockchain, rounds=5):
def __init__(self, agents, coordinator, blockchain, rounds=5, scheduler=None):
"""
Initializes the simulation environment.

Expand All @@ -12,12 +12,15 @@ def __init__(self, agents, coordinator, blockchain, rounds=5):
coordinator (object): Central aggregation and publishing manager.
blockchain (BlockchainSimulator): Ledger instance.
rounds (int): Number of federated training rounds.
scheduler (Scheduler, optional): Strategy that selects
participating agents per round.
"""
self.agents = agents
self.coordinator = coordinator
self.blockchain = blockchain
self.rounds = rounds
self.logs = []
self.scheduler = scheduler

def run(self):
"""
Expand All @@ -27,7 +30,13 @@ def run(self):
print(f"\n🔄 Round {r+1} ------------------")
round_models = []

for agent in self.agents:
participants = (
self.scheduler.select_agents(self.agents, r)
if self.scheduler
else self.agents
)

for agent in participants:
# Local training
X, y = agent.load_data()
agent.trainer.train(X, y)
Expand Down
121 changes: 121 additions & 0 deletions tests/test_simulation_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np

from chainfl.simulator.simulation_runner import SimulationRunner
from chainfl.simulator.scheduler import Scheduler


class StubTrainer:
def __init__(self, agent_id):
self.agent_id = agent_id
self.train_count = 0
self.set_weights_calls = []

def train(self, X, y):
self.train_count += 1

def get_weights(self):
# Return deterministic weights per agent so aggregation is stable
return self.agent_id, float(self.agent_id)

def set_weights(self, coef, intercept):
self.set_weights_calls.append((coef, intercept))


class StubHasher:
def hash_weights(self, coef, intercept):
return f"{coef}-{intercept}".encode()


class StubSigner:
def sign(self, message):
return b"signature"


class StubConsensus:
def validate_block(self, _block):
return True

def simulate_latency(self):
pass


class StubAgent:
def __init__(self, agent_id):
self.agent_id = agent_id
self.trainer = StubTrainer(agent_id)
self.hasher = StubHasher()
self.signer = StubSigner()
self.consensus = StubConsensus()

def load_data(self):
# Data contents are irrelevant for these tests
return np.array([self.agent_id]), np.array([self.agent_id])


class StubBlockchain:
def __init__(self):
self.blocks = []

def add_block(self, block):
self.blocks.append(block)


class StubPublisher:
def __init__(self):
self.published = []

def publish(self, coef, intercept):
self.published.append((coef, intercept))


class StubCoordinator:
def __init__(self):
self.aggregator = AggregatorSpy()
self.publisher = StubPublisher()


class AggregatorSpy:
def __init__(self):
self.seen_models = []

def aggregate(self, models):
self.seen_models.append(models)
coefs = [coef for coef, _ in models]
intercepts = [intercept for _, intercept in models]
return np.mean(coefs, axis=0), np.mean(intercepts, axis=0)


def build_runner(agent_count, rounds, scheduler=None):
agents = [StubAgent(i) for i in range(agent_count)]
coordinator = StubCoordinator()
blockchain = StubBlockchain()
return SimulationRunner(
agents=agents,
coordinator=coordinator,
blockchain=blockchain,
rounds=rounds,
scheduler=scheduler,
), agents, coordinator


def test_all_agents_participate_without_scheduler():
runner, agents, coordinator = build_runner(agent_count=3, rounds=2)

runner.run()

# Every agent should have trained once per round
assert [agent.trainer.train_count for agent in agents] == [2, 2, 2]
# Aggregator should see contributions from each agent per round
assert all(len(models) == len(agents) for models in coordinator.aggregator.seen_models)


def test_scheduler_limits_participation_round_robin():
scheduler = Scheduler(mode="round_robin", sample_ratio=0.5)
runner, agents, coordinator = build_runner(agent_count=4, rounds=3, scheduler=scheduler)

runner.run()

# Round-robin with 4 agents and sample_ratio=0.5 selects 2 agents per round
assert [agent.trainer.train_count for agent in agents] == [2, 2, 1, 1]
# Each aggregation should only use the participating subset
assert all(len(models) == 2 for models in coordinator.aggregator.seen_models)