From fd7c4c2afd7dd5b49c14c78fad0cff5e9e6c8747 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 14 Mar 2026 04:37:44 +0530 Subject: [PATCH 01/20] add torchao quantize_ --- src/quantize.py | 22 +++++++++++++++++++++- src/world_engine.py | 7 +++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/quantize.py b/src/quantize.py index b74825b..9c7b346 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -182,7 +182,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return result.reshape(x.shape[:-1] + (-1,)) - + def quantize_model(model: nn.Module, quant: str): if quant is None: @@ -208,3 +208,23 @@ def eligible(m: nn.Module) -> bool: child, quant ) return model + +from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8WeightOnlyConfig +_LAYER_FILTERS = { + "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "dit_mlp" in fqn, + "attention": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and ".attn." in fqn, +} + + +def quantize_qat_model(model, config: str, layers: str = None): + """Apply QAT in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" + filter_fn = _LAYER_FILTERS.get(layers) if layers else None + + if config == "int4_weights": + qconfig = Int4WeightOnlyConfig(group_size=32) + elif config == "int8_weights": + qconfig = Int8WeightOnlyConfig() + elif config == "fp8_weights": + qconfig = Float8WeightOnlyConfig() + + quantize_(model, qconfig, filter_fn=filter_fn) \ No newline at end of file diff --git a/src/world_engine.py b/src/world_engine.py index 91fc2ae..cd1b991 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -6,7 +6,7 @@ from .model import WorldModel, StaticKVCache, PromptEncoder from .ae import get_ae from .patch_model import apply_inference_patches -from .quantize import quantize_model +from .quantize import quantize_model, quantize_qat_model # Global torch optimizations @@ -45,8 +45,7 @@ def __init__( ): """ model_uri: HF URI or local folder containing model.safetensors and config.yaml - quant: None | w8a8 | nvfp4 - + quant: None | fp8_weights | int8_weights | int4_weights model_config_overrides: Dict to override model config values - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p """ @@ -77,7 +76,7 @@ def __init__( ).eval() apply_inference_patches(self.model) if quant is not None: - quantize_model(self.model, quant) + quantize_qat_model(self.model, quant, layers="mlp") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) From e37c4376c9bedad487311b74a4b3a745faf54184 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 14 Mar 2026 04:44:40 +0530 Subject: [PATCH 02/20] testing --- examples/gen_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 0217896..d5b2ada 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -12,7 +12,7 @@ # Create inference engine -engine = WorldEngine(sys.argv[1], device="cuda") +engine = WorldEngine(sys.argv[1], quant="int4_weights", device="cuda") # Define sequence of controller inputs applied From 86fd345dafc171cb6816d30bd1ada3586076fa23 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 14 Mar 2026 04:52:21 +0530 Subject: [PATCH 03/20] testing yes --- examples/gen_sample.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index d5b2ada..6198089 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -12,7 +12,12 @@ # Create inference engine -engine = WorldEngine(sys.argv[1], quant="int4_weights", device="cuda") +model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5"} +model_config_overrides.update({}) +engine = WorldEngine(sys.argv[1], + model_config_overrides=model_config_overrides, + quant="int4_weights", + device="cuda") # Define sequence of controller inputs applied From cf8a4ecd1da479dd7c0d29c33e7ca41b20c20f11 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 14 Mar 2026 05:05:29 +0530 Subject: [PATCH 04/20] use taehv overide --- examples/gen_sample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 6198089..37992c1 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -12,7 +12,8 @@ # Create inference engine -model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5"} +model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5", + "use_taehv_ae": True} model_config_overrides.update({}) engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, From 62320304540410f855835d74da919a6c7306fcc4 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sat, 14 Mar 2026 05:06:03 +0530 Subject: [PATCH 05/20] yuh --- examples/gen_sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 37992c1..c668f7e 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -13,7 +13,7 @@ # Create inference engine model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5", - "use_taehv_ae": True} + "taehv_ae": True} model_config_overrides.update({}) engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, From f59f7cc1b5bd5d08f18fc48ea505c8b92cbc7f1c Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 15 Mar 2026 00:56:47 +0530 Subject: [PATCH 06/20] add apply qat --- examples/gen_sample.py | 3 ++- src/model/base_model.py | 4 ++++ src/quantize.py | 37 ++++++++++++++++++++++++++++++++++--- src/world_engine.py | 6 ++++-- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index c668f7e..72b9273 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -13,7 +13,8 @@ # Create inference engine model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5", - "taehv_ae": True} + "taehv_ae": True, + "quant": "int4_weights"} model_config_overrides.update({}) engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, diff --git a/src/model/base_model.py b/src/model/base_model.py index f94314e..5635594 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -6,6 +6,7 @@ from torch import nn import torch +from quantize import apply_qat MODEL_CONFIG_DEFAULTS = OmegaConf.create( { @@ -42,6 +43,9 @@ def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weig cfg = cls.load_config(path) model = cls(cfg).to(dtype=dtype, device=device) + if cfg.quant is not None: + apply_qat(model, quant_config=cfg.quant, layers="mlp", step="prepare") + if load_weights: safetensors_path = os.path.join(path, "model.safetensors") model.load_state_dict(load_file(safetensors_path, device=device), strict=True) diff --git a/src/quantize.py b/src/quantize.py index 9c7b346..1b6e63d 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -210,14 +210,15 @@ def eligible(m: nn.Module) -> bool: return model from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8WeightOnlyConfig +from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig, Float8FakeQuantizeConfig, PerTensor + _LAYER_FILTERS = { "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "dit_mlp" in fqn, "attention": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and ".attn." in fqn, } - -def quantize_qat_model(model, config: str, layers: str = None): - """Apply QAT in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" +def apply_ptq_model(model, config: str, layers: str = None): + """Apply PTQ in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" filter_fn = _LAYER_FILTERS.get(layers) if layers else None if config == "int4_weights": @@ -227,4 +228,34 @@ def quantize_qat_model(model, config: str, layers: str = None): elif config == "fp8_weights": qconfig = Float8WeightOnlyConfig() + quantize_(model, qconfig, filter_fn=filter_fn) + +def apply_qat(model, quant_config: str = "fp8_general", layers: str = None, step: str = "prepare"): + """Apply QAT in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" + filter_fn = _LAYER_FILTERS.get(layers) if layers else None + + if step == "prepare": + if quant_config == "fp8_general": + weight_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn, granularity=PerTensor()) + qconfig = QATConfig(weight_config=weight_config, step=step) + elif quant_config == "int8_weights": + weight_config = IntxFakeQuantizeConfig(torch.int8, group_size=32, is_symmetric=True) + qconfig = QATConfig(weight_config=weight_config, step=step) + elif quant_config == "int4_weights": + config = Int4WeightOnlyConfig( + group_size=32, + ) + qconfig = QATConfig(base_config=config, step=step) + else: + raise ValueError(f"Unknown quant_config: {quant_config!r}") + elif step == "convert": + # convert step requires a real PTQ base_config (not FakeQuantizeConfigBase) + # or None (which just strips fake-quant wrappers back to plain nn.Linear) + if quant_config == "fp8_general": + qconfig = QATConfig(base_config=Float8WeightOnlyConfig(), step=step) + elif quant_config == "int4_weights": + qconfig = QATConfig(base_config=Int4WeightOnlyConfig(group_size=32), step=step) + else: + raise ValueError(f"Unknown quant_config: {quant_config!r}") + quantize_(model, qconfig, filter_fn=filter_fn) \ No newline at end of file diff --git a/src/world_engine.py b/src/world_engine.py index cd1b991..37b0977 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -6,7 +6,7 @@ from .model import WorldModel, StaticKVCache, PromptEncoder from .ae import get_ae from .patch_model import apply_inference_patches -from .quantize import quantize_model, quantize_qat_model +from .quantize import quantize_model, apply_ptq_model, apply_qat # Global torch optimizations @@ -76,7 +76,9 @@ def __init__( ).eval() apply_inference_patches(self.model) if quant is not None: - quantize_qat_model(self.model, quant, layers="mlp") + # apply_ptq_model(self.model, quant, layers="mlp") + apply_qat(self.model, quant_config=quant, layers="mlp", step="convert") + self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) From e22f8a5f60c83a84a628c3da9568ca8bb60d86cb Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 15 Mar 2026 01:00:56 +0530 Subject: [PATCH 07/20] yuh --- src/model/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/base_model.py b/src/model/base_model.py index 5635594..506eeec 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -6,7 +6,7 @@ from torch import nn import torch -from quantize import apply_qat +from ..quantize import apply_qat MODEL_CONFIG_DEFAULTS = OmegaConf.create( { From 2737827d778b5222dd7d395fc617d750d242709d Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Sun, 15 Mar 2026 01:03:17 +0530 Subject: [PATCH 08/20] uh --- src/quantize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quantize.py b/src/quantize.py index 1b6e63d..fe9611c 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -209,8 +209,8 @@ def eligible(m: nn.Module) -> bool: ) return model -from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8WeightOnlyConfig -from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig, Float8FakeQuantizeConfig, PerTensor +from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8WeightOnlyConfig, PerTensor +from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig, Float8FakeQuantizeConfig _LAYER_FILTERS = { "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "dit_mlp" in fqn, From 3bc5568a8512c0e7904107007217352442ef6167 Mon Sep 17 00:00:00 2001 From: anm-ol Date: Mon, 16 Mar 2026 10:22:54 +0000 Subject: [PATCH 09/20] enable int4 benchmarking and inference --- examples/benchmark.py | 1 + examples/gen_sample.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 3d8e623..3125124 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -56,6 +56,7 @@ def get_warm_engine(model_uri, model_overrides=None): engine = WorldEngine( model_uri, model_config_overrides=model_config_overrides, + quant=model_config_overrides.get("quant"), device="cuda", load_weights=False ) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 72b9273..302ff8d 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -13,8 +13,12 @@ # Create inference engine model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5", - "taehv_ae": True, - "quant": "int4_weights"} + "patch": [2, 2], + "temporal_compression": 4, + "inference_fps": 60, + "quant": "int4_weights", + "taehv_ae": True} + model_config_overrides.update({}) engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, From 20c81817299c16da185e4395bf05d8f6da4b7744 Mon Sep 17 00:00:00 2001 From: anm-ol Date: Wed, 18 Mar 2026 04:43:00 +0000 Subject: [PATCH 10/20] apply quantize_model w8a8 --- examples/benchmark.py | 3 ++- examples/gen_sample.py | 8 ++++++-- src/quantize.py | 34 ++++++++++++++++++++++++++-------- src/world_engine.py | 8 +++++--- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 3125124..a178666 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -56,7 +56,8 @@ def get_warm_engine(model_uri, model_overrides=None): engine = WorldEngine( model_uri, model_config_overrides=model_config_overrides, - quant=model_config_overrides.get("quant"), + # quant=model_config_overrides.get("quant"), + quant="w8a8", device="cuda", load_weights=False ) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 302ff8d..14d6cd8 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -16,15 +16,19 @@ "patch": [2, 2], "temporal_compression": 4, "inference_fps": 60, - "quant": "int4_weights", + # "quant": "int8_weights", + "quant": None, "taehv_ae": True} model_config_overrides.update({}) engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, - quant="int4_weights", + # quant="int4_weights", + quant=None, device="cuda") +total_linear_params = sum(mod.weight.numel() for _, mod in engine.model.named_modules() if isinstance(mod, torch.nn.Linear)) +print(f"Total linear layer parameters: {total_linear_params:,}") # Define sequence of controller inputs applied controller_sequence = [ diff --git a/src/quantize.py b/src/quantize.py index fe9611c..06aa379 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -209,25 +209,37 @@ def eligible(m: nn.Module) -> bool: ) return model -from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8WeightOnlyConfig, PerTensor -from torchao.quantization.qat import QATConfig, IntxFakeQuantizeConfig, Float8FakeQuantizeConfig +from torchao.quantization import (quantize_, + Int4WeightOnlyConfig, + Int8WeightOnlyConfig, + Float8WeightOnlyConfig, + Float8DynamicActivationFloat8WeightConfig, + PerTensor, PerRow) +from torchao.quantization.quantize_.workflows import Int4PackingFormat, Float8PackingFormat +from torchao.quantization.qat import (QATConfig, + IntxFakeQuantizeConfig, + Float8FakeQuantizeConfig) _LAYER_FILTERS = { - "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "dit_mlp" in fqn, + "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "transformer.blocks" in fqn and ".mlp." in fqn, "attention": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and ".attn." in fqn, + "all": lambda mod, fqn: isinstance(mod, torch.nn.Linear), } -def apply_ptq_model(model, config: str, layers: str = None): +def apply_ptq_model(model, config: str, layers: str = "mlp"): """Apply PTQ in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" filter_fn = _LAYER_FILTERS.get(layers) if layers else None if config == "int4_weights": - qconfig = Int4WeightOnlyConfig(group_size=32) + qconfig = Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") elif config == "int8_weights": qconfig = Int8WeightOnlyConfig() elif config == "fp8_weights": qconfig = Float8WeightOnlyConfig() - + elif config == "f8aw": + qconfig = Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor ()) + else: + raise ValueError(f"Unknown quant_config: {config!r}") quantize_(model, qconfig, filter_fn=filter_fn) def apply_qat(model, quant_config: str = "fp8_general", layers: str = None, step: str = "prepare"): @@ -244,18 +256,24 @@ def apply_qat(model, quant_config: str = "fp8_general", layers: str = None, step elif quant_config == "int4_weights": config = Int4WeightOnlyConfig( group_size=32, + int4_packing_format=Int4PackingFormat.PRESHUFFLED ) qconfig = QATConfig(base_config=config, step=step) else: raise ValueError(f"Unknown quant_config: {quant_config!r}") + elif step == "convert": # convert step requires a real PTQ base_config (not FakeQuantizeConfigBase) # or None (which just strips fake-quant wrappers back to plain nn.Linear) if quant_config == "fp8_general": qconfig = QATConfig(base_config=Float8WeightOnlyConfig(), step=step) + elif quant_config == "int8_weights": + quantize_(model, QATConfig(step=step), filter_fn=filter_fn) # need to run quantize to convert fake quant to real quant for int8, since int8 fake quant is not a simple wrapper around int8 PTQ module + qconfig = QATConfig(base_config=Int8WeightOnlyConfig(), step=step) elif quant_config == "int4_weights": - qconfig = QATConfig(base_config=Int4WeightOnlyConfig(group_size=32), step=step) + qconfig = QATConfig(base_config=Int4WeightOnlyConfig(group_size=32, int4_packing_format=Int4PackingFormat.PRESHUFFLED), step=step) + elif quant_config == "bf16": + qconfig = QATConfig(step=step) else: raise ValueError(f"Unknown quant_config: {quant_config!r}") - quantize_(model, qconfig, filter_fn=filter_fn) \ No newline at end of file diff --git a/src/world_engine.py b/src/world_engine.py index 37b0977..41cab59 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -76,9 +76,11 @@ def __init__( ).eval() apply_inference_patches(self.model) if quant is not None: - # apply_ptq_model(self.model, quant, layers="mlp") - apply_qat(self.model, quant_config=quant, layers="mlp", step="convert") - + print(f"Applying {quant} PTQ...") + # apply_qat(self.model, quant_config="bf16", layers="mlp", step="convert") + quantize_model(self.model, quant) + # apply_qat(self.model, quant_config=quant, layers="all", step="convert") + # apply_ptq_model(self.model, quant, layers="all") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) From 59158fecf0b9070e5ced50eaa4f7d2e9912eda8b Mon Sep 17 00:00:00 2001 From: anm-ol Date: Wed, 18 Mar 2026 20:44:19 +0000 Subject: [PATCH 11/20] add int8 ptq --- examples/benchmark.py | 2 +- src/quantize.py | 19 ++++++++++++++++--- src/world_engine.py | 4 ++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index a178666..c74f65c 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -57,7 +57,7 @@ def get_warm_engine(model_uri, model_overrides=None): model_uri, model_config_overrides=model_config_overrides, # quant=model_config_overrides.get("quant"), - quant="w8a8", + quant="int_w8a8", device="cuda", load_weights=False ) diff --git a/src/quantize.py b/src/quantize.py index 06aa379..f3b77d8 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -212,6 +212,10 @@ def eligible(m: nn.Module) -> bool: from torchao.quantization import (quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + # NVFP4DynamicActivationNVFP4WeightConfig, + Float8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, Float8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig, PerTensor, PerRow) @@ -222,7 +226,8 @@ def eligible(m: nn.Module) -> bool: _LAYER_FILTERS = { "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "transformer.blocks" in fqn and ".mlp." in fqn, - "attention": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and ".attn." in fqn, + "attn": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and ".attn." in fqn,\ + "mlp_and_attn": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "transformer.blocks" in fqn and (".mlp." in fqn or ".attn." in fqn), "all": lambda mod, fqn: isinstance(mod, torch.nn.Linear), } @@ -234,10 +239,18 @@ def apply_ptq_model(model, config: str, layers: str = "mlp"): qconfig = Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") elif config == "int8_weights": qconfig = Int8WeightOnlyConfig() + elif config == "int_w8a8": + qconfig = Int8DynamicActivationInt8WeightConfig() + elif config == "int4w_int8a": + qconfig = Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4) + elif config == "int4w_fp8a": + qconfig = Float8DynamicActivationInt4WeightConfig() + elif config == "fp_w8a8": + qconfig = Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + # elif config == "fp_w4a4": + # qconfig = NVFP4DynamicActivationNVFP4WeightConfig() elif config == "fp8_weights": qconfig = Float8WeightOnlyConfig() - elif config == "f8aw": - qconfig = Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor ()) else: raise ValueError(f"Unknown quant_config: {config!r}") quantize_(model, qconfig, filter_fn=filter_fn) diff --git a/src/world_engine.py b/src/world_engine.py index 41cab59..9ebb9e0 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -78,9 +78,9 @@ def __init__( if quant is not None: print(f"Applying {quant} PTQ...") # apply_qat(self.model, quant_config="bf16", layers="mlp", step="convert") - quantize_model(self.model, quant) + # quantize_model(self.model, quant) # apply_qat(self.model, quant_config=quant, layers="all", step="convert") - # apply_ptq_model(self.model, quant, layers="all") + apply_ptq_model(self.model, quant, layers="mlp_and_attn") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) From f898e8e451e7b1763eb68bb8fa7d12bceb98817b Mon Sep 17 00:00:00 2001 From: anm-ol Date: Wed, 18 Mar 2026 21:55:51 +0000 Subject: [PATCH 12/20] quant none --- examples/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index c74f65c..ce83b1b 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -57,7 +57,7 @@ def get_warm_engine(model_uri, model_overrides=None): model_uri, model_config_overrides=model_config_overrides, # quant=model_config_overrides.get("quant"), - quant="int_w8a8", + quant=None, device="cuda", load_weights=False ) From b5cab413695d51ad2d44748d96ae7b247e5f97db Mon Sep 17 00:00:00 2001 From: anm-ol Date: Thu, 19 Mar 2026 21:28:47 +0000 Subject: [PATCH 13/20] int8 gemlite implementation --- examples/benchmark.py | 2 +- examples/gen_sample.py | 2 +- src/quantize.py | 189 ++++++++++++++++++++++++++++++++++++++++- src/world_engine.py | 4 +- 4 files changed, 190 insertions(+), 7 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index ce83b1b..35d1c68 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -57,7 +57,7 @@ def get_warm_engine(model_uri, model_overrides=None): model_uri, model_config_overrides=model_config_overrides, # quant=model_config_overrides.get("quant"), - quant=None, + quant="w8a8_gemlite", device="cuda", load_weights=False ) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 14d6cd8..edcb392 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -24,7 +24,7 @@ engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, # quant="int4_weights", - quant=None, + quant="w8a8_gemlite", device="cuda") total_linear_params = sum(mod.weight.numel() for _, mod in engine.model.named_modules() if isinstance(mod, torch.nn.Linear)) diff --git a/src/quantize.py b/src/quantize.py index f3b77d8..12afe58 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -12,6 +12,22 @@ QUANTS.append("nvfp4") except ImportError: pass +try: + from sgl_kernel import int8_scaled_mm as sgl_int8_scaled_mm + if "w8a8" not in QUANTS: + QUANTS.append("w8a8") +except ImportError: + sgl_int8_scaled_mm = None +try: + from gemlite.helper import A8W8_INT8_dynamic + import gemlite + gemlite.set_autotune("max") +except ImportError: + A8W8_INT8_dynamic = None +try: + from lmdeploy.pytorch.models.q_modules import QLinear +except ImportError: + QLinear = None @torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) @@ -182,7 +198,171 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) return result.reshape(x.shape[:-1] + (-1,)) - + +def _per_token_quant_int8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Local per-token symmetric int8 quantization matching SGLang's W8A8 flow: + scale = absmax / 127 + x_q = round(x / scale) + Returns: + x_q: [..., K] int8 + scales: [..., 1] float32 + """ + x_fp = x.float().nan_to_num() + scales = (x_fp.abs().amax(dim=-1, keepdim=True).clamp_min(1e-10) / 127.0).float() + x_q = torch.round(x_fp / scales).clamp(-127, 127).to(torch.int8) + return x_q, scales + + +@torch.library.custom_op("world_engine::w8a8_int8_linear", mutates_args=()) +def w8a8_int8_linear( + a: torch.Tensor, + b_int8_T: torch.Tensor, + b_scale: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + if sgl_int8_scaled_mm is None: + raise ImportError("sgl-kernel is required for quant='w8a8'") + + assert a.ndim == 2, "expected [M, K] input" + x_q, x_scale = _per_token_quant_int8(a.contiguous()) + + bias_arg = None if bias.numel() == 0 else bias + return sgl_int8_scaled_mm( + x_q, # [M, K] row-major int8 + b_int8_T, # [K, N] column-major int8 view + x_scale, # [M, 1] float32 + b_scale, # [N, 1] float32 + out_dtype=a.dtype, + bias=bias_arg, + ) + + +@w8a8_int8_linear.register_fake +def _w8a8_int8_linear_fake( + a: torch.Tensor, + b_int8_T: torch.Tensor, + b_scale: torch.Tensor, + bias: torch.Tensor, +) -> torch.Tensor: + return torch.empty( + (a.shape[0], b_int8_T.shape[1]), + device=a.device, + dtype=a.dtype, + ) + + +class W8A8Int8LinearSGLang(nn.Module): + """ + INT8 W8A8 linear using sgl-kernel's int8_scaled_mm. + Weight path: + - static per-channel symmetric int8 + - stored as a transposed [K, N] view (column-major for the kernel) + Activation path: + - dynamic per-token symmetric int8 + """ + + __constants__ = ("in_features", "out_features") + + def __init__(self, lin: nn.Linear): + super().__init__() + + if sgl_int8_scaled_mm is None: + raise ImportError("sgl-kernel is required for quant='w8a8'") + + self.in_features = lin.in_features + self.out_features = lin.out_features + + # Your current eligible() already enforces % 32, which is stricter than needed. + w = lin.weight.detach() # [N, K] + + # Per-output-channel symmetric weight quantization. + w_scale = ( + w.float() + .abs() + .nan_to_num() + .amax(dim=1, keepdim=True) + .clamp_min(1e-10) + / 127.0 + ).float() # [N, 1] + + w_q = torch.round(w.float() / w_scale).clamp(-127, 127).to(torch.int8) # [N, K] + + # IMPORTANT: keep this as a transpose view, not contiguous(). + # sgl-kernel expects mat_b to be column-major [K, N] with stride(0) == 1. + self.register_buffer("weight_int8_T", w_q.t()) # [K, N], column-major view + self.register_buffer("weight_scale", w_scale.contiguous()) # [N, 1] + + if lin.bias is None: + self.register_buffer( + "bias", + torch.empty(0, device=w.device, dtype=lin.weight.dtype), + ) + else: + self.register_buffer( + "bias", + lin.bias.detach().to(lin.weight.dtype).contiguous(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.is_cuda, "w8a8 requires CUDA" + assert x.dtype in (torch.float16, torch.bfloat16), \ + "w8a8 expects fp16/bf16 activations" + + s = x.shape + x2 = x.reshape(-1, s[-1]).contiguous() + bias = self.bias + if bias.numel() != 0 and bias.dtype != x2.dtype: + bias = bias.to(x2.dtype) + y = w8a8_int8_linear( + x2, + self.weight_int8_T, + self.weight_scale, + bias, + ) + return y.reshape(*s[:-1], self.out_features) + + +class INT8W8A8GemLite(nn.Module): + __constants__ = ("in_features", "out_features") + + def __init__(self, lin: nn.Linear): + super().__init__() + if A8W8_INT8_dynamic is None: + raise ImportError("Install gemlite for quant='w8a8_gemlite'") + + self.in_features = lin.in_features + self.out_features = lin.out_features + + # Minimal wrapper: assumes the layer is already on the target CUDA device. + self.impl = A8W8_INT8_dynamic( + device=str(lin.weight.device), + dtype=lin.weight.dtype, + ).from_linear(lin) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + s = x.shape + y = self.impl(x.reshape(-1, s[-1]).contiguous()) + return y.reshape(*s[:-1], self.out_features).to(x.dtype) + + +class INT8W8A8LMDeploy(nn.Module): + __constants__ = ("in_features", "out_features") + + def __init__(self, lin: nn.Linear): + super().__init__() + if QLinear is None: + raise ImportError("Install lmdeploy for quant='w8a8_lmdeploy'") + + self.in_features = lin.in_features + self.out_features = lin.out_features + self.impl = QLinear.from_float(lin) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + s = x.shape + y = self.impl(x.reshape(-1, s[-1]).contiguous()) + return y.reshape(*s[:-1], self.out_features).to(x.dtype) + def quantize_model(model: nn.Module, quant: str): if quant is None: @@ -198,7 +378,10 @@ def eligible(m: nn.Module) -> bool: return (o % 32 == 0) and (k % 32 == 0) new_linear = { - "w8a8": FP8W8A8Linear, + "w8a8_gemlite": INT8W8A8GemLite, + "w8a8_lmdeploy": INT8W8A8LMDeploy, + "w8a8_sglang": W8A8Int8LinearSGLang, + "fp8w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear, }[quant] @@ -209,11 +392,11 @@ def eligible(m: nn.Module) -> bool: ) return model + from torchao.quantization import (quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, - # NVFP4DynamicActivationNVFP4WeightConfig, Float8DynamicActivationInt4WeightConfig, Int8DynamicActivationIntxWeightConfig, Float8WeightOnlyConfig, diff --git a/src/world_engine.py b/src/world_engine.py index 9ebb9e0..cd7f471 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -78,9 +78,9 @@ def __init__( if quant is not None: print(f"Applying {quant} PTQ...") # apply_qat(self.model, quant_config="bf16", layers="mlp", step="convert") - # quantize_model(self.model, quant) + quantize_model(self.model, quant) # apply_qat(self.model, quant_config=quant, layers="all", step="convert") - apply_ptq_model(self.model, quant, layers="mlp_and_attn") + # apply_ptq_model(self.model, quant, layers="mlp_and_attn") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) From f0651fb8e744a1d196b017a9becfc8f560ec1e1d Mon Sep 17 00:00:00 2001 From: anm-ol Date: Fri, 20 Mar 2026 06:51:48 +0000 Subject: [PATCH 14/20] clean up, remove torchao quantization --- examples/gen_sample.py | 3 -- src/model/base_model.py | 24 ----------- src/quantize.py | 91 ++--------------------------------------- src/world_engine.py | 6 +-- 4 files changed, 4 insertions(+), 120 deletions(-) diff --git a/examples/gen_sample.py b/examples/gen_sample.py index edcb392..07d4418 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -16,14 +16,11 @@ "patch": [2, 2], "temporal_compression": 4, "inference_fps": 60, - # "quant": "int8_weights", - "quant": None, "taehv_ae": True} model_config_overrides.update({}) engine = WorldEngine(sys.argv[1], model_config_overrides=model_config_overrides, - # quant="int4_weights", quant="w8a8_gemlite", device="cuda") diff --git a/src/model/base_model.py b/src/model/base_model.py index 506eeec..02a9b4f 100644 --- a/src/model/base_model.py +++ b/src/model/base_model.py @@ -6,27 +6,6 @@ from torch import nn import torch -from ..quantize import apply_qat - -MODEL_CONFIG_DEFAULTS = OmegaConf.create( - { - "auto_aspect_ratio": True, - "gated_attn": False, - "inference_fps": "${base_fps}", - "model_type": "waypoint-1", - "n_kv_heads": "${n_heads}", - "patch": [1, 1], - "prompt_conditioning": None, - "prompt_encoder_uri": "google/umt5-xl", - "rope_nyquist_frac": 0.8, - "rope_theta": 10000.0, - "taehv_ae": False, - "temporal_compression": 1, - "value_residual": False, - } -) - - class BaseModel(nn.Module): @classmethod def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weights: bool = True): @@ -43,9 +22,6 @@ def from_pretrained(cls, path: str, cfg=None, device=None, dtype=None, load_weig cfg = cls.load_config(path) model = cls(cfg).to(dtype=dtype, device=device) - if cfg.quant is not None: - apply_qat(model, quant_config=cfg.quant, layers="mlp", step="prepare") - if load_weights: safetensors_path = os.path.join(path, "model.safetensors") model.load_state_dict(load_file(safetensors_path, device=device), strict=True) diff --git a/src/quantize.py b/src/quantize.py index 12afe58..4d724ed 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -330,7 +330,7 @@ def __init__(self, lin: nn.Linear): super().__init__() if A8W8_INT8_dynamic is None: raise ImportError("Install gemlite for quant='w8a8_gemlite'") - + self.in_features = lin.in_features self.out_features = lin.out_features @@ -341,10 +341,7 @@ def __init__(self, lin: nn.Linear): ).from_linear(lin) def forward(self, x: torch.Tensor) -> torch.Tensor: - s = x.shape - y = self.impl(x.reshape(-1, s[-1]).contiguous()) - return y.reshape(*s[:-1], self.out_features).to(x.dtype) - + return self.impl(x).type_as(x) class INT8W8A8LMDeploy(nn.Module): __constants__ = ("in_features", "out_features") @@ -390,86 +387,4 @@ def eligible(m: nn.Module) -> bool: setattr(model, name, new_linear(child)) if eligible(child) else quantize_model( child, quant ) - return model - - -from torchao.quantization import (quantize_, - Int4WeightOnlyConfig, - Int8WeightOnlyConfig, - Int8DynamicActivationInt8WeightConfig, - Float8DynamicActivationInt4WeightConfig, - Int8DynamicActivationIntxWeightConfig, - Float8WeightOnlyConfig, - Float8DynamicActivationFloat8WeightConfig, - PerTensor, PerRow) -from torchao.quantization.quantize_.workflows import Int4PackingFormat, Float8PackingFormat -from torchao.quantization.qat import (QATConfig, - IntxFakeQuantizeConfig, - Float8FakeQuantizeConfig) - -_LAYER_FILTERS = { - "mlp": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "transformer.blocks" in fqn and ".mlp." in fqn, - "attn": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and ".attn." in fqn,\ - "mlp_and_attn": lambda mod, fqn: isinstance(mod, torch.nn.Linear) and "transformer.blocks" in fqn and (".mlp." in fqn or ".attn." in fqn), - "all": lambda mod, fqn: isinstance(mod, torch.nn.Linear), -} - -def apply_ptq_model(model, config: str, layers: str = "mlp"): - """Apply PTQ in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" - filter_fn = _LAYER_FILTERS.get(layers) if layers else None - - if config == "int4_weights": - qconfig = Int4WeightOnlyConfig(group_size=32, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq") - elif config == "int8_weights": - qconfig = Int8WeightOnlyConfig() - elif config == "int_w8a8": - qconfig = Int8DynamicActivationInt8WeightConfig() - elif config == "int4w_int8a": - qconfig = Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4) - elif config == "int4w_fp8a": - qconfig = Float8DynamicActivationInt4WeightConfig() - elif config == "fp_w8a8": - qconfig = Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) - # elif config == "fp_w4a4": - # qconfig = NVFP4DynamicActivationNVFP4WeightConfig() - elif config == "fp8_weights": - qconfig = Float8WeightOnlyConfig() - else: - raise ValueError(f"Unknown quant_config: {config!r}") - quantize_(model, qconfig, filter_fn=filter_fn) - -def apply_qat(model, quant_config: str = "fp8_general", layers: str = None, step: str = "prepare"): - """Apply QAT in-place. layers: 'mlp', 'attention', or None for all Linear layers.""" - filter_fn = _LAYER_FILTERS.get(layers) if layers else None - - if step == "prepare": - if quant_config == "fp8_general": - weight_config = Float8FakeQuantizeConfig(dtype=torch.float8_e4m3fn, granularity=PerTensor()) - qconfig = QATConfig(weight_config=weight_config, step=step) - elif quant_config == "int8_weights": - weight_config = IntxFakeQuantizeConfig(torch.int8, group_size=32, is_symmetric=True) - qconfig = QATConfig(weight_config=weight_config, step=step) - elif quant_config == "int4_weights": - config = Int4WeightOnlyConfig( - group_size=32, - int4_packing_format=Int4PackingFormat.PRESHUFFLED - ) - qconfig = QATConfig(base_config=config, step=step) - else: - raise ValueError(f"Unknown quant_config: {quant_config!r}") - - elif step == "convert": - # convert step requires a real PTQ base_config (not FakeQuantizeConfigBase) - # or None (which just strips fake-quant wrappers back to plain nn.Linear) - if quant_config == "fp8_general": - qconfig = QATConfig(base_config=Float8WeightOnlyConfig(), step=step) - elif quant_config == "int8_weights": - quantize_(model, QATConfig(step=step), filter_fn=filter_fn) # need to run quantize to convert fake quant to real quant for int8, since int8 fake quant is not a simple wrapper around int8 PTQ module - qconfig = QATConfig(base_config=Int8WeightOnlyConfig(), step=step) - elif quant_config == "int4_weights": - qconfig = QATConfig(base_config=Int4WeightOnlyConfig(group_size=32, int4_packing_format=Int4PackingFormat.PRESHUFFLED), step=step) - elif quant_config == "bf16": - qconfig = QATConfig(step=step) - else: - raise ValueError(f"Unknown quant_config: {quant_config!r}") - quantize_(model, qconfig, filter_fn=filter_fn) \ No newline at end of file + return model \ No newline at end of file diff --git a/src/world_engine.py b/src/world_engine.py index cd7f471..e8f4492 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -45,7 +45,7 @@ def __init__( ): """ model_uri: HF URI or local folder containing model.safetensors and config.yaml - quant: None | fp8_weights | int8_weights | int4_weights + quant: None | w8a8_gemlite | fp8w8a8 | nvfp4 model_config_overrides: Dict to override model config values - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p """ @@ -76,11 +76,7 @@ def __init__( ).eval() apply_inference_patches(self.model) if quant is not None: - print(f"Applying {quant} PTQ...") - # apply_qat(self.model, quant_config="bf16", layers="mlp", step="convert") quantize_model(self.model, quant) - # apply_qat(self.model, quant_config=quant, layers="all", step="convert") - # apply_ptq_model(self.model, quant, layers="mlp_and_attn") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device) From 2480e10022109de055f19f377f43f5b81033537c Mon Sep 17 00:00:00 2001 From: anm-ol Date: Fri, 20 Mar 2026 06:53:18 +0000 Subject: [PATCH 15/20] add gemlite to requirements --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1bd9d0a..eb87976 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,12 @@ requires-python = ">=3.10" dependencies = [ "taehv @ git+https://github.com/madebyollin/taehv.git@7dc60ec6601af2e668e31bc70acc4cb3665e4c22", "torch==2.10.0", + "torchvision==0.25.0", + "torchaudio==2.10.0", + "torchao==0.16.0", + "flashinfer-python==0.6.6", + "fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'", + "gemlite==0.5.1.post1" "einops", "tensordict==0.10.0", "transformers>=5.3.0", From 60215480493946d05bc7847a0faac38a1bc9bc16 Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Fri, 20 Mar 2026 23:25:05 +0530 Subject: [PATCH 16/20] remove unused quant kernels and imports --- src/quantize.py | 103 +------------------------------------------- src/world_engine.py | 2 +- 2 files changed, 2 insertions(+), 103 deletions(-) diff --git a/src/quantize.py b/src/quantize.py index 4d724ed..1786d33 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -12,22 +12,12 @@ QUANTS.append("nvfp4") except ImportError: pass -try: - from sgl_kernel import int8_scaled_mm as sgl_int8_scaled_mm - if "w8a8" not in QUANTS: - QUANTS.append("w8a8") -except ImportError: - sgl_int8_scaled_mm = None try: from gemlite.helper import A8W8_INT8_dynamic import gemlite gemlite.set_autotune("max") except ImportError: A8W8_INT8_dynamic = None -try: - from lmdeploy.pytorch.models.q_modules import QLinear -except ImportError: - QLinear = None @torch.library.custom_op("world_engine::fp4_linear", mutates_args=()) @@ -251,78 +241,6 @@ def _w8a8_int8_linear_fake( dtype=a.dtype, ) - -class W8A8Int8LinearSGLang(nn.Module): - """ - INT8 W8A8 linear using sgl-kernel's int8_scaled_mm. - Weight path: - - static per-channel symmetric int8 - - stored as a transposed [K, N] view (column-major for the kernel) - Activation path: - - dynamic per-token symmetric int8 - """ - - __constants__ = ("in_features", "out_features") - - def __init__(self, lin: nn.Linear): - super().__init__() - - if sgl_int8_scaled_mm is None: - raise ImportError("sgl-kernel is required for quant='w8a8'") - - self.in_features = lin.in_features - self.out_features = lin.out_features - - # Your current eligible() already enforces % 32, which is stricter than needed. - w = lin.weight.detach() # [N, K] - - # Per-output-channel symmetric weight quantization. - w_scale = ( - w.float() - .abs() - .nan_to_num() - .amax(dim=1, keepdim=True) - .clamp_min(1e-10) - / 127.0 - ).float() # [N, 1] - - w_q = torch.round(w.float() / w_scale).clamp(-127, 127).to(torch.int8) # [N, K] - - # IMPORTANT: keep this as a transpose view, not contiguous(). - # sgl-kernel expects mat_b to be column-major [K, N] with stride(0) == 1. - self.register_buffer("weight_int8_T", w_q.t()) # [K, N], column-major view - self.register_buffer("weight_scale", w_scale.contiguous()) # [N, 1] - - if lin.bias is None: - self.register_buffer( - "bias", - torch.empty(0, device=w.device, dtype=lin.weight.dtype), - ) - else: - self.register_buffer( - "bias", - lin.bias.detach().to(lin.weight.dtype).contiguous(), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.is_cuda, "w8a8 requires CUDA" - assert x.dtype in (torch.float16, torch.bfloat16), \ - "w8a8 expects fp16/bf16 activations" - - s = x.shape - x2 = x.reshape(-1, s[-1]).contiguous() - bias = self.bias - if bias.numel() != 0 and bias.dtype != x2.dtype: - bias = bias.to(x2.dtype) - y = w8a8_int8_linear( - x2, - self.weight_int8_T, - self.weight_scale, - bias, - ) - return y.reshape(*s[:-1], self.out_features) - - class INT8W8A8GemLite(nn.Module): __constants__ = ("in_features", "out_features") @@ -343,23 +261,6 @@ def __init__(self, lin: nn.Linear): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.impl(x).type_as(x) -class INT8W8A8LMDeploy(nn.Module): - __constants__ = ("in_features", "out_features") - - def __init__(self, lin: nn.Linear): - super().__init__() - if QLinear is None: - raise ImportError("Install lmdeploy for quant='w8a8_lmdeploy'") - - self.in_features = lin.in_features - self.out_features = lin.out_features - self.impl = QLinear.from_float(lin) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - s = x.shape - y = self.impl(x.reshape(-1, s[-1]).contiguous()) - return y.reshape(*s[:-1], self.out_features).to(x.dtype) - def quantize_model(model: nn.Module, quant: str): if quant is None: @@ -375,9 +276,7 @@ def eligible(m: nn.Module) -> bool: return (o % 32 == 0) and (k % 32 == 0) new_linear = { - "w8a8_gemlite": INT8W8A8GemLite, - "w8a8_lmdeploy": INT8W8A8LMDeploy, - "w8a8_sglang": W8A8Int8LinearSGLang, + "intw8a8": INT8W8A8GemLite, "fp8w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear, diff --git a/src/world_engine.py b/src/world_engine.py index e8f4492..4514e7c 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -6,7 +6,7 @@ from .model import WorldModel, StaticKVCache, PromptEncoder from .ae import get_ae from .patch_model import apply_inference_patches -from .quantize import quantize_model, apply_ptq_model, apply_qat +from .quantize import quantize_model # Global torch optimizations From fc70b79dbe55c8290d8a2067a3b81d15c4a80b1c Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Fri, 20 Mar 2026 23:33:52 +0530 Subject: [PATCH 17/20] restore gen_sample.py, more cleanup --- examples/benchmark.py | 5 ++--- examples/gen_sample.py | 14 +------------- pyproject.toml | 2 -- src/world_engine.py | 2 +- 4 files changed, 4 insertions(+), 19 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 35d1c68..c9882c9 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -56,8 +56,7 @@ def get_warm_engine(model_uri, model_overrides=None): engine = WorldEngine( model_uri, model_config_overrides=model_config_overrides, - # quant=model_config_overrides.get("quant"), - quant="w8a8_gemlite", + quant="intw8a8", device="cuda", load_weights=False ) @@ -126,4 +125,4 @@ def target(): if blocking: torch.cuda.synchronize() - benchmark.pedantic(target, setup=setup, rounds=20) + benchmark.pedantic(target, setup=setup, rounds=20) \ No newline at end of file diff --git a/examples/gen_sample.py b/examples/gen_sample.py index 07d4418..8aae0c6 100644 --- a/examples/gen_sample.py +++ b/examples/gen_sample.py @@ -12,20 +12,8 @@ # Create inference engine -model_config_overrides = {"ae_uri": "Overworld-Models/taehv1_5", - "patch": [2, 2], - "temporal_compression": 4, - "inference_fps": 60, - "taehv_ae": True} +engine = WorldEngine(sys.argv[1], quant=None, device="cuda") -model_config_overrides.update({}) -engine = WorldEngine(sys.argv[1], - model_config_overrides=model_config_overrides, - quant="w8a8_gemlite", - device="cuda") - -total_linear_params = sum(mod.weight.numel() for _, mod in engine.model.named_modules() if isinstance(mod, torch.nn.Linear)) -print(f"Total linear layer parameters: {total_linear_params:,}") # Define sequence of controller inputs applied controller_sequence = [ diff --git a/pyproject.toml b/pyproject.toml index eb87976..1a195a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,6 @@ dependencies = [ "torch==2.10.0", "torchvision==0.25.0", "torchaudio==2.10.0", - "torchao==0.16.0", - "flashinfer-python==0.6.6", "fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'", "gemlite==0.5.1.post1" "einops", diff --git a/src/world_engine.py b/src/world_engine.py index 4514e7c..713d048 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -45,7 +45,7 @@ def __init__( ): """ model_uri: HF URI or local folder containing model.safetensors and config.yaml - quant: None | w8a8_gemlite | fp8w8a8 | nvfp4 + quant: None | intw8a8 | fp8w8a8 | nvfp4 model_config_overrides: Dict to override model config values - auto_aspect_ratio: set to False to work in ae raw space, otherwise in/out are 720p or 360p """ From 8d0e665484d576027eaa6b99bac674ed67ac6c3e Mon Sep 17 00:00:00 2001 From: Anmol Agarwal Date: Fri, 20 Mar 2026 23:48:36 +0530 Subject: [PATCH 18/20] update readme with Quantization docs --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 8d68292..005129d 100644 --- a/README.md +++ b/README.md @@ -118,6 +118,21 @@ img = pipeline.append_frame(uint8_img) # returns passed image Note: returned `img` is always on the same device as `engine.device` +## Quantization + +Model can be quantized by passing quant argument to WorldEngine +``` +engine = WorldEngine("Overworld/Waypoint-1.5-1B", quant="intw8a8", device="cuda") +``` +Supported inference quantization schemes are: + +| Config | Description | Supported GPUs | +|--------|-------------|----------------| +| `intw8a8` | INT8 weights + INT8 dynamic per-token activations | NVIDIA (30xx, 40xx, Ampere+) | +| `fp8w8a8` | FP8 (e4m3) weights + FP8 per-tensor activations via `torch._scaled_mm` | NVIDIA Ada Lovelace / Hopper+ (RTX 40xx, H100) | +| `nvfp4` | NVFP4 weights + FP4 activations via FlashInfer/CUTLASS | NVIDIA Blackwell (B100, B200, RTX 5090) | + + ### WorldEngine `WorldEngine` computes each new frame from past frames, the controls, and the current prompt, then appends it to the sequence so later frames stay aligned with what has already been generated. From 5989fb4a9d51038e753e14cf84ccf537aa05801e Mon Sep 17 00:00:00 2001 From: anm-ol Date: Fri, 20 Mar 2026 18:46:24 +0000 Subject: [PATCH 19/20] fixed requirements gemlite --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1a195a1..b6c41eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "torchvision==0.25.0", "torchaudio==2.10.0", "fbgemm-gpu-genai==1.5.0; sys_platform == 'linux'", - "gemlite==0.5.1.post1" + "gemlite", "einops", "tensordict==0.10.0", "transformers>=5.3.0", From 5c279c7612a876de113ff0d824a3bed04b492ed0 Mon Sep 17 00:00:00 2001 From: anm-ol Date: Fri, 27 Mar 2026 06:49:13 +0000 Subject: [PATCH 20/20] WorldEngine takes flag for smoothquant, raise error if not applicable to model uri --- src/quantize.py | 44 +++++++++++++++++++++++++++++++++++++------- src/world_engine.py | 7 +++++-- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/quantize.py b/src/quantize.py index 1786d33..37c06dd 100644 --- a/src/quantize.py +++ b/src/quantize.py @@ -1,4 +1,5 @@ from typing import Optional +from functools import partial import torch import torch.nn as nn @@ -114,7 +115,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FP8W8A8Linear(nn.Module): __constants__ = ("in_features", "out_features") - def __init__(self, lin: nn.Linear): + def __init__(self, lin: nn.Linear, smoothquant: bool = False): super().__init__() self.in_features, self.out_features = lin.in_features, lin.out_features @@ -133,10 +134,25 @@ def __init__(self, lin: nn.Linear): else: self.register_buffer("bias", lin.bias.detach().to(torch.float16)) + smooth = getattr(lin, "_smooth_scale", None) + if smoothquant and smooth is None: + raise ValueError( + f"smoothquant=True but this checkpoint has no _smooth_scale on " + f"{type(lin).__name__}(in={lin.in_features}, out={lin.out_features}). " + "SmoothQuant cannot be applied to this model checkpoint." + ) + if smooth is not None: + self.register_buffer("_smooth_scale", smooth.detach()) + else: + self._smooth_scale = None + def forward(self, x: torch.Tensor) -> torch.Tensor: s = x.shape x2 = x.reshape(-1, s[-1]) + if self._smooth_scale is not None: + x2 = x2 * self._smooth_scale + xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float() # 0-d xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous() @@ -244,11 +260,11 @@ def _w8a8_int8_linear_fake( class INT8W8A8GemLite(nn.Module): __constants__ = ("in_features", "out_features") - def __init__(self, lin: nn.Linear): + def __init__(self, lin: nn.Linear, smoothquant: bool = False): super().__init__() if A8W8_INT8_dynamic is None: raise ImportError("Install gemlite for quant='w8a8_gemlite'") - + self.in_features = lin.in_features self.out_features = lin.out_features @@ -258,11 +274,25 @@ def __init__(self, lin: nn.Linear): dtype=lin.weight.dtype, ).from_linear(lin) + smooth = getattr(lin, "_smooth_scale", None) + if smoothquant and smooth is None: + raise ValueError( + f"smoothquant=True but this checkpoint has no _smooth_scale on " + f"{type(lin).__name__}(in={lin.in_features}, out={lin.out_features}). " + "SmoothQuant cannot be applied to this model checkpoint." + ) + if smooth is not None: + self.register_buffer("_smooth_scale", smooth.detach()) + else: + self._smooth_scale = None + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._smooth_scale is not None: + x = x * self._smooth_scale return self.impl(x).type_as(x) -def quantize_model(model: nn.Module, quant: str): +def quantize_model(model: nn.Module, quant: str, smoothquant: bool = False): if quant is None: return model @@ -276,14 +306,14 @@ def eligible(m: nn.Module) -> bool: return (o % 32 == 0) and (k % 32 == 0) new_linear = { - "intw8a8": INT8W8A8GemLite, - "fp8w8a8": FP8W8A8Linear, + "intw8a8": partial(INT8W8A8GemLite, smoothquant=smoothquant), + "fp8w8a8": partial(FP8W8A8Linear, smoothquant=smoothquant), "nvfp4": FP4Linear, "fp8": FP8Linear, }[quant] for name, child in model.named_children(): setattr(model, name, new_linear(child)) if eligible(child) else quantize_model( - child, quant + child, quant, smoothquant ) return model \ No newline at end of file diff --git a/src/world_engine.py b/src/world_engine.py index 713d048..7321efc 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -25,7 +25,7 @@ # "shape_padding": True, } - +import time @dataclass class CtrlInput: button: Set[int] = field(default_factory=set) # pressed button IDs @@ -38,6 +38,7 @@ def __init__( self, model_uri: str, quant: Optional[str] = None, + smooth: bool = False, model_config_overrides: Optional[Dict] = None, device=None, dtype=torch.bfloat16, @@ -76,7 +77,9 @@ def __init__( ).eval() apply_inference_patches(self.model) if quant is not None: - quantize_model(self.model, quant) + start_time = time.time() + quantize_model(self.model, quant, smoothquant=smooth) + print(f"Quantization took {time.time() - start_time:.2f}s") self.kv_cache = StaticKVCache(self.model_cfg, batch_size=1, dtype=dtype).to(device=device)