Rename buffer config keys and add mixed-precision dtype zones#44
Rename buffer config keys and add mixed-precision dtype zones#44
Conversation
…filter Ray actor options - min_training_samples → min_samples, max_training_samples → max_samples - validation_window_size → validation_samples - resample_batch_size → simulate_count - keep_resampling → simulate_when_full, resample_interval → simulate_interval - chunk_size (in ray:) → sample_chunk_size (node level) - Filter non-Ray keys from actor_config to prevent ValueError - Support deterministic intermediate nodes in inference graph (BFS expansion) - Rename internal graph attributes: parents_dict → forward_deps, sorted_node_names → forward_order, sorted_inference_node_names → backward_order Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Embedding pipeline and posterior no longer inherit float64 from numpy theta. Neural network layers (TransformerEmbedding, MLP) run in float32 for tensor-core speed; parameter-space operations (covariance, sampling) stay float64 for precision. - Rename LazyOnlineNorm → RunningNorm with 3D reduce_dims and output_dtype - GaussianPosterior: override to() to protect float64 buffers, cast MLP output to float64 before de-whitening - base.py: remove dtype forcing, cast conditions to float32 for embeddings Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #44 +/- ##
========================================
- Coverage 9.69% 9.58% -0.11%
========================================
Files 32 32
Lines 3797 3839 +42
========================================
Hits 368 368
- Misses 3429 3471 +42
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
falcon/embeddings/norms.py
Outdated
| self.initialized = False | ||
|
|
||
| def forward(self, x): | ||
| reduce_dims = tuple(range(x.dim() - 1)) |
There was a problem hiding this comment.
Not convinced this is always the best strategy.
There was a problem hiding this comment.
Agreed — computing reduce_dims on every forward call is a bit heavy-handed. The concern was that the same RunningNorm instance might see both 2D and 3D input, but in practice each pipeline layer sees a fixed rank.
Options:
- Lazy-set once: compute on first forward call and store it, raise if rank changes later
- Constructor arg: add
input_ndimparam (default 2), compute at init - Keep as-is: cost of
tuple(range(x.dim() - 1))is negligible vs the actual computation
I'd lean toward (1) — minimal API change, still auto-detects, but makes the assumption explicit. Want me to go with that?
There was a problem hiding this comment.
Right, these are two different normalization modes:
-
Per-element (images/spectra): reduce over batch only (
dim=0), stats shape matches spatial dims. Mean spectrum subtracted per pixel — correct for sequential SBI where the per-pixel mean shifts across rounds. -
Per-feature (transformers): reduce over batch + sequence (
dim=(0,1,...,N-2)), stats shape is(n_features,). One mean/var per feature across all positions.
The old code always did (1). The new code always does (2). Neither is universally right.
Simplest fix — add a per_feature flag (default False to preserve old behavior):
class RunningNorm(nn.Module):
def __init__(self, ..., per_feature=False):
self.per_feature = per_feature
...
def forward(self, x):
if self.per_feature:
reduce_dims = tuple(range(x.dim() - 1)) # (B, S, F) → stats (F,)
else:
reduce_dims = (0,) # (B, ...) → stats (...)
...per_feature=False: per-pixel/per-element stats (images, spectra)
per_feature=True: per-feature stats (transformer input)
For the spectral analysis example, the config would use per_feature: true since RunningNorm sits between TokenEmbed (B, S, 5) and the transformer.
There was a problem hiding this comment.
Addressed in d4fa3f8. Replaced computed reduce_dims with an explicit dim parameter (default (0,), matching torch.mean semantics). Uses keepdim=True internally for correct broadcasting with arbitrary reduction dims.
Config usage: dim: [0, 1] reduces over batch and sequence dims (per-feature stats for transformer input), dim: [0] reduces over batch only (per-element stats for images/spectra).
falcon/estimators/base.py
Outdated
| embedding.eval() | ||
| with torch.no_grad(): | ||
| conditions_device = {k: v.to(self.device, dtype=dtype) for k, v in conditions.items()} | ||
| conditions_device = {k: v.to(self.device).float() for k, v in conditions.items()} |
There was a problem hiding this comment.
Why do we enforce float here? That is pretty extreme. It should be up to the
There was a problem hiding this comment.
You're right, this is too aggressive. The .float() was a quick fix for the dtype mismatch after removing the blanket dtype=theta.dtype cast — numpy arrays produce float64 tensors, which crash float32 embedding modules (e.g. E_identity, E_fft_norm in examples 04/05).
But forcing float32 on all conditions is wrong — it's the embedding's job to decide its input dtype. Better approach: let the embedding declare what it expects, or just cast conditions to match the embedding's parameter dtype:
# Cast to embedding's dtype (respects whatever the embedding was initialized as)
emb_dtype = next(embedding.parameters()).dtype if list(embedding.parameters()) else torch.float32
conditions_device = {k: v.to(self.device, dtype=emb_dtype) for k, v in conditions.items()}This way parameter-free embeddings (like E_identity) default to float32, while embeddings with float64 parameters keep float64. Want me to go this route?
There was a problem hiding this comment.
Addressed in d4fa3f8. All three .float() calls removed. Each consumer handles its own dtype:
- GaussianPosterior casts conditions to MLP dtype internally via
conditions.to(self._input_mean.dtype) - Example 05 embeddings now cast to float32 at entry (
x = x.float())
This allows mixed-dtype embedding pipelines (e.g., float64 TokenEmbed → RunningNorm → float32 Transformer) to work correctly.
falcon/estimators/base.py
Outdated
| theta_logprob = self._to_tensor(batch[f"{self.theta_key}.log_prob"]) | ||
| conditions = { | ||
| k: self._to_tensor(batch[f"{k}.value"], self.device) | ||
| k: self._to_tensor(batch[f"{k}.value"], self.device).float() |
There was a problem hiding this comment.
Same issue as above — see reply on line 454. Will fix all three .float() calls together with the embedding-dtype-aware approach.
falcon/estimators/base.py
Outdated
|
|
||
| conditions_device = { | ||
| k: self._to_tensor(v, self.device).expand(num_samples, *v.shape[1:]) | ||
| k: self._to_tensor(v, self.device).float().expand(num_samples, *v.shape[1:]) |
There was a problem hiding this comment.
Same — will fix with the other two.
| result._residual_cov = result._residual_cov.double() | ||
| result._residual_eigvals = result._residual_eigvals.double() | ||
| result._residual_eigvecs = result._residual_eigvecs.double() | ||
| return result |
There was a problem hiding this comment.
Shouldn't there be just a separate argument for the parameter-space buffer type?
There was a problem hiding this comment.
Good point. A param_dtype argument (default torch.float64) would make this explicit and configurable rather than hardcoded in to():
class GaussianPosterior(nn.Module):
def __init__(self, ..., param_dtype=torch.float64):
self.param_dtype = param_dtype
...
# Register buffers in param_dtype
self.register_buffer('_output_mean', torch.zeros(param_dim, dtype=param_dtype))
...
def to(self, *args, **kwargs):
result = super().to(*args, **kwargs)
# Restore param-space buffers to their declared dtype
for name in ('_output_mean', '_output_std', '_residual_cov', '_residual_eigvals', '_residual_eigvecs'):
buf = getattr(result, name)
setattr(result, name, buf.to(self.param_dtype))
return resultThen users can set param_dtype: float32 in the YAML network: config if their problem doesn't need float64 precision. The default stays float64 for the sequential SBI use case. Want me to implement this?
There was a problem hiding this comment.
Addressed in d4fa3f8. No separate param_dtype argument needed. Instead:
to()preserves the current buffer dtype (inferred, not hardcoded)- All EMA updates use non-in-place ops (
(1-m)*buf + m*val) which auto-promote via PyTorch's type promotion rules - When theta is float64, output buffers auto-promote to float64; when theta is float32, everything stays float32
This means the param-space dtype is inferred from the first theta batch — no configuration needed.
falcon/estimators/gaussian.py
Outdated
| x_norm = (conditions - self._input_mean.detach()) / self._input_std.detach() | ||
| r = self.net(x_norm) | ||
| r = self.net(x_norm) # float32 | ||
| r = r.double() # → float64 |
There was a problem hiding this comment.
Not convinced this is always the best. Or good default to have? There might be situations where the parameters are fundamentally float32, which woudl break the setup.
There was a problem hiding this comment.
Fair point. The r.double() cast would indeed break if theta is float32 — _output_mean would be float64 but theta would be float32, causing a dtype mismatch in log_prob.
This ties into the param_dtype argument from the comment above. If we add that, this line becomes:
r = r.to(self.param_dtype) # matches whatever the buffers areWith param_dtype=float32, there's no cast at all — everything stays float32. With param_dtype=float64 (default for sequential SBI), the cast happens here. Clean and explicit either way.
There was a problem hiding this comment.
Addressed in d4fa3f8. The r.double() hardcoded cast is gone. Now uses r.to(self._output_mean.dtype) — so it follows whatever dtype the output buffers have (which is inferred from theta via auto-promotion). If theta is float32, everything stays float32. If theta is float64, buffers auto-promote and de-whitening runs in float64.
…ility RunningNorm: Replace computed reduce_dims with explicit `dim` parameter (default (0,)) using keepdim=True for flexible reduction semantics. GaussianPosterior: Replace in-place lerp_/copy_ with non-in-place EMA ops for automatic dtype promotion from theta. to() preserves current buffer dtype rather than hardcoding float64. LossBasedEstimator: Remove .float() casts — each consumer (embedding, posterior) is responsible for its own dtype handling. Example 05 embeddings: Add x.float() at entry since nn.Linear expects float32 input, and replace lerp_ with non-in-place EMA. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
min_training_samples→min_samples,max_training_samples→max_samples,validation_window_size→validation_samples,resample_batch_size→simulate_count,keep_resampling→simulate_when_full,resample_interval→simulate_intervalchunk_sizefromray:section to node-levelsample_chunk_size; filter non-Ray keys from actor options to prevent ValueErrortheta → x → tokensgraphsLazyOnlineNorm→RunningNormwith 3D support andoutput_dtypeparameter; backwards-compatible alias keptGaussianPosterior: overrideto()to protect float64 buffers, cast MLP output to float64 before de-whiteningbase.py: stop forcing theta dtype onto embedding/posterior; cast conditions to float32Test plan
pytest tests/test_examples_smoke.py -v)loss.backward()works)🤖 Generated with Claude Code