diff --git a/README.md b/README.md index 8d68292..005129d 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,21 @@ img = pipeline.append_frame(uint8_img) # returns passed image Note: returned `img` is always on the same device as `engine.device` +## Quantization + +Model can be quantized by passing quant argument to WorldEngine +``` +engine = WorldEngine("Overworld/Waypoint-1.5-1B", quant="intw8a8", device="cuda") +``` +Supported inference quantization schemes are: + +| Config | Description | Supported GPUs | +|--------|-------------|----------------| +| `intw8a8` | INT8 weights + INT8 dynamic per-token activations | NVIDIA (30xx, 40xx, Ampere+) | +| `fp8w8a8` | FP8 (e4m3) weights + FP8 per-tensor activations via `torch._scaled_mm` | NVIDIA Ada Lovelace / Hopper+ (RTX 40xx, H100) | +| `nvfp4` | NVFP4 weights + FP4 activations via FlashInfer/CUTLASS | NVIDIA Blackwell (B100, B200, RTX 5090) | + + ### WorldEngine `WorldEngine` computes each new frame from past frames, the controls, and the current prompt, then appends it to the sequence so later frames stay aligned with what has already been generated. diff --git a/examples/benchmark.py b/examples/benchmark.py index 3d8e623..c9882c9 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -56,6 +56,7 @@ def get_warm_engine(model_uri, model_overrides=None): engine = WorldEngine( model_uri, model_config_overrides=model_config_overrides, + quant="intw8a8", device="cuda", load_weights=False ) @@ -124,4 +125,4 @@ def target(): if blocking: torch.cuda.synchronize() - benchmark.pedantic(target, setup=setup, rounds=20) + benchmark.pedantic(target, setup=setup, rounds=20) \ No newline at end of file diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 0217896..8aae0c6 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -12,7 +12,7 @@ # Create inference engine -engine = WorldEngine(sys.argv[1], device="cuda") +engine = WorldEngine(sys.argv[1], quant=None, device="cuda") # Define sequence of controller inputs applied diff --git a/pyproject.toml b/pyproject.toml index 1bd9d0a..b6c41eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,10 @@ requires-python = ">=3.10" dependencies = [ "taehv @ git+https://github.com/madebyollin/taehv.git@7dc60ec6601af2e668e31bc70acc4cb3665e4c22", "torch==2.10.0", + "torchvision==0.25.0", + "torchaudio==2.10.0", + "fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'", + "gemlite", "einops", "tensordict==0.10.0", "transformers>=5.3.0", diff --git a/src/model/base_model.py b/src/model/base_model.py index f94314e..02a9b4f 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -6,26 +6,6 @@ from torch import nn import torch - -MODEL_CONFIG_DEFAULTS = OmegaConf.create( - { - "auto_aspect_ratio": True, - "gated_attn": False, - "inference_fps": "${base_fps}", - "model_type": "waypoint-1", - "n_kv_heads": "${n_heads}", - "patch": [1, 1], - "prompt_conditioning": None, - "prompt_encoder_uri": "google/umt5-xl", - "rope_nyquist_frac": 0.8, - "rope_theta": 10000.0, - "taehv_ae": False, - "temporal_compression": 1, - "value_residual": False, - } -) - - class BaseModel(nn.Module): @classmethod def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weights: bool = True): diff --git a/src/quantize.py b/src/quantize.py index b74825b..37c06dd 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -1,4 +1,5 @@ from typing import Optional +from functools import partial import torch import torch.nn as nn @@ -12,6 +13,12 @@ QUANTS.append("nvfp4") except ImportError: pass +try: + from gemlite.helper import A8W8_INT8_dynamic + import gemlite + gemlite.set_autotune("max") +except ImportError: + A8W8_INT8_dynamic = None @torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) @@ -108,7 +115,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FP8W8A8Linear(nn.Module): __constants__ = ("in_features", "out_features") - def __init__(self, lin: nn.Linear): + def __init__(self, lin: nn.Linear, smoothquant: bool = False): super().__init__() self.in_features, self.out_features = lin.in_features, lin.out_features @@ -127,10 +134,25 @@ def __init__(self, lin: nn.Linear): else: self.register_buffer("bias", lin.bias.detach().to(torch.float16)) + smooth = getattr(lin, "_smooth_scale", None) + if smoothquant and smooth is None: + raise ValueError( + f"smoothquant=True but this checkpoint has no _smooth_scale on " + f"{type(lin).__name__}(in={lin.in_features}, out={lin.out_features}). " + "SmoothQuant cannot be applied to this model checkpoint." + ) + if smooth is not None: + self.register_buffer("_smooth_scale", smooth.detach()) + else: + self._smooth_scale = None + def forward(self, x: torch.Tensor) -> torch.Tensor: s = x.shape x2 = x.reshape(-1, s[-1]) + if self._smooth_scale is not None: + x2 = x2 * self._smooth_scale + xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() # 0-d xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous() @@ -183,8 +205,94 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return result.reshape(x.shape[:-1] + (-1,)) +def _per_token_quant_int8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Local per-token symmetric int8 quantization matching SGLang's W8A8 flow: + scale = absmax / 127 + x_q = round(x / scale) + Returns: + x_q: [..., K] int8 + scales: [..., 1] float32 + """ + x_fp = x.float().nan_to_num() + scales = (x_fp.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) / 127.0).float() + x_q = torch.round(x_fp / scales).clamp(-127, 127).to(torch.int8) + return x_q, scales + + +@torch.library.custom_op("world_engine::w8a8_int8_linear", mutates_args=()) +def w8a8_int8_linear( + a: torch.Tensor, + b_int8_T: torch.Tensor, + b_scale: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if sgl_int8_scaled_mm is None: + raise ImportError("sgl-kernel is required for quant='w8a8'") + + assert a.ndim == 2, "expected [M, K] input" + x_q, x_scale = _per_token_quant_int8(a.contiguous()) + + bias_arg = None if bias.numel() == 0 else bias + return sgl_int8_scaled_mm( + x_q, # [M, K] row-major int8 + b_int8_T, # [K, N] column-major int8 view + x_scale, # [M, 1] float32 + b_scale, # [N, 1] float32 + out_dtype=a.dtype, + bias=bias_arg, + ) + + +@w8a8_int8_linear.register_fake +def _w8a8_int8_linear_fake( + a: torch.Tensor, + b_int8_T: torch.Tensor, + b_scale: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return torch.empty( + (a.shape[0], b_int8_T.shape[1]), + device=a.device, + dtype=a.dtype, + ) + +class INT8W8A8GemLite(nn.Module): + __constants__ = ("in_features", "out_features") + + def __init__(self, lin: nn.Linear, smoothquant: bool = False): + super().__init__() + if A8W8_INT8_dynamic is None: + raise ImportError("Install gemlite for quant='w8a8_gemlite'") + + self.in_features = lin.in_features + self.out_features = lin.out_features + + # Minimal wrapper: assumes the layer is already on the target CUDA device. + self.impl = A8W8_INT8_dynamic( + device=str(lin.weight.device), + dtype=lin.weight.dtype, + ).from_linear(lin) + + smooth = getattr(lin, "_smooth_scale", None) + if smoothquant and smooth is None: + raise ValueError( + f"smoothquant=True but this checkpoint has no _smooth_scale on " + f"{type(lin).__name__}(in={lin.in_features}, out={lin.out_features}). " + "SmoothQuant cannot be applied to this model checkpoint." + ) + if smooth is not None: + self.register_buffer("_smooth_scale", smooth.detach()) + else: + self._smooth_scale = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._smooth_scale is not None: + x = x * self._smooth_scale + return self.impl(x).type_as(x) + -def quantize_model(model: nn.Module, quant: str): +def quantize_model(model: nn.Module, quant: str, smoothquant: bool = False): if quant is None: return model @@ -198,13 +306,14 @@ def eligible(m: nn.Module) -> bool: return (o % 32 == 0) and (k % 32 == 0) new_linear = { - "w8a8": FP8W8A8Linear, + "intw8a8": partial(INT8W8A8GemLite, smoothquant=smoothquant), + "fp8w8a8": partial(FP8W8A8Linear, smoothquant=smoothquant), "nvfp4": FP4Linear, "fp8": FP8Linear, }[quant] for name, child in model.named_children(): setattr(model, name, new_linear(child)) if eligible(child) else quantize_model( - child, quant + child, quant, smoothquant ) - return model + return model \ No newline at end of file diff --git a/src/world_engine.py b/src/world_engine.py index 91fc2ae..7321efc 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -25,7 +25,7 @@ # "shape_padding": True, } - +import time @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs @@ -38,6 +38,7 @@ def __init__( self, model_uri: str, quant: Optional[str] = None, + smooth: bool = False, model_config_overrides: Optional[Dict] = None, device=None, dtype=torch.bfloat16, @@ -45,8 +46,7 @@ def __init__( ): """ model_uri: HF URI or local folder containing model.safetensors and config.yaml - quant: None | w8a8 | nvfp4 - + quant: None | intw8a8 | fp8w8a8 | nvfp4 model_config_overrides: Dict to override model config values - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p """ @@ -77,7 +77,9 @@ def __init__( ).eval() apply_inference_patches(self.model) if quant is not None: - quantize_model(self.model, quant) + start_time = time.time() + quantize_model(self.model, quant, smoothquant=smooth) + print(f"Quantization took {time.time() - start_time:.2f}s") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device)