Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,21 @@ img = pipeline.append_frame(uint8_img) # returns passed image

Note: returned `img` is always on the same device as `engine.device`

## Quantization

Model can be quantized by passing quant argument to WorldEngine
```
engine = WorldEngine("Overworld/Waypoint-1.5-1B", quant="intw8a8", device="cuda")
```
Supported inference quantization schemes are:

| Config | Description | Supported GPUs |
|--------|-------------|----------------|
| `intw8a8` | INT8 weights + INT8 dynamic per-token activations | NVIDIA (30xx, 40xx, Ampere+) |
| `fp8w8a8` | FP8 (e4m3) weights + FP8 per-tensor activations via `torch._scaled_mm` | NVIDIA Ada Lovelace / Hopper+ (RTX 40xx, H100) |
| `nvfp4` | NVFP4 weights + FP4 activations via FlashInfer/CUTLASS | NVIDIA Blackwell (B100, B200, RTX 5090) |


### WorldEngine

`WorldEngine` computes each new frame from past frames, the controls, and the current prompt, then appends it to the sequence so later frames stay aligned with what has already been generated.
Expand Down
3 changes: 2 additions & 1 deletion examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def get_warm_engine(model_uri, model_overrides=None):
engine = WorldEngine(
model_uri,
model_config_overrides=model_config_overrides,
quant="intw8a8",
device="cuda",
load_weights=False
)
Expand Down Expand Up @@ -124,4 +125,4 @@ def target():
if blocking:
torch.cuda.synchronize()

benchmark.pedantic(target, setup=setup, rounds=20)
benchmark.pedantic(target, setup=setup, rounds=20)
2 changes: 1 addition & 1 deletion examples/gen_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


# Create inference engine
engine = WorldEngine(sys.argv[1], device="cuda")
engine = WorldEngine(sys.argv[1], quant=None, device="cuda")


# Define sequence of controller inputs applied
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ requires-python = ">=3.10"
dependencies = [
"taehv @ git+https://github.com/madebyollin/taehv.git@7dc60ec6601af2e668e31bc70acc4cb3665e4c22",
"torch==2.10.0",
"torchvision==0.25.0",
"torchaudio==2.10.0",
"fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'",
"gemlite",
"einops",
"tensordict==0.10.0",
"transformers>=5.3.0",
Expand Down
20 changes: 0 additions & 20 deletions src/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,6 @@
from torch import nn
import torch


MODEL_CONFIG_DEFAULTS = OmegaConf.create(
{
"auto_aspect_ratio": True,
"gated_attn": False,
"inference_fps": "${base_fps}",
"model_type": "waypoint-1",
"n_kv_heads": "${n_heads}",
"patch": [1, 1],
"prompt_conditioning": None,
"prompt_encoder_uri": "google/umt5-xl",
"rope_nyquist_frac": 0.8,
"rope_theta": 10000.0,
"taehv_ae": False,
"temporal_compression": 1,
"value_residual": False,
}
)


class BaseModel(nn.Module):
@classmethod
def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weights: bool = True):
Expand Down
119 changes: 114 additions & 5 deletions src/quantize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -12,6 +13,12 @@
QUANTS.append("nvfp4")
except ImportError:
pass
try:
from gemlite.helper import A8W8_INT8_dynamic
import gemlite
gemlite.set_autotune("max")
except ImportError:
A8W8_INT8_dynamic = None


@torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
Expand Down Expand Up @@ -108,7 +115,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class FP8W8A8Linear(nn.Module):
__constants__ = ("in_features", "out_features")

def __init__(self, lin: nn.Linear):
def __init__(self, lin: nn.Linear, smoothquant: bool = False):
super().__init__()
self.in_features, self.out_features = lin.in_features, lin.out_features

Expand All @@ -127,10 +134,25 @@ def __init__(self, lin: nn.Linear):
else:
self.register_buffer("bias", lin.bias.detach().to(torch.float16))

smooth = getattr(lin, "_smooth_scale", None)
if smoothquant and smooth is None:
raise ValueError(
f"smoothquant=True but this checkpoint has no _smooth_scale on "
f"{type(lin).__name__}(in={lin.in_features}, out={lin.out_features}). "
"SmoothQuant cannot be applied to this model checkpoint."
)
if smooth is not None:
self.register_buffer("_smooth_scale", smooth.detach())
else:
self._smooth_scale = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
s = x.shape
x2 = x.reshape(-1, s[-1])

if self._smooth_scale is not None:
x2 = x2 * self._smooth_scale

xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() # 0-d
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()

Expand Down Expand Up @@ -183,8 +205,94 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

return result.reshape(x.shape[:-1] + (-1,))

def _per_token_quant_int8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Local per-token symmetric int8 quantization matching SGLang's W8A8 flow:
scale = absmax / 127
x_q = round(x / scale)
Returns:
x_q: [..., K] int8
scales: [..., 1] float32
"""
x_fp = x.float().nan_to_num()
scales = (x_fp.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) / 127.0).float()
x_q = torch.round(x_fp / scales).clamp(-127, 127).to(torch.int8)
return x_q, scales


@torch.library.custom_op("world_engine::w8a8_int8_linear", mutates_args=())
def w8a8_int8_linear(
a: torch.Tensor,
b_int8_T: torch.Tensor,
b_scale: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if sgl_int8_scaled_mm is None:
raise ImportError("sgl-kernel is required for quant='w8a8'")

assert a.ndim == 2, "expected [M, K] input"
x_q, x_scale = _per_token_quant_int8(a.contiguous())

bias_arg = None if bias.numel() == 0 else bias
return sgl_int8_scaled_mm(
x_q, # [M, K] row-major int8
b_int8_T, # [K, N] column-major int8 view
x_scale, # [M, 1] float32
b_scale, # [N, 1] float32
out_dtype=a.dtype,
bias=bias_arg,
)


@w8a8_int8_linear.register_fake
def _w8a8_int8_linear_fake(
a: torch.Tensor,
b_int8_T: torch.Tensor,
b_scale: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return torch.empty(
(a.shape[0], b_int8_T.shape[1]),
device=a.device,
dtype=a.dtype,
)

class INT8W8A8GemLite(nn.Module):
__constants__ = ("in_features", "out_features")

def __init__(self, lin: nn.Linear, smoothquant: bool = False):
super().__init__()
if A8W8_INT8_dynamic is None:
raise ImportError("Install gemlite for quant='w8a8_gemlite'")

self.in_features = lin.in_features
self.out_features = lin.out_features

# Minimal wrapper: assumes the layer is already on the target CUDA device.
self.impl = A8W8_INT8_dynamic(
device=str(lin.weight.device),
dtype=lin.weight.dtype,
).from_linear(lin)

smooth = getattr(lin, "_smooth_scale", None)
if smoothquant and smooth is None:
raise ValueError(
f"smoothquant=True but this checkpoint has no _smooth_scale on "
f"{type(lin).__name__}(in={lin.in_features}, out={lin.out_features}). "
"SmoothQuant cannot be applied to this model checkpoint."
)
if smooth is not None:
self.register_buffer("_smooth_scale", smooth.detach())
else:
self._smooth_scale = None

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._smooth_scale is not None:
x = x * self._smooth_scale
return self.impl(x).type_as(x)


def quantize_model(model: nn.Module, quant: str):
def quantize_model(model: nn.Module, quant: str, smoothquant: bool = False):
if quant is None:
return model

Expand All @@ -198,13 +306,14 @@ def eligible(m: nn.Module) -> bool:
return (o % 32 == 0) and (k % 32 == 0)

new_linear = {
"w8a8": FP8W8A8Linear,
"intw8a8": partial(INT8W8A8GemLite, smoothquant=smoothquant),
"fp8w8a8": partial(FP8W8A8Linear, smoothquant=smoothquant),
"nvfp4": FP4Linear,
"fp8": FP8Linear,
}[quant]

for name, child in model.named_children():
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(
child, quant
child, quant, smoothquant
)
return model
return model
10 changes: 6 additions & 4 deletions src/world_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# "shape_padding": True,
}


import time
@dataclass
class CtrlInput:
button: Set[int] = field(default_factory=set) # pressed button IDs
Expand All @@ -38,15 +38,15 @@ def __init__(
self,
model_uri: str,
quant: Optional[str] = None,
smooth: bool = False,
model_config_overrides: Optional[Dict] = None,
device=None,
dtype=torch.bfloat16,
load_weights: bool = True
):
"""
model_uri: HF URI or local folder containing model.safetensors and config.yaml
quant: None | w8a8 | nvfp4

quant: None | intw8a8 | fp8w8a8 | nvfp4
model_config_overrides: Dict to override model config values
- auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p
"""
Expand Down Expand Up @@ -77,7 +77,9 @@ def __init__(
).eval()
apply_inference_patches(self.model)
if quant is not None:
quantize_model(self.model, quant)
start_time = time.time()
quantize_model(self.model, quant, smoothquant=smooth)
print(f"Quantization took {time.time() - start_time:.2f}s")

self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device)

Expand Down