Skip to content

pellegre/plastic

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Neural Plasticity for Rapid Adaptation in Dynamic Environments

This repository explores how neural networks can adapt online to changing environments by learning fast, plastic connectivity rather than just memorising a fixed dataset.

The setting is intentionally simple but expressive: we generate random graphs, walk on them, and ask a model to predict the next node given the current node and an action (edge label).

Two model families are compared:

  • A standard RNN-based graph sequence model (RNNGraphNetwork)
  • A plastic fast-weight network (PlasticGraphNetwork) that updates an internal connectivity matrix during inference

To rigorously test generalization—even on relatively large graphs (50 nodes, 15 actions), we constructed a dataset where structural triplets (node, action, next-node) in the test set were strictly held out from training. In this setting, the standard RNN largely memorised the training distribution and failed to generalise. The plastic architecture, by contrast, achieved near 100% accuracy on unseen graphs and adapted in real time when the underlying graph topology was switched mid-sequence, effectively learning the new structure on the fly.

At a high level, the plastic network learns to build and update a “lookup table” in its fast weights that maps each (current node, action) pair to the most likely next node, purely from online experience.

For concrete examples and plots from representative experiments (RNN vs plastic; standard vs disjoint; small vs large graphs), see the PDF run reports in reports/.


Conceptual overview

Environment and task

  • We work with undirected graphs G = (V, E):

    • Nodes: V = {0, ..., num_nodes-1}
    • Edges (u, v) carry an action id a ∈ {0, ..., num_actions-1}.
  • A trajectory is generated by randomly walking on the graph:

    • At each timestep t, we are at node x_t.
    • We sample one of its neighbours x_{t+1} uniformly.
    • The edge (x_t, x_{t+1}) has an attached action_id, which we encode as a one-hot action vector.
  • The model sees the sequence:

    • Input at timestep t: [one_hot(x_t) || one_hot(action_t)]
    • Target: the next node x_{t+1}, again as a one-hot over nodes.

So the core prediction task is:

Given (current node, action), predict the next node.

This structure ensures that if action labels around a node are unique, the mapping (node, action) → next node is well-defined.


Datasets

All datasets are built on top of BaseGraphDataset (core/base/dataset.py), which:

  • Generates graphs using a configurable BaseGraphGenerator (by default, a NetworkX random graph).

  • For each graph:

    • Runs multiple random explorations.
    • Converts each exploration into a tensor of shape [num_timesteps, num_nodes + num_actions].
  • Concatenates samples across graphs into:

    • train_data: shape [N_train, T, num_nodes + num_actions]
    • test_data: shape [N_test, T, num_nodes + num_actions]
  • Optionally plots:

    • Example train/test graphs.
    • Per-node-pair action histograms.
    • Triplet statistics and constraints.

StandardGraphDataset

dataset/standard.py

This is the “easy” dataset where train and test share the same action semantics.

Key properties:

  • Graphs are sampled from gnp_random_graph(num_nodes, edge_prob), conditioned to:

    • Be connected.
    • Have at least one edge.
  • Edges are labelled with actions such that:

    • Around each node u, the actions on edges incident to u are unique.
    • This makes (u, action) a deterministic key for the neighbour.
  • The same action space is shared across:

    • All training graphs
    • All test graphs

Effectively, there is a single global mapping notion: “Action 3 from node 7 on graph A plays the same semantic role as action 3 from node 7 on graph B.”

DisjointGraphDataset

dataset/disjoint.py

This is the “hard” dataset designed to break global memorisation strategies.

Key idea: train and test do not share (u, action, v) triplets.

Mechanism:

  • For each unordered node pair {u, v}:

    • The action set {0, ..., num_actions-1} is randomly partitioned into:

      • A train-only subset of size k_train
      • A test-only subset of size k_test
  • For train split:

    • All edges between {u, v} must use actions from the train subset.
  • For test split:

    • Edges between {u, v} must use actions from the test subset.
  • Additional constraints guarantee:

    • Graphs are connected and have bounded degree vs. number of actions.
    • Each node still has unique actions on its incident edges.

This means:

Any (node, action, next-node) pattern seen in training cannot appear in the test graphs.

A model that relies on a global mapping from action ids to structural meaning cannot generalise here. To do well, it must truly learn locally, on each graph, as it explores.


Models

All models inherit from BaseGraphNetwork (base/network.py), which provides:

  • Deterministic seeding across NumPy, Python, and PyTorch.
  • Model config serialization.
  • Saving and loading checkpoints.
  • A notion of slow parameters (trainable weights) vs fast weights (ephemeral buffers like plastic matrices).

RNNGraphNetwork

network/rnn.py

A straightforward baseline:

  • Input dimension: num_nodes + num_actions

  • Architecture:

    • nn.RNN(input_dim, hidden_dim, batch_first=True)

    • A small MLP decoder:

      • Hidden: hidden_dim → intermediate_dim
      • Output: intermediate_dim → num_nodes
  • Forward pass:

    • Feed the whole sequence into the RNN: Z = rnn(X)
    • Discard the last RNN output (Z[:, :-1]) to align with “predict next node”.
    • Decode to logits over nodes at each timestep.
  • Loss:

    • Cross-entropy between predicted node logits and the ground truth next-node one-hot.

Intuitively, this model has all the expressive power of a generic RNN. It can, in principle, learn:

  • A representation of the current node and recent history.
  • A mapping from hidden state to next node.

But it has no explicit plasticity mechanism; all structure is stored in fixed weights.

PlasticGraphNetwork

network/plastic.py

This is the core of the project: a fast-weight network with online plasticity.

Key components:

  • Input dimension: num_nodes + num_actions + 2

    • The extra 2 bits are mode bits:

      • [1, 0] = prediction mode
      • [0, 1] = acquisition (learning) mode
  • Two linear layers:

    • c1: Linear(input_dim, hidden_dim)
    • ca: Linear(num_nodes, hidden_dim)
  • A plastic weight matrix Wp of shape [batch_size, hidden_dim, hidden_dim]

    • Registered as a buffer, not a parameter.
    • Initialised with random, row-normalised rows.
  • A decoder MLP:

    • hidden_dim → intermediate_dim → num_nodes

Online update rule

At each timestep t we do two things:

  1. Acquire new experience from (x_{t-1}, a_{t-1}) → x_t (only for t > 0)

    • Construct acquisition input with mode bits [0, 1].

    • Compute:

      • A context vector from previous node + action: c1_acq = relu(c1(input_acq))
      • A node embedding for current node: ca_out = relu(ca(x_t))
    • Update Wp with an outer product:

      delta_Wp = torch.bmm(
          ca_out.unsqueeze(2),  # [B, H, 1]
          c1_acq.unsqueeze(1)   # [B, 1, H]
      )
      Wp_current = Wp_current + delta_Wp
    • Row-normalise each fast-weight row to keep things stable.

    This acts like a learnable Hebbian rule: “When we see that (node, action) led to this next node, strengthen the connection between their representations.”

  2. Predict the next node for the current timestep

    • Construct prediction input with mode bits [1, 0].

    • Compute c1_pred = relu(c1(input_pred)).

    • Apply the fast-weight matrix:

      h = torch.bmm(Wp_current, c1_pred.unsqueeze(2)).squeeze(2)
      h = relu(h)
      logits = decoder(h)
    • This is where the plastic matrix effectively maps states encoding (node, action) to the representation of the predicted next node.

At the end of the sequence, the updated fast weights for each batch element are written back to self.Wp. There is also a reset_plastic_weights() method called:

  • At the start of each training batch
  • Before some evaluation phases

so that each new exploration starts from a clean slate of fast weights.

In short, the plastic network learns how to update its own fast-weight lookup table from experience, not the raw table entries themselves.


Training and evaluation pipeline

The orchestration logic lives in utils/run.py.

RunConfig

RunConfig is a dataclass that bundles all knobs:

  • Dataset:

    • dataset_type: "standard" or "disjoint"
    • num_nodes
    • edge_prob
    • num_samples_per_graph
    • num_graphs
    • num_timesteps
    • num_actions
    • test_fraction
  • Network:

    • network_type: "rnn" or "plastic"
    • hidden_dim
    • batch_size (required for plastic network, ignored for RNN)
  • Training:

    • epochs
    • trainer_batch_size
    • lr
    • val_split
    • scheduler ("none" or "cosine")
  • Logging:

    • wandb_enabled
    • wandb_project
    • wandb_entity
    • wandb_group
    • wandb_tags
  • Evaluation:

    • long_timesteps, long_window_size
    • boundary_num_graphs, boundary_repeats_per_graph, boundary_window_size

Trainer

base/trainer.py

The Trainer class provides a generic training loop:

  • Splits the data into train/validation numpy arrays.

  • Iterates over mini-batches:

    • Converts each batch to a PyTorch tensor on the chosen device.
    • Runs forward pass, computes cross-entropy loss.
    • Backprop, gradient norm clipping (no effective clipping, but norm is logged).
    • Tracks accuracy per epoch.
  • Optionally uses a cosine LR scheduler.

  • Keeps track of:

    • Train/val loss
    • Train/val accuracy
    • LR over epochs
    • Gradient norms
    • Best validation loss and corresponding model state dict

Run lifecycle

Run.train() does the following:

  1. Build dataset (StandardGraphDataset or DisjointGraphDataset).

  2. Log dataset stats and configuration.

  3. Save:

    • Random train/test graph plots.
    • Action-id histograms per node-pair for train/test splits.
  4. Check structural constraints using DatasetUtils:

    • Triplet overlap fraction (should be zero for disjoint).
    • Consistency of (u, action) → v mapping.
    • An estimate of the optimal prediction rate using an oracle-like lookup.
  5. Build the network (RNNGraphNetwork or PlasticGraphNetwork).

  6. Log parameter counts (slow vs fast).

  7. Create a Trainer with callbacks:

    • For plastic networks, reset plastic weights at the start of each batch.
    • For logging metrics into wandb.
  8. Train on the training data.

  9. Save:

    • Model checkpoint (model.pth) with model + dataset configs.
    • Dataset and trainer configs (dataset_config.json, trainer_config.json, run_config.json).
    • Training curves figure.

Run.evaluate() then:

  • Computes test next-node accuracy.
  • Plots rolling accuracy during a long exploration on a single random test graph.
  • Plots rolling accuracy across multiple graphs with visible boundaries, to see how quickly the model adapts after a graph switch.

Run.generate_report() uses RunReportBuilder to produce a minimal LaTeX report (optionally compiled to PDF) summarising:

  • Configuration.
  • Key metrics.
  • Training curves, rolling accuracy, graphs, and histograms.

All run artefacts are written into a fresh directory:

run/
  <uuid>/
    model.pth
    run_config.json
    dataset_config.json
    trainer_config.json
    train_graph.png
    test_graph.png
    train_action_hist.png
    test_action_hist.png
    training_curves.png
    rolling_accuracy.png
    rolling_accuracy_boundaries.png
    report.tex
    report.pdf (if compilation succeeds)
    run.log

Installation and usage

Prerequisites

  • Python (3.10+ recommended)

  • PyTorch with CUDA if you want GPU acceleration

  • uv Python package manager (for environment and dependency management)

  • Optional:

    • LaTeX (tectonic or pdflatex) for PDF report compilation
    • wandb for experiment tracking

Install dependencies with uv

From the repository root:

uv sync

This will create a virtual environment and install dependencies specified in your project configuration (for example a pyproject.toml).

You can then run commands inside that environment using:

uv run python main.py --config path/to/config.yaml

or, if you prefer manual activation, activate the environment that uv created and call python directly.


Running experiments

The main entry point is main.py:

uv run python main.py --config configs/rnn_standard.yaml

or, without uv:

python main.py --config configs/rnn_standard.yaml

This will:

  1. Parse the YAML config.
  2. Build a RunConfig.
  3. Create a Run object.
  4. Call run.train(), run.evaluate(), and run.generate_report().
  5. Save all artefacts under run/<uuid>/.

Configuration file structure

main.py expects a YAML file structured into sections:

  • global
  • dataset
  • network
  • trainer
  • evaluation
  • logging

Each section is optional; keys are merged into the RunConfig. The trainer section is a bit special: it is flattened into the top-level config but also allows batch_size to be aliased to trainer_batch_size.

Example: standard RNN on standard dataset

# configs/rnn_standard.yaml

global:
  seed: 42
  dataset_type: standard        # "standard" or "disjoint"
  network_type: rnn             # "rnn" or "plastic"

dataset:
  num_nodes: 20
  edge_prob: 0.2
  num_samples_per_graph: 64
  num_graphs: 200
  num_timesteps: 32
  num_actions: 8
  test_fraction: 0.2

network:
  hidden_dim: 64                # RNN hidden dimension

trainer:
  epochs: 50
  trainer_batch_size: 32        # minibatch size for Trainer
  lr: 0.0005
  val_split: 0.2
  scheduler: cosine             # or "none"

evaluation:
  long_timesteps: 2000
  long_window_size: 100
  boundary_num_graphs: 5
  boundary_repeats_per_graph: 5
  boundary_window_size: 50

logging:
  wandb_enabled: false          # set true to enable Weights & Biases
  wandb_project: "plastic-graphs"
  wandb_entity: null
  wandb_group: null
  wandb_tags: ["rnn", "standard"]

Run it with:

uv run python main.py --config configs/rnn_standard.yaml

Example: plastic network on disjoint dataset with larger graphs

# configs/plastic_disjoint.yaml

global:
  seed: 123
  dataset_type: disjoint
  network_type: plastic

dataset:
  num_nodes: 50
  edge_prob: 0.1
  num_samples_per_graph: 32
  num_graphs: 150
  num_timesteps: 40
  num_actions: 15
  test_fraction: 0.2

network:
  hidden_dim: 128

trainer:
  epochs: 80
  trainer_batch_size: 32
  lr: 0.0005
  val_split: 0.2
  scheduler: cosine

evaluation:
  long_timesteps: 2000
  long_window_size: 100
  boundary_num_graphs: 5
  boundary_repeats_per_graph: 5
  boundary_window_size: 50

logging:
  wandb_enabled: true
  wandb_project: "plastic-graphs"
  wandb_entity: "your_entity"
  wandb_group: "disjoint-50n15a"
  wandb_tags: ["plastic", "disjoint", "n50", "a15"]

Run it with:

uv run python main.py --config configs/plastic_disjoint.yaml

You should then find in the corresponding run/<uuid>/ directory:

  • Near-perfect next-node test accuracy.
  • Rolling accuracy plots that show quick adaptation when switching between graphs.
  • A LaTeX/PDF report summarising the run.

Extending the project

Some natural directions:

  • Try other sequence models (GRU, Transformer) as baselines.

  • Add variants of plastic networks (e.g. gated Hebbian updates, learned decay).

  • Introduce more realistic environments:

    • Temporal changes in the graph structure.
    • Non-uniform and state-dependent action distributions.
  • Learn the exploration policy itself rather than walking randomly.

The current codebase is intentionally modular:

  • New datasets inherit from BaseGraphDataset.
  • New models inherit from BaseGraphNetwork.
  • Everything plugs into the same Run and Trainer pipeline.

License

MIT License - see LICENSE file for details.


This project shows that biologically inspired plasticity mechanisms can give neural networks rapid, online adaptation to novel environments that standard architectures struggle with.

About

Plastic neural networks that learn graph dynamics online, comparing standard RNNs to fast-weight plastic models for rapid adaptation in dynamic environments

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages