Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/scope/cloud/fal_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,16 @@ def setup(self):
print(f"GPU check failed: {e}")
raise

# Log CUDA environment so failures in plugin pipelines (e.g. flashvsr)
# that surface as "No CUDA GPUs are available" can be correlated with
# the worker configuration seen at startup time.
cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "<not set>")
nv_vis = os.environ.get("NVIDIA_VISIBLE_DEVICES", "<not set>")
print(
f"CUDA env at startup: CUDA_VISIBLE_DEVICES={cvd!r} "
f"NVIDIA_VISIBLE_DEVICES={nv_vis!r}"
)

# Environment for scope - whitelist only necessary variables (security)
ENV_WHITELIST = [
# Required for process execution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,19 +457,32 @@ def qkv_fn(x):
# If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache
kv_cache_size = kv_cache["k"].shape[1]
num_new_tokens = roped_query.shape[1]
# Normalize cache indices to Python ints. On the very first chunk after
# a cache reset, initialize_kv_cache() stores torch.tensor([0], ...) in
# these slots. If we leave them as tensors, all subsequent arithmetic
# (local_end_index, cache_current_block_start, …) also becomes tensors.
# When cache_current_block_start is captured as a tensor in score_mod and
# passed to torch.compile(flex_attention, dynamic=False), flex_attention
# tries to re-trace score_mod on every chunk because the captured tensor
# *object* identity changes, which triggers:
# "Detected that you are using FX to symbolically trace a
# dynamo-optimized function." (+ _dispatch_keys TypeError)
# int() is safe for both Python ints and single-element torch.Tensors.
cache_global_end: int = int(kv_cache["global_end_index"])
cache_local_end: int = int(kv_cache["local_end_index"])
if (
self.local_attn_size != -1
and (current_end > kv_cache["global_end_index"])
and (num_new_tokens + kv_cache["local_end_index"] > kv_cache_size)
and (current_end > cache_global_end)
and (num_new_tokens + cache_local_end > kv_cache_size)
):
# Calculate the number of new tokens added in this step
# Shift existing cache content left to discard oldest tokens
# Clone the source slice to avoid overlapping memory error
num_evicted_tokens = (
num_new_tokens + kv_cache["local_end_index"] - kv_cache_size
num_new_tokens + cache_local_end - kv_cache_size
)
num_rolled_tokens = (
kv_cache["local_end_index"] - num_evicted_tokens - sink_tokens
cache_local_end - num_evicted_tokens - sink_tokens
)
kv_cache["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = (
kv_cache["k"][
Expand All @@ -489,9 +502,9 @@ def qkv_fn(x):
)
# Insert the new keys/values at the end
local_end_index = (
kv_cache["local_end_index"]
cache_local_end
+ current_end
- kv_cache["global_end_index"]
- cache_global_end
- num_evicted_tokens
)
local_start_index = local_end_index - num_new_tokens
Expand All @@ -500,9 +513,9 @@ def qkv_fn(x):
else:
# Assign new keys/values directly up to current_end
local_end_index = (
kv_cache["local_end_index"]
cache_local_end
+ current_end
- kv_cache["global_end_index"]
- cache_global_end
)
local_start_index = local_end_index - num_new_tokens
kv_cache["k"][:, local_start_index:local_end_index] = roped_key
Expand Down Expand Up @@ -541,24 +554,30 @@ def qkv_fn(x):
cached_v, target_padded_length, pad_dim=1
)

# Convert scalars to tensors to avoid ShapeAsConstantBuffer dtype issues during compilation
# This is critical when using torch.compile with flex_attention
frame_seqlen_tensor = torch.as_tensor(
frame_seqlen, dtype=torch.int32, device=roped_query.device
)
cache_current_block_start_tensor = torch.as_tensor(
cache_current_block_start, dtype=torch.int32, device=roped_query.device
).squeeze()
log_scale_tensor = torch.as_tensor(
log_scale, dtype=roped_query.dtype, device=roped_query.device
)
# Use Python scalar literals (int/float) as constants in score_mod.
# Capturing freshly-created CUDA tensors caused two errors:
# 1. FX symbolic-trace error: torch.compile(flex_attention, dynamic=False)
# tries to re-trace score_mod when captured tensor *objects* change
# (cache_current_block_start shifts each chunk), and the FX tracer hits
# the already-compiled flex_attention, raising:
# "Detected that you are using FX to symbolically trace a
# dynamo-optimized function."
# 2. _dispatch_keys TypeError: FakeTensors (used during trace) collide
# with real CUDA tensors captured in the closure.
# Python scalars become stable graph constants, avoiding both issues.
# The old tensor-conversion workaround targeted a ShapeAsConstantBuffer bug
# in pre-2.9 PyTorch; that bug is not present in torch>=2.9.
_fs: int = frame_seqlen
_ccbs: int = cache_current_block_start
_ls: float = log_scale

def score_mod(score, b_idx, h_idx, q_idx, kv_idx):
# Apply bias only to past frames (exclude first frame and current block)
# Apply bias only to past frames (exclude first frame and current block).
# kv_idx is an int32 index scalar supplied by flex_attention; Python int
# comparisons are safe and compile cleanly without tensor captures.
return torch.where(
(kv_idx >= frame_seqlen_tensor)
& (kv_idx < cache_current_block_start_tensor),
score + log_scale_tensor,
(kv_idx >= _fs) & (kv_idx < _ccbs),
score + _ls,
score,
)

Expand Down
56 changes: 56 additions & 0 deletions src/scope/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,53 @@ def get_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _assert_cuda_accessible() -> None:
"""Raise RuntimeError with a clear message if CUDA cannot actually be used.

``torch.cuda.is_available()`` only checks that the CUDA *runtime* is
installed; it does **not** guarantee that a physical GPU is visible. On
fal.ai GPU workers that use MIG partitions or that set
``CUDA_VISIBLE_DEVICES`` to an unexpected value the check passes but any
subsequent attempt to allocate a CUDA tensor raises
"No CUDA GPUs are available".

This helper forces lazy CUDA initialisation early so that the error surface
is a clean, actionable exception rather than a cryptic failure buried deep
inside a plugin's ``__init__``.
"""
import os

if not torch.cuda.is_available():
n_devs = torch.cuda.device_count()
cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "<not set>")
raise RuntimeError(
f"No CUDA GPUs are available (device_count={n_devs}, "
f"CUDA_VISIBLE_DEVICES={cvd!r}). "
"Check that the worker has a visible GPU and that "
"CUDA_VISIBLE_DEVICES is set correctly."
)

# is_available() returned True — now do a real device-count check and a
# tiny test allocation to catch cases where CUDA context init will fail
# (e.g. empty CUDA_VISIBLE_DEVICES, invalid MIG UUID, driver mismatch).
n_devs = torch.cuda.device_count()
cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "<not set>")
if n_devs == 0:
raise RuntimeError(
f"No CUDA GPUs are available (device_count=0, "
f"CUDA_VISIBLE_DEVICES={cvd!r}). "
"CUDA runtime is installed but no devices are visible."
)

try:
_ = torch.zeros(1, device="cuda")
except RuntimeError as exc:
raise RuntimeError(
f"CUDA device_count={n_devs} but test tensor allocation failed "
f"(CUDA_VISIBLE_DEVICES={cvd!r}): {exc}"
) from exc


class PipelineNotAvailableException(Exception):
"""Exception raised when pipeline is not available for processing."""

Expand Down Expand Up @@ -733,6 +780,15 @@ def _load_pipeline_implementation(
logger.info(f"Loading plugin pipeline: {pipeline_id}")
if stage_callback:
stage_callback("Initializing pipeline...")

# Validate that CUDA is actually accessible before handing off to
# the plugin. Plugin __init__ methods often allocate CUDA tensors
# immediately (model loads, warmup passes) and the generic
# "No CUDA GPUs are available" error they produce is hard to trace.
# _assert_cuda_accessible() surfaces the problem early with extra
# diagnostic context (device_count, CUDA_VISIBLE_DEVICES).
_assert_cuda_accessible()

config_class = pipeline_class.get_config_class()
# Get defaults from schema fields
schema_defaults = {}
Expand Down
Loading