This repository explores how neural networks can adapt online to changing environments by learning fast, plastic connectivity rather than just memorising a fixed dataset.
The setting is intentionally simple but expressive: we generate random graphs, walk on them, and ask a model to predict the next node given the current node and an action (edge label).
Two model families are compared:
- A standard RNN-based graph sequence model (
RNNGraphNetwork) - A plastic fast-weight network (
PlasticGraphNetwork) that updates an internal connectivity matrix during inference
To rigorously test generalization—even on relatively large graphs (50 nodes, 15 actions), we constructed a dataset where structural triplets (node, action, next-node) in the test set were strictly held out from training. In this setting, the standard RNN largely memorised the training distribution and failed to generalise. The plastic architecture, by contrast, achieved near 100% accuracy on unseen graphs and adapted in real time when the underlying graph topology was switched mid-sequence, effectively learning the new structure on the fly.
At a high level, the plastic network learns to build and update a “lookup table” in its fast weights that maps each (current node, action) pair to the most likely next node, purely from online experience.
For concrete examples and plots from representative experiments (RNN vs plastic; standard vs disjoint; small vs large graphs), see the PDF run reports in reports/.
-
We work with undirected graphs
G = (V, E):- Nodes:
V = {0, ..., num_nodes-1} - Edges
(u, v)carry an action ida ∈ {0, ..., num_actions-1}.
- Nodes:
-
A trajectory is generated by randomly walking on the graph:
- At each timestep
t, we are at nodex_t. - We sample one of its neighbours
x_{t+1}uniformly. - The edge
(x_t, x_{t+1})has an attachedaction_id, which we encode as a one-hot action vector.
- At each timestep
-
The model sees the sequence:
- Input at timestep
t:[one_hot(x_t) || one_hot(action_t)] - Target: the next node
x_{t+1}, again as a one-hot over nodes.
- Input at timestep
So the core prediction task is:
Given
(current node, action), predict the next node.
This structure ensures that if action labels around a node are unique, the mapping (node, action) → next node is well-defined.
All datasets are built on top of BaseGraphDataset (core/base/dataset.py), which:
-
Generates graphs using a configurable
BaseGraphGenerator(by default, a NetworkX random graph). -
For each graph:
- Runs multiple random explorations.
- Converts each exploration into a tensor of shape
[num_timesteps, num_nodes + num_actions].
-
Concatenates samples across graphs into:
train_data: shape[N_train, T, num_nodes + num_actions]test_data: shape[N_test, T, num_nodes + num_actions]
-
Optionally plots:
- Example train/test graphs.
- Per-node-pair action histograms.
- Triplet statistics and constraints.
dataset/standard.py
This is the “easy” dataset where train and test share the same action semantics.
Key properties:
-
Graphs are sampled from
gnp_random_graph(num_nodes, edge_prob), conditioned to:- Be connected.
- Have at least one edge.
-
Edges are labelled with actions such that:
- Around each node
u, the actions on edges incident touare unique. - This makes
(u, action)a deterministic key for the neighbour.
- Around each node
-
The same action space is shared across:
- All training graphs
- All test graphs
Effectively, there is a single global mapping notion: “Action 3 from node 7 on graph A plays the same semantic role as action 3 from node 7 on graph B.”
dataset/disjoint.py
This is the “hard” dataset designed to break global memorisation strategies.
Key idea: train and test do not share (u, action, v) triplets.
Mechanism:
-
For each unordered node pair
{u, v}:-
The action set
{0, ..., num_actions-1}is randomly partitioned into:- A train-only subset of size
k_train - A test-only subset of size
k_test
- A train-only subset of size
-
-
For train split:
- All edges between
{u, v}must use actions from the train subset.
- All edges between
-
For test split:
- Edges between
{u, v}must use actions from the test subset.
- Edges between
-
Additional constraints guarantee:
- Graphs are connected and have bounded degree vs. number of actions.
- Each node still has unique actions on its incident edges.
This means:
Any
(node, action, next-node)pattern seen in training cannot appear in the test graphs.
A model that relies on a global mapping from action ids to structural meaning cannot generalise here. To do well, it must truly learn locally, on each graph, as it explores.
All models inherit from BaseGraphNetwork (base/network.py), which provides:
- Deterministic seeding across NumPy, Python, and PyTorch.
- Model config serialization.
- Saving and loading checkpoints.
- A notion of slow parameters (trainable weights) vs fast weights (ephemeral buffers like plastic matrices).
network/rnn.py
A straightforward baseline:
-
Input dimension:
num_nodes + num_actions -
Architecture:
-
nn.RNN(input_dim, hidden_dim, batch_first=True) -
A small MLP decoder:
- Hidden:
hidden_dim → intermediate_dim - Output:
intermediate_dim → num_nodes
- Hidden:
-
-
Forward pass:
- Feed the whole sequence into the RNN:
Z = rnn(X) - Discard the last RNN output (
Z[:, :-1]) to align with “predict next node”. - Decode to logits over nodes at each timestep.
- Feed the whole sequence into the RNN:
-
Loss:
- Cross-entropy between predicted node logits and the ground truth next-node one-hot.
Intuitively, this model has all the expressive power of a generic RNN. It can, in principle, learn:
- A representation of the current node and recent history.
- A mapping from hidden state to next node.
But it has no explicit plasticity mechanism; all structure is stored in fixed weights.
network/plastic.py
This is the core of the project: a fast-weight network with online plasticity.
Key components:
-
Input dimension:
num_nodes + num_actions + 2-
The extra
2bits are mode bits:[1, 0]= prediction mode[0, 1]= acquisition (learning) mode
-
-
Two linear layers:
c1: Linear(input_dim, hidden_dim)ca: Linear(num_nodes, hidden_dim)
-
A plastic weight matrix
Wpof shape[batch_size, hidden_dim, hidden_dim]- Registered as a buffer, not a parameter.
- Initialised with random, row-normalised rows.
-
A decoder MLP:
hidden_dim → intermediate_dim → num_nodes
At each timestep t we do two things:
-
Acquire new experience from
(x_{t-1}, a_{t-1}) → x_t(only fort > 0)-
Construct acquisition input with mode bits
[0, 1]. -
Compute:
- A context vector from previous node + action:
c1_acq = relu(c1(input_acq)) - A node embedding for current node:
ca_out = relu(ca(x_t))
- A context vector from previous node + action:
-
Update
Wpwith an outer product:delta_Wp = torch.bmm( ca_out.unsqueeze(2), # [B, H, 1] c1_acq.unsqueeze(1) # [B, 1, H] ) Wp_current = Wp_current + delta_Wp
-
Row-normalise each fast-weight row to keep things stable.
This acts like a learnable Hebbian rule: “When we see that
(node, action)led to this next node, strengthen the connection between their representations.” -
-
Predict the next node for the current timestep
-
Construct prediction input with mode bits
[1, 0]. -
Compute
c1_pred = relu(c1(input_pred)). -
Apply the fast-weight matrix:
h = torch.bmm(Wp_current, c1_pred.unsqueeze(2)).squeeze(2) h = relu(h) logits = decoder(h)
-
This is where the plastic matrix effectively maps states encoding
(node, action)to the representation of the predicted next node.
-
At the end of the sequence, the updated fast weights for each batch element are written back to self.Wp. There is also a reset_plastic_weights() method called:
- At the start of each training batch
- Before some evaluation phases
so that each new exploration starts from a clean slate of fast weights.
In short, the plastic network learns how to update its own fast-weight lookup table from experience, not the raw table entries themselves.
The orchestration logic lives in utils/run.py.
RunConfig is a dataclass that bundles all knobs:
-
Dataset:
dataset_type:"standard"or"disjoint"num_nodesedge_probnum_samples_per_graphnum_graphsnum_timestepsnum_actionstest_fraction
-
Network:
network_type:"rnn"or"plastic"hidden_dimbatch_size(required for plastic network, ignored for RNN)
-
Training:
epochstrainer_batch_sizelrval_splitscheduler("none"or"cosine")
-
Logging:
wandb_enabledwandb_projectwandb_entitywandb_groupwandb_tags
-
Evaluation:
long_timesteps,long_window_sizeboundary_num_graphs,boundary_repeats_per_graph,boundary_window_size
base/trainer.py
The Trainer class provides a generic training loop:
-
Splits the data into train/validation numpy arrays.
-
Iterates over mini-batches:
- Converts each batch to a PyTorch tensor on the chosen device.
- Runs forward pass, computes cross-entropy loss.
- Backprop, gradient norm clipping (no effective clipping, but norm is logged).
- Tracks accuracy per epoch.
-
Optionally uses a cosine LR scheduler.
-
Keeps track of:
- Train/val loss
- Train/val accuracy
- LR over epochs
- Gradient norms
- Best validation loss and corresponding model state dict
Run.train() does the following:
-
Build dataset (
StandardGraphDatasetorDisjointGraphDataset). -
Log dataset stats and configuration.
-
Save:
- Random train/test graph plots.
- Action-id histograms per node-pair for train/test splits.
-
Check structural constraints using
DatasetUtils:- Triplet overlap fraction (should be zero for disjoint).
- Consistency of
(u, action) → vmapping. - An estimate of the optimal prediction rate using an oracle-like lookup.
-
Build the network (
RNNGraphNetworkorPlasticGraphNetwork). -
Log parameter counts (slow vs fast).
-
Create a
Trainerwith callbacks:- For plastic networks, reset plastic weights at the start of each batch.
- For logging metrics into wandb.
-
Train on the training data.
-
Save:
- Model checkpoint (
model.pth) with model + dataset configs. - Dataset and trainer configs (
dataset_config.json,trainer_config.json,run_config.json). - Training curves figure.
- Model checkpoint (
Run.evaluate() then:
- Computes test next-node accuracy.
- Plots rolling accuracy during a long exploration on a single random test graph.
- Plots rolling accuracy across multiple graphs with visible boundaries, to see how quickly the model adapts after a graph switch.
Run.generate_report() uses RunReportBuilder to produce a minimal LaTeX report (optionally compiled to PDF) summarising:
- Configuration.
- Key metrics.
- Training curves, rolling accuracy, graphs, and histograms.
All run artefacts are written into a fresh directory:
run/
<uuid>/
model.pth
run_config.json
dataset_config.json
trainer_config.json
train_graph.png
test_graph.png
train_action_hist.png
test_action_hist.png
training_curves.png
rolling_accuracy.png
rolling_accuracy_boundaries.png
report.tex
report.pdf (if compilation succeeds)
run.log
-
Python (3.10+ recommended)
-
PyTorch with CUDA if you want GPU acceleration
-
uvPython package manager (for environment and dependency management) -
Optional:
- LaTeX (
tectonicorpdflatex) for PDF report compilation wandbfor experiment tracking
- LaTeX (
From the repository root:
uv syncThis will create a virtual environment and install dependencies specified in your project configuration (for example a pyproject.toml).
You can then run commands inside that environment using:
uv run python main.py --config path/to/config.yamlor, if you prefer manual activation, activate the environment that uv created and call python directly.
The main entry point is main.py:
uv run python main.py --config configs/rnn_standard.yamlor, without uv:
python main.py --config configs/rnn_standard.yamlThis will:
- Parse the YAML config.
- Build a
RunConfig. - Create a
Runobject. - Call
run.train(),run.evaluate(), andrun.generate_report(). - Save all artefacts under
run/<uuid>/.
main.py expects a YAML file structured into sections:
globaldatasetnetworktrainerevaluationlogging
Each section is optional; keys are merged into the RunConfig. The trainer section is a bit special: it is flattened into the top-level config but also allows batch_size to be aliased to trainer_batch_size.
# configs/rnn_standard.yaml
global:
seed: 42
dataset_type: standard # "standard" or "disjoint"
network_type: rnn # "rnn" or "plastic"
dataset:
num_nodes: 20
edge_prob: 0.2
num_samples_per_graph: 64
num_graphs: 200
num_timesteps: 32
num_actions: 8
test_fraction: 0.2
network:
hidden_dim: 64 # RNN hidden dimension
trainer:
epochs: 50
trainer_batch_size: 32 # minibatch size for Trainer
lr: 0.0005
val_split: 0.2
scheduler: cosine # or "none"
evaluation:
long_timesteps: 2000
long_window_size: 100
boundary_num_graphs: 5
boundary_repeats_per_graph: 5
boundary_window_size: 50
logging:
wandb_enabled: false # set true to enable Weights & Biases
wandb_project: "plastic-graphs"
wandb_entity: null
wandb_group: null
wandb_tags: ["rnn", "standard"]Run it with:
uv run python main.py --config configs/rnn_standard.yaml# configs/plastic_disjoint.yaml
global:
seed: 123
dataset_type: disjoint
network_type: plastic
dataset:
num_nodes: 50
edge_prob: 0.1
num_samples_per_graph: 32
num_graphs: 150
num_timesteps: 40
num_actions: 15
test_fraction: 0.2
network:
hidden_dim: 128
trainer:
epochs: 80
trainer_batch_size: 32
lr: 0.0005
val_split: 0.2
scheduler: cosine
evaluation:
long_timesteps: 2000
long_window_size: 100
boundary_num_graphs: 5
boundary_repeats_per_graph: 5
boundary_window_size: 50
logging:
wandb_enabled: true
wandb_project: "plastic-graphs"
wandb_entity: "your_entity"
wandb_group: "disjoint-50n15a"
wandb_tags: ["plastic", "disjoint", "n50", "a15"]Run it with:
uv run python main.py --config configs/plastic_disjoint.yamlYou should then find in the corresponding run/<uuid>/ directory:
- Near-perfect next-node test accuracy.
- Rolling accuracy plots that show quick adaptation when switching between graphs.
- A LaTeX/PDF report summarising the run.
Some natural directions:
-
Try other sequence models (GRU, Transformer) as baselines.
-
Add variants of plastic networks (e.g. gated Hebbian updates, learned decay).
-
Introduce more realistic environments:
- Temporal changes in the graph structure.
- Non-uniform and state-dependent action distributions.
-
Learn the exploration policy itself rather than walking randomly.
The current codebase is intentionally modular:
- New datasets inherit from
BaseGraphDataset. - New models inherit from
BaseGraphNetwork. - Everything plugs into the same
RunandTrainerpipeline.
MIT License - see LICENSE file for details.
This project shows that biologically inspired plasticity mechanisms can give neural networks rapid, online adaptation to novel environments that standard architectures struggle with.