MicroSplit API: noise model & loss refactoring#700
MicroSplit API: noise model & loss refactoring#700CatEek wants to merge 29 commits intodev/microsplit_apifrom
Conversation
Co-authored-by: Vera Galinova <32124316+veegalinova@users.noreply.github.com>
c35d2bc to
f66e8d5
Compare
| ) | ||
|
|
||
| predict_logvar: Literal[None, "pixelwise"] = "pixelwise" | ||
| predict_logvar: bool = True |
There was a problem hiding this comment.
can we add a parameter description?
| loss_type: Literal["kl", "kl_restricted"] = "kl" | ||
| """Type of KL divergence used as KL loss.""" | ||
| rescaling: Literal["latent_dim", "image_dim"] = "latent_dim" | ||
| """Rescaling of the KL loss.""" | ||
| aggregation: Literal["sum", "mean"] = "mean" | ||
| """Aggregation of the KL loss across different layers.""" | ||
| free_bits_coeff: float = 0.0 | ||
| """Free bits coefficient for the KL loss.""" | ||
| annealing: bool = False | ||
| """Whether to apply KL loss annealing.""" | ||
| start: int = -1 | ||
| """Epoch at which KL loss annealing starts.""" | ||
| annealtime: int = 10 | ||
| """Number of epochs for which KL loss annealing is applied.""" |
There was a problem hiding this comment.
Can you comment on why were these parameters removed? Are they never changed and were just left over from over-parametrization?
There was a problem hiding this comment.
I don't see the point of having a Pydantic model for a list of single channel NMs
| "noise_model, data_mean, and data_std required when denoisplit_weight > 0" | ||
| ) | ||
| recons_loss: torch.Tensor | float = 0.0 | ||
| if nm_weight > 0 and gaussian_weight > 0: |
There was a problem hiding this comment.
Since there are multiple if clause with the same conditions, it would make sense to split them into different functions.
There was a problem hiding this comment.
I feel the loss could be further refactored in a subsequent PR.
| raise NotImplementedError( | ||
| f"Model {model_config.model_type} is not implemented" | ||
| ) | ||
| if model_config.model_type == "GaussianMixtureNoiseModel": |
There was a problem hiding this comment.
That string should indeed be simplified
| ) | ||
|
|
||
| n_channels = signal.shape[-3] | ||
| if n_channels != self._nm_cnt: |
There was a problem hiding this comment.
self._nm_cnt is not an easy name to read or understand.
| """Module containing pytorch implementations for obtaining predictions from an LVAE. | ||
|
|
||
| from typing import Any | ||
| DEPRECATED: This module uses the old likelihood-based approach and needs to be updated |
There was a problem hiding this comment.
Does that mean that prediction is not compatible with the current state of this branch?
Disclaimer
Description
Note
tldr: Refactor MicroSplit noise-model and loss configuration, replacing
fragmented likelihood configs with a unified logic.
Background - why do we need this PR?
The previous implementation spread noise-model related code across 3 separate
Pydantic models (
GaussianLikelihoodConfig,NMLikelihoodConfig, andMultiChannelNMConfig), three loss types (musplit,denoisplit,denoisplit_musplit), and multipleVAEModuleattributes(
gaussian_likelihood,noise_model_likelihood). This made it hard to reasonabout which combination of configs was valid, made code very hard to read and error prone.
Overview - what changed?
in
LVAELossConfig.predict_logvaris now a plainbooleverywhere (wasLiteral[None, "pixelwise"]).GaussianMixtureNMConfigexposes afrom_npzclassmethod instead ofaccepting a
pathfield that triggered side-effectful validation.(That is questionable and likely to be changed)VAEModuleholds data statistics directly and passes them into the lossfunction, removing the intermediary likelihood objects.
NoiseModelTraineris added(Not yet used).by a single microsplit_loss, controlled by
musplit_weight/denoisplit_weight.. The SupportedLoss enum retains only hdn andmicrosplit.
Performance tests are not finished!
Please ensure your PR meets the following requirements: