Skip to content

MicroSplit API: noise model & loss refactoring#700

Open
CatEek wants to merge 29 commits intodev/microsplit_apifrom
iz/feat/noise_model_refac
Open

MicroSplit API: noise model & loss refactoring#700
CatEek wants to merge 29 commits intodev/microsplit_apifrom
iz/feat/noise_model_refac

Conversation

@CatEek
Copy link
Copy Markdown
Contributor

@CatEek CatEek commented Jan 25, 2026

Disclaimer

  • I am an AI agent.
  • I have used AI and I thoroughly reviewed every line.
  • I have not used AI extensively.

Description

Note

tldr: Refactor MicroSplit noise-model and loss configuration, replacing
fragmented likelihood configs with a unified logic.

Background - why do we need this PR?

The previous implementation spread noise-model related code across 3 separate
Pydantic models (GaussianLikelihoodConfig, NMLikelihoodConfig, and
MultiChannelNMConfig), three loss types (musplit, denoisplit,
denoisplit_musplit), and multiple VAEModule attributes
(gaussian_likelihood, noise_model_likelihood). This made it hard to reason
about which combination of configs was valid, made code very hard to read and error prone.

Overview - what changed?

  • Likelihood configuration objects are removed; their parameters live directly
    in LVAELossConfig.
  • predict_logvar is now a plain bool everywhere (was Literal[None, "pixelwise"]).
  • GaussianMixtureNMConfig exposes a from_npz classmethod instead of
    accepting a path field that triggered side-effectful validation.(That is questionable and likely to be changed)
  • VAEModule holds data statistics directly and passes them into the loss
    function, removing the intermediary likelihood objects.
  • NoiseModelTrainer is added(Not yet used).
  • musplit_loss, denoisplit_loss, and denoisplit_musplit_loss are replaced
    by a single microsplit_loss, controlled by musplit_weight / denoisplit_weight.. The SupportedLoss enum retains only hdn and
    microsplit.

Performance tests are not finished!

Please ensure your PR meets the following requirements:

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features)

@CatEek CatEek marked this pull request as draft January 25, 2026 20:57
@CatEek CatEek force-pushed the iz/feat/noise_model_refac branch from c35d2bc to f66e8d5 Compare January 31, 2026 22:47
@jdeschamps jdeschamps requested review from jdeschamps and removed request for jdeschamps February 4, 2026 09:25
@jdeschamps jdeschamps changed the title BB Noise model & loss refactoring for MicroSplit MicroSplit API: Noise model & loss refactoring Mar 25, 2026
@jdeschamps jdeschamps changed the title MicroSplit API: Noise model & loss refactoring MicroSplit API: noise model & loss refactoring Mar 25, 2026
@CatEek CatEek requested a review from jdeschamps March 25, 2026 17:32
@jdeschamps jdeschamps changed the base branch from main to microsplit_api March 26, 2026 09:27
@jdeschamps jdeschamps marked this pull request as ready for review March 26, 2026 12:47
)

predict_logvar: Literal[None, "pixelwise"] = "pixelwise"
predict_logvar: bool = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a parameter description?

Comment on lines -13 to -26
loss_type: Literal["kl", "kl_restricted"] = "kl"
"""Type of KL divergence used as KL loss."""
rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
"""Rescaling of the KL loss."""
aggregation: Literal["sum", "mean"] = "mean"
"""Aggregation of the KL loss across different layers."""
free_bits_coeff: float = 0.0
"""Free bits coefficient for the KL loss."""
annealing: bool = False
"""Whether to apply KL loss annealing."""
start: int = -1
"""Epoch at which KL loss annealing starts."""
annealtime: int = 10
"""Number of epochs for which KL loss annealing is applied."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on why were these parameters removed? Are they never changed and were just left over from over-parametrization?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point of having a Pydantic model for a list of single channel NMs

"noise_model, data_mean, and data_std required when denoisplit_weight > 0"
)
recons_loss: torch.Tensor | float = 0.0
if nm_weight > 0 and gaussian_weight > 0:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are multiple if clause with the same conditions, it would make sense to split them into different functions.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the loss could be further refactored in a subsequent PR.

raise NotImplementedError(
f"Model {model_config.model_type} is not implemented"
)
if model_config.model_type == "GaussianMixtureNoiseModel":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That string should indeed be simplified

)

n_channels = signal.shape[-3]
if n_channels != self._nm_cnt:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._nm_cnt is not an easy name to read or understand.

"""Module containing pytorch implementations for obtaining predictions from an LVAE.

from typing import Any
DEPRECATED: This module uses the old likelihood-based approach and needs to be updated
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean that prediction is not compatible with the current state of this branch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants