Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
daaeb5c
Reduce RAM and compute time in model saving with Loras
rattus128 Dec 31, 2025
db78623
ops: Do bias dtype conversion on compute stream
rattus128 Dec 22, 2025
a08aed2
mm: Implement cast buffer allocations
rattus128 Dec 23, 2025
37567cb
move string_to_seed to utils.py
rattus128 Jan 8, 2026
3c2ce0d
pinned_memory: add python
rattus128 Jan 13, 2026
b6fd3dc
mp: wrap get_free_memory
rattus128 Jan 13, 2026
13a7b68
mp/mm: APi expansions for dynamic loading
rattus128 Jan 13, 2026
594b472
mp: add mode for non comfy weight prioritization
rattus128 Jan 13, 2026
6a8255f
ops/mp: implement aimdo
rattus128 Jan 13, 2026
c862c42
models: Use CoreModelPatcher
rattus128 Jan 13, 2026
469d7a6
execution: add aimdo primary pytorch cache integration
rattus128 Jan 13, 2026
04bf6ef
main: Go live with --fast dynamic_vram
rattus128 Jan 13, 2026
ff434ea
mm: fix sync
rattus128 Jan 13, 2026
e2d62b8
write better tx commentary
rattus128 Jan 13, 2026
e8c9977
add missing del on unpin
rattus128 Jan 13, 2026
7a18963
misc cleanup
rattus128 Jan 13, 2026
01ca403
ruff
rattus128 Jan 13, 2026
9f701f6
sd: empty cache on tiler fallback
rattus128 Jan 13, 2026
0983fb8
clip: support assign load when taking clip from a ckpt
rattus128 Jan 15, 2026
f302177
sampling: improve progress meter accuracy for dynamic loading
rattus128 Jan 15, 2026
3908056
main: Rework aimdo into process
rattus128 Jan 15, 2026
5684c67
aimdo version bump
rattus128 Jan 15, 2026
b0580b8
remove junk arg
rattus128 Jan 15, 2026
2f29e21
ops: defer creation of the parameters until state dict load
rattus128 Jan 18, 2026
cecf8c5
implement lightweight safetensors with READ mmap
rattus128 Jan 18, 2026
607d15c
execution: remove per node gc.collect()
rattus128 Jan 20, 2026
322d917
mm: remove left over hooks draft code
rattus128 Jan 20, 2026
f3854f6
mp: handle blank __new__ call
rattus128 Jan 20, 2026
e54440a
nodes_model_patch: fix copy-paste coding error
rattus128 Jan 20, 2026
12263b7
ruff
rattus128 Jan 21, 2026
49809b7
mp: big bump on the VBAR sizes
rattus128 Jan 21, 2026
d1778d8
archive the model defined dtypes
rattus128 Jan 21, 2026
36c7652
ops: fix __init__ return
rattus128 Jan 21, 2026
ede3d4b
MPDynamic: Add support for model defined dtype
rattus128 Jan 21, 2026
355172f
remove bad pyt2.4 versions gate
rattus128 Jan 24, 2026
8bb291b
disable async pin population
rattus128 Jan 25, 2026
4c875a2
fix syncs
rattus128 Jan 25, 2026
f98c86c
add missing signature set for non comfy
rattus128 Jan 26, 2026
2a76ec6
fix missing import
rattus128 Jan 27, 2026
04141ef
mm: Dont GPU load models
rattus128 Jan 27, 2026
cd08531
ops: dont discard pins
rattus128 Jan 27, 2026
101367b
mm: redefine free memory for Windows
rattus128 Jan 27, 2026
dff1ee9
free dynamic pins properly
rattus128 Jan 27, 2026
f8f9a89
bump aimdo to 1.4
rattus128 Jan 27, 2026
8067cb4
mm: dont clear_cache with mempools
rattus128 Jan 28, 2026
bc80f78
Fix ram freeing logic
rattus128 Jan 29, 2026
b1eb25b
Go back to pre-pins
rattus128 Jan 29, 2026
46f9ac1
bump aimdo
rattus128 Jan 29, 2026
74584f6
fixes to pinning rework
rattus128 Jan 30, 2026
58fd609
bump aimdo
rattus128 Jan 30, 2026
882a3bc
remove bad assertion
rattus128 Jan 31, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions comfy/audio_encoders/audio_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(self, config):
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
4 changes: 4 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
DynamicVRAM = "dynamic_vram"

parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))

Expand Down Expand Up @@ -257,3 +258,6 @@ def is_valid_directory(path: str) -> str:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)

def enables_dynamic_vram():
return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
4 changes: 2 additions & 2 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(self, json_config):
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
2 changes: 1 addition & 1 deletion comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())

self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling
Expand Down
33 changes: 32 additions & 1 deletion comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
import math
import time
from functools import partial

from scipy import integrate
import torch
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
from tqdm.auto import trange as trange_, tqdm

from . import utils
from . import deis
from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling

import comfy.memory_management


def trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None:
return trange_(*args, **kwargs)

pbar = trange_(*args, **kwargs, smoothing=1.0)
pbar._i = 0
pbar.set_postfix_str(" Model Initializing ... ")

_update = pbar.update

def warmup_update(n=1):
pbar._i += 1
if pbar._i == 1:
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")

_update(n)

pbar.update = warmup_update
return pbar


def append_zero(x):
return torch.cat([x, x.new_zeros([1])])

Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/hunyuan_video/upsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def __init__(self, model_type, config):
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())

def get_sd(self):
return self.model.state_dict()
Expand Down
81 changes: 81 additions & 0 deletions comfy/memory_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import math
import torch
from typing import NamedTuple

from comfy.quant_ops import QuantizedTensor

class TensorGeometry(NamedTuple):
shape: any
dtype: torch.dtype

def element_size(self):
info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
return info.bits // 8

def numel(self):
return math.prod(self.shape)

def tensors_to_geometries(tensors, dtype=None):
geometries = []
for t in tensors:
if t is None or isinstance(t, QuantizedTensor):
geometries.append(t)
continue
tdtype = t.dtype
if hasattr(t, "_model_dtype"):
tdtype = t._model_dtype
if dtype is not None:
tdtype = dtype
geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
return geometries

def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])

if isinstance(tensor, QuantizedTensor):
inner_tensors, _ = tensor.__tensor_flatten__()
return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])

if tensor is None:
return 0

size = tensor.numel() * tensor.element_size()
aligment_req = 1024
return (size + aligment_req - 1) // aligment_req * aligment_req

def interpret_gathered_like(tensors, gathered):
offset = 0
dest_views = []

if gathered.dim() != 1 or gathered.element_size() != 1:
raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")

for tensor in tensors:

if tensor is None:
dest_views.append(None)
continue

if isinstance(tensor, QuantizedTensor):
inner_tensors, qt_ctx = tensor.__tensor_flatten__()
templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
else:
templates = { "data": tensor }

actuals = {}
for attr, template in templates.items():
size = template.numel() * template.element_size()
if offset + size > gathered.numel():
raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
offset += vram_aligned_size(template)

if isinstance(tensor, QuantizedTensor):
dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
else:
dest_views.append(actuals["data"])

return dest_views

aimdo_allocator = None
15 changes: 7 additions & 8 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

comfy.model_management.archive_model_dtypes(self.diffusion_model)

self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
Expand Down Expand Up @@ -299,15 +301,15 @@ def extra_conds(self, **kwargs):

return out

def load_model_weights(self, sd, unet_prefix=""):
def load_model_weights(self, sd, unet_prefix="", assign=False):
to_load = {}
keys = list(sd.keys())
for k in keys:
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

Expand All @@ -322,18 +324,15 @@ def process_latent_in(self, latent):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)

def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])

Expand Down Expand Up @@ -776,8 +775,8 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]
Expand Down
Loading