Skip to content

Rename buffer config keys and add mixed-precision dtype zones#44

Merged
cweniger merged 3 commits intomainfrom
refactor/mixed-precision
Feb 22, 2026
Merged

Rename buffer config keys and add mixed-precision dtype zones#44
cweniger merged 3 commits intomainfrom
refactor/mixed-precision

Conversation

@cweniger
Copy link
Copy Markdown
Owner

Summary

  • Buffer config renames: min_training_samplesmin_samples, max_training_samplesmax_samples, validation_window_sizevalidation_samples, resample_batch_sizesimulate_count, keep_resamplingsimulate_when_full, resample_intervalsimulate_interval
  • Move chunk_size from ray: section to node-level sample_chunk_size; filter non-Ray keys from actor options to prevent ValueError
  • Inference graph expansion: support deterministic intermediate nodes reachable via evidence (BFS expansion), enabling theta → x → tokens graphs
  • Mixed-precision dtype zones: neural network layers run in float32 (tensor-core speed), parameter-space operations (covariance, eigendecomp, sampling) stay float64
  • Rename LazyOnlineNormRunningNorm with 3D support and output_dtype parameter; backwards-compatible alias kept
  • GaussianPosterior: override to() to protect float64 buffers, cast MLP output to float64 before de-whitening
  • base.py: stop forcing theta dtype onto embedding/posterior; cast conditions to float32

Test plan

  • All 5 smoke tests pass (pytest tests/test_examples_smoke.py -v)
  • GaussianPosterior dtype zones verified (MLP float32, output/residual buffers float64, loss.backward() works)
  • RunningNorm: 2D/3D input, output_dtype cast, backwards-compatible alias
  • All example configs updated to new key names

🤖 Generated with Claude Code

cweniger and others added 2 commits February 22, 2026 20:32
…filter Ray actor options

- min_training_samples → min_samples, max_training_samples → max_samples
- validation_window_size → validation_samples
- resample_batch_size → simulate_count
- keep_resampling → simulate_when_full, resample_interval → simulate_interval
- chunk_size (in ray:) → sample_chunk_size (node level)
- Filter non-Ray keys from actor_config to prevent ValueError
- Support deterministic intermediate nodes in inference graph (BFS expansion)
- Rename internal graph attributes: parents_dict → forward_deps,
  sorted_node_names → forward_order, sorted_inference_node_names → backward_order

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Embedding pipeline and posterior no longer inherit float64 from numpy
theta. Neural network layers (TransformerEmbedding, MLP) run in float32
for tensor-core speed; parameter-space operations (covariance, sampling)
stay float64 for precision.

- Rename LazyOnlineNorm → RunningNorm with 3D reduce_dims and output_dtype
- GaussianPosterior: override to() to protect float64 buffers, cast MLP
  output to float64 before de-whitening
- base.py: remove dtype forcing, cast conditions to float32 for embeddings

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov bot commented Feb 22, 2026

Codecov Report

❌ Patch coverage is 1.03093% with 96 lines in your changes missing coverage. Please review.
✅ Project coverage is 9.58%. Comparing base (343ed0d) to head (d4fa3f8).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
falcon/core/graph.py 0.00% 32 Missing ⚠️
falcon/estimators/gaussian.py 0.00% 20 Missing ⚠️
falcon/core/raystore.py 6.25% 15 Missing ⚠️
falcon/embeddings/norms.py 0.00% 12 Missing ⚠️
falcon/core/deployed_graph.py 0.00% 11 Missing ⚠️
falcon/estimators/base.py 0.00% 3 Missing ⚠️
falcon/cli.py 0.00% 2 Missing ⚠️
falcon/embeddings/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##            main     #44      +/-   ##
========================================
- Coverage   9.69%   9.58%   -0.11%     
========================================
  Files         32      32              
  Lines       3797    3839      +42     
========================================
  Hits         368     368              
- Misses      3429    3471      +42     
Flag Coverage Δ
unit 9.58% <1.03%> (-0.11%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

self.initialized = False

def forward(self, x):
reduce_dims = tuple(range(x.dim() - 1))
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not convinced this is always the best strategy.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed — computing reduce_dims on every forward call is a bit heavy-handed. The concern was that the same RunningNorm instance might see both 2D and 3D input, but in practice each pipeline layer sees a fixed rank.

Options:

  1. Lazy-set once: compute on first forward call and store it, raise if rank changes later
  2. Constructor arg: add input_ndim param (default 2), compute at init
  3. Keep as-is: cost of tuple(range(x.dim() - 1)) is negligible vs the actual computation

I'd lean toward (1) — minimal API change, still auto-detects, but makes the assumption explicit. Want me to go with that?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, these are two different normalization modes:

  1. Per-element (images/spectra): reduce over batch only (dim=0), stats shape matches spatial dims. Mean spectrum subtracted per pixel — correct for sequential SBI where the per-pixel mean shifts across rounds.

  2. Per-feature (transformers): reduce over batch + sequence (dim=(0,1,...,N-2)), stats shape is (n_features,). One mean/var per feature across all positions.

The old code always did (1). The new code always does (2). Neither is universally right.

Simplest fix — add a per_feature flag (default False to preserve old behavior):

class RunningNorm(nn.Module):
    def __init__(self, ..., per_feature=False):
        self.per_feature = per_feature
        ...

    def forward(self, x):
        if self.per_feature:
            reduce_dims = tuple(range(x.dim() - 1))  # (B, S, F) → stats (F,)
        else:
            reduce_dims = (0,)                         # (B, ...) → stats (...)
        ...

per_feature=False: per-pixel/per-element stats (images, spectra)
per_feature=True: per-feature stats (transformer input)

For the spectral analysis example, the config would use per_feature: true since RunningNorm sits between TokenEmbed (B, S, 5) and the transformer.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d4fa3f8. Replaced computed reduce_dims with an explicit dim parameter (default (0,), matching torch.mean semantics). Uses keepdim=True internally for correct broadcasting with arbitrary reduction dims.

Config usage: dim: [0, 1] reduces over batch and sequence dims (per-feature stats for transformer input), dim: [0] reduces over batch only (per-element stats for images/spectra).

embedding.eval()
with torch.no_grad():
conditions_device = {k: v.to(self.device, dtype=dtype) for k, v in conditions.items()}
conditions_device = {k: v.to(self.device).float() for k, v in conditions.items()}
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we enforce float here? That is pretty extreme. It should be up to the

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this is too aggressive. The .float() was a quick fix for the dtype mismatch after removing the blanket dtype=theta.dtype cast — numpy arrays produce float64 tensors, which crash float32 embedding modules (e.g. E_identity, E_fft_norm in examples 04/05).

But forcing float32 on all conditions is wrong — it's the embedding's job to decide its input dtype. Better approach: let the embedding declare what it expects, or just cast conditions to match the embedding's parameter dtype:

# Cast to embedding's dtype (respects whatever the embedding was initialized as)
emb_dtype = next(embedding.parameters()).dtype if list(embedding.parameters()) else torch.float32
conditions_device = {k: v.to(self.device, dtype=emb_dtype) for k, v in conditions.items()}

This way parameter-free embeddings (like E_identity) default to float32, while embeddings with float64 parameters keep float64. Want me to go this route?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d4fa3f8. All three .float() calls removed. Each consumer handles its own dtype:

  • GaussianPosterior casts conditions to MLP dtype internally via conditions.to(self._input_mean.dtype)
  • Example 05 embeddings now cast to float32 at entry (x = x.float())

This allows mixed-dtype embedding pipelines (e.g., float64 TokenEmbed → RunningNorm → float32 Transformer) to work correctly.

theta_logprob = self._to_tensor(batch[f"{self.theta_key}.log_prob"])
conditions = {
k: self._to_tensor(batch[f"{k}.value"], self.device)
k: self._to_tensor(batch[f"{k}.value"], self.device).float()
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why float?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above — see reply on line 454. Will fix all three .float() calls together with the embedding-dtype-aware approach.


conditions_device = {
k: self._to_tensor(v, self.device).expand(num_samples, *v.shape[1:])
k: self._to_tensor(v, self.device).float().expand(num_samples, *v.shape[1:])
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, why float?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same — will fix with the other two.

result._residual_cov = result._residual_cov.double()
result._residual_eigvals = result._residual_eigvals.double()
result._residual_eigvecs = result._residual_eigvecs.double()
return result
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be just a separate argument for the parameter-space buffer type?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. A param_dtype argument (default torch.float64) would make this explicit and configurable rather than hardcoded in to():

class GaussianPosterior(nn.Module):
    def __init__(self, ..., param_dtype=torch.float64):
        self.param_dtype = param_dtype
        ...
        # Register buffers in param_dtype
        self.register_buffer('_output_mean', torch.zeros(param_dim, dtype=param_dtype))
        ...

    def to(self, *args, **kwargs):
        result = super().to(*args, **kwargs)
        # Restore param-space buffers to their declared dtype
        for name in ('_output_mean', '_output_std', '_residual_cov', '_residual_eigvals', '_residual_eigvecs'):
            buf = getattr(result, name)
            setattr(result, name, buf.to(self.param_dtype))
        return result

Then users can set param_dtype: float32 in the YAML network: config if their problem doesn't need float64 precision. The default stays float64 for the sequential SBI use case. Want me to implement this?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d4fa3f8. No separate param_dtype argument needed. Instead:

  • to() preserves the current buffer dtype (inferred, not hardcoded)
  • All EMA updates use non-in-place ops ((1-m)*buf + m*val) which auto-promote via PyTorch's type promotion rules
  • When theta is float64, output buffers auto-promote to float64; when theta is float32, everything stays float32

This means the param-space dtype is inferred from the first theta batch — no configuration needed.

x_norm = (conditions - self._input_mean.detach()) / self._input_std.detach()
r = self.net(x_norm)
r = self.net(x_norm) # float32
r = r.double() # → float64
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not convinced this is always the best. Or good default to have? There might be situations where the parameters are fundamentally float32, which woudl break the setup.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point. The r.double() cast would indeed break if theta is float32 — _output_mean would be float64 but theta would be float32, causing a dtype mismatch in log_prob.

This ties into the param_dtype argument from the comment above. If we add that, this line becomes:

r = r.to(self.param_dtype)  # matches whatever the buffers are

With param_dtype=float32, there's no cast at all — everything stays float32. With param_dtype=float64 (default for sequential SBI), the cast happens here. Clean and explicit either way.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in d4fa3f8. The r.double() hardcoded cast is gone. Now uses r.to(self._output_mean.dtype) — so it follows whatever dtype the output buffers have (which is inferred from theta via auto-promotion). If theta is float32, everything stays float32. If theta is float64, buffers auto-promote and de-whitening runs in float64.

…ility

RunningNorm: Replace computed reduce_dims with explicit `dim` parameter
(default (0,)) using keepdim=True for flexible reduction semantics.

GaussianPosterior: Replace in-place lerp_/copy_ with non-in-place EMA ops
for automatic dtype promotion from theta. to() preserves current buffer
dtype rather than hardcoding float64.

LossBasedEstimator: Remove .float() casts — each consumer (embedding,
posterior) is responsible for its own dtype handling.

Example 05 embeddings: Add x.float() at entry since nn.Linear expects
float32 input, and replace lerp_ with non-in-place EMA.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@cweniger cweniger merged commit e62cece into main Feb 22, 2026
4 of 6 checks passed
@cweniger cweniger deleted the refactor/mixed-precision branch March 9, 2026 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant