Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,68 @@ def sync_stream(device, stream):
return
current_stream(device).wait_stream(stream)

def use_tiled_vae_decode(memory_needed, device=None):
try:
if device is None:
device = get_torch_device()

# If running everything on CPU, no GPU memory check needed
if cpu_state == CPUState.CPU or args.cpu_vae:
return False

inference_memory = minimum_inference_memory()
memory_required = max(inference_memory, memory_needed + extra_reserved_memory())

gpu_free = get_free_memory(device)
cpu_free = psutil.virtual_memory().available

# Check if GPU have enough space for full decode (with reserves)?
if gpu_free >= memory_required:
return False

# With --gpu-only, models can't offload to CPU (offload device = GPU)
if args.gpu_only:
return True

# Calculate memory_to_free
memory_to_free = memory_required - gpu_free

# Calculate how much we can offload from currently loaded models - only count models whose offload_device is CPU
# With --highvram, UNet has offload_device=GPU so it CAN'T be offloaded.
loaded_model_memory = 0
cpu_offloadable_memory = 0

for loaded_model in current_loaded_models:
if loaded_model.device == device:
model_size = loaded_model.model_loaded_memory()
loaded_model_memory += model_size
if hasattr(loaded_model.model, 'offload_device'):
offload_dev = loaded_model.model.offload_device
if is_device_cpu(offload_dev):
cpu_offloadable_memory += model_size
else:
cpu_offloadable_memory += model_size

# Check is there enough to offload (to CPU)?
if cpu_offloadable_memory < memory_to_free:
return True # Can't offload enough, must tile

# Check if CPU can receive the offload (which prevents 0xC0000005 crash)
# Smart Memory ON (default) - partial offload: only memory_to_free bytes move to CPU
# Smart Memory OFF (--disable-smart-memory) - full offload: ALL models get fully unloaded
if DISABLE_SMART_MEMORY:
if cpu_free < loaded_model_memory:
return True
else:
return False
else:
# With smart memory, only partial offload (memory_to_free bytes) moves to CPU
if cpu_free < memory_to_free:
return True
return False

except Exception as e:
return True

def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = nullcontext()
Expand Down
61 changes: 42 additions & 19 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,29 +927,52 @@ def decode(self, samples_in, vae_options={}):
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)

for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.

memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)

# Memory check: switch to tiled decode if GPU can't fit full decode
# and models can't offload to CPU, preventing 0xC0000005 crash
if model_management.use_tiled_vae_decode(memory_used, self.device):
logging.warning("Insufficient memory for regular VAE decoding, switching to tiled VAE decoding.")
do_tile = True

if not do_tile:
try:
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)

for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out

except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
Copy link
Copy Markdown
Contributor

@asagi4 asagi4 Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to set do_tile = True here to actually do the tiled VAE retry.

I think this patch would be fairly helpful on AMD especially. Some VAE VRAM estimates with AMD seem to be kind of bonkers; the Flux VAE requests 11.6GB of VRAM to decode a 1 megapixel image and somehow I don't think it actually uses anywhere near that much.

EDIT: I just did a quick memory dump after a VAE decode. Torch maximum memory usage was about 6.6GB, and that would probably include the loaded VAE model and anything else that might be in VRAM. I'm not sure how to accurately tell what the actual VAE decoding used, but clearly not 11.6GB

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, do_tile should be set to True in this case. I pushed new commit with appropriate changes.

do_tile = True

if do_tile:
dims = samples_in.ndim - 2
if dims == 1:
tile_shape = (1, samples_in.shape[1], 128) # 1D tile estimate
elif dims == 2:
tile_shape = (1, samples_in.shape[1], 64, 64) # 2D tile: 64x64
else:
tile = 256 // self.spacial_compression_decode()
tile_shape = (1, samples_in.shape[1], 8, tile, tile) # 3D tile estimate

# Calculate tile memory
tile_memory = self.memory_used_decode(tile_shape, self.vae_dtype)

model_management.load_models_gpu([self.patcher], memory_required=tile_memory, force_full_load=self.disable_offload)

if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
Expand Down