Unified Interface for World Models in Reinforcement Learning
One API. Multiple Architectures. Clear Contracts.
Alpha (v0.1.1) — Under active development. API may change between minor versions.
WorldFlux provides a unified Python interface for world models used in reinforcement learning.
World models let RL agents imagine before acting by predicting future states, rewards, and outcomes without touching the real environment. Upstream literature reports strong sample-efficiency gains for world-model methods in many settings (Hafner et al., 2023; Hansen et al., 2024).
The problem: every research team reimplements the same core components from scratch. DreamerV3, TD-MPC2, JEPA — different codebases, different APIs, incompatible training loops. Want to swap an encoder while keeping DreamerV3's dynamics? Rewrite everything.
WorldFlux solves this with a unified interface:
# One API for any world model architecture
model = create_world_model("dreamerv3:size12m")
state = model.encode(obs)
trajectory = model.rollout(state, actions) # imagine 15 steps ahead- Swap components independently with the 5-layer pluggable architecture
- Reference-family implementations with proof-mode parity workflows against upstream baselines; public proof claims require published evidence bundles
- Training infrastructure with replay buffers, checkpointing, and callbacks
- One API —
encode(),transition(),decode(),rollout()— works across all model families
- Unified API: Common interface across model families
- v3-first API:
create_world_model()defaults toapi_version="v3"(strict contracts enabled) - Universal Payload Layer:
ActionPayload/ConditionPayloadfor polymorphic conditioning - Planner Contract: planners return
ActionPayloadwithextras["wf.planner.horizon"] - Simple Usage: One-liner model creation with
create_world_model() - Pluggable 5-layer core: optional
component_overridesfor encoder/dynamics/conditioner/decoder/rollout - Training Infrastructure: Complete training loop with callbacks, checkpointing, and logging
- Type Safe: Full type annotations and mypy compatibility
Install uv first if you do not have it yet: uv installation guide.
uv tool install worldflux
worldflux init my-world-modelOptional: enable the InquirerPy-powered prompt UI.
uv tool install --with inquirerpy worldfluxworldflux init now performs cross-platform pre-init dependency assurance.
It provisions a user-scoped bootstrap virtual environment and installs the
selected environment dependencies before scaffolding:
- Linux/macOS default:
~/.worldflux/bootstrap/py<major><minor> - Windows default:
%LOCALAPPDATA%/WorldFlux/bootstrap/py<major><minor>
Environment variables:
WORLDFLUX_BOOTSTRAP_HOME: override bootstrap root pathWORLDFLUX_INIT_ENSURE_DEPS=0: disable auto-bootstrap (emergency bypass)
git clone https://github.com/worldflux/WorldFlux.git
cd worldflux
uv sync
source .venv/bin/activate
worldflux init my-world-model
# With training dependencies
uv sync --extra training
# With all optional dependencies
uv sync --extra all
# For development
uv sync --extra devuv pip install worldflux
worldflux init my-world-modelworldflux doctorcd website
npm ci
npm run build
# Optional: local docs dev server
npm startuv sync --extra dev
uv run python examples/quickstart_cpu_success.py --quickThis official smoke path uses a random replay buffer and a CI-sized model to validate installation and core contracts on CPU. It is not a benchmark or a real-environment reproduction path.
from worldflux import create_world_model
model = create_world_model("dreamerv3:size12m")from worldflux import ActionPayload, ConditionPayload
state = model.encode(obs)
next_state = model.transition(
state,
ActionPayload(kind="continuous", tensor=action),
conditions=ConditionPayload(goal=goal_tensor),
)from worldflux import create_world_model
model = create_world_model(
"tdmpc2:ci",
obs_shape=(4,),
action_dim=2,
component_overrides={
# values can be registered component ids, classes, or instances
"action_conditioner": "my_plugin.zero_action_conditioner",
},
)External packages can register plugins through entry-point groups:
worldflux.modelsworldflux.components
import torch
obs = torch.randn(1, 3, 64, 64)
state = model.encode(obs)
actions = torch.randn(15, 1, 6) # [horizon, batch, action_dim]
trajectory = model.rollout(state, actions)
print(f"Predicted rewards: {trajectory.rewards.shape}")
print(f"Continue probs: {trajectory.continues.shape}")from worldflux import create_world_model
from worldflux.training import train, ReplayBuffer
model = create_world_model("dreamerv3:size12m", obs_shape=(3, 64, 64), action_dim=6)
buffer = ReplayBuffer.load("trajectories.npz")
trained_model = train(model, buffer, total_steps=50_000)
trained_model.save_pretrained("./my_model")| Family | Presets | Status |
|---|---|---|
| DreamerV3 | size12m, size25m, size50m, size100m, size200m |
Reference-family |
| TD-MPC2 | 5m, 19m, 48m, 317m |
Reference-family |
| JEPA | base |
Experimental |
| V-JEPA2 | ci, tiny, base |
Experimental |
| Token | base |
Experimental |
| Diffusion | base |
Experimental |
Reference-family models map to maintained upstream families and internal proof-mode parity workflows. Public proof claims require published evidence bundles; local fixtures and internal runs are not enough on their own. Experimental models implement the full API but do not carry the same parity workflow coverage and may return
Nonefor some predictions (e.g. rewards).
This table lists commonly used presets. For the full catalog (including CI, experimental, and skeleton families), run:
worldflux models list --verboseAll world models implement the WorldModel base class:
state = model.encode(obs)
next_state = model.transition(state, action)
next_state = model.update(state, action, obs)
output = model.decode(state)
preds = output.preds # e.g. {"obs", "reward", "continue"}
trajectory = model.rollout(initial_state, actions)
loss_out = model.loss(batch) # LossOutput (loss_out.loss, loss_out.components)from worldflux.training import (
Trainer,
TrainingConfig,
ReplayBuffer,
train,
)
from worldflux.training.callbacks import (
LoggingCallback,
CheckpointCallback,
EarlyStoppingCallback,
ProgressCallback,
)See the examples/ directory:
quickstart_cpu_success.py- Official CPU-first smoke path using a random replay buffercompare_unified_training.py- Shared-contract smoke comparison for DreamerV3 and TD-MPC2worldflux_quickstart.ipynb- Interactive Colab notebooktrain_dreamer.py- Training exampletrain_tdmpc2.py- Training examplevisualize_imagination.py- Imagination rollout visualization
uv run python examples/quickstart_cpu_success.py --quick
uv run python examples/compare_unified_training.py --quick
uv run python examples/train_dreamer.py --test
uv run python examples/train_dreamer.py --data trajectories.npz --steps 100000- Full Documentation - Guides and API reference
- API Reference - Contract and symbol-level docs
- Reference - Operational and quality docs
Join our Discord to discuss world models, get help, and connect with other researchers and developers.
- Support channels and response paths: SUPPORT.md
- Community expectations and reporting: CODE_OF_CONDUCT.md
See SECURITY.md for security considerations, especially regarding loading model checkpoints from untrusted sources.
Apache License 2.0 - see LICENSE and NOTICE for details.
Contributions are welcome. Please read our Contributing Guide before submitting pull requests.
If you use this library in your research, please cite:
@software{worldflux,
title = {WorldFlux: Unified Interface for World Models},
year = {2026},
url = {https://github.com/worldflux/WorldFlux}
}