Skip to content

[ENH] MDNRegressor (Mixture Density Network)#796

Open
joshdunnlime wants to merge 8 commits intosktime:mainfrom
joshdunnlime:mdn
Open

[ENH] MDNRegressor (Mixture Density Network)#796
joshdunnlime wants to merge 8 commits intosktime:mainfrom
joshdunnlime:mdn

Conversation

@joshdunnlime
Copy link
Contributor

@joshdunnlime joshdunnlime commented Mar 5, 2026

Reference Issues/PRs

No issue opened.

What does this implement/fix? Explain your changes.

A new regressor implementation of Mixture Density Network (MDN) as per Bishop 1994 with noise regularisation as per Rothfuss 2019.
It also includes optional passing of pytorch activation functions and optimizers.

It implements a fully vectorised NormalMixture distribution where each rows weights are individually learnt and applied. It also implements a custom vectorised bisection for fast _ppf method calls.

Does your contribution introduce a new dependency? If yes, which one?

Yes, soft deps: Pytorch and pytorch-optimizer.

What should a reviewer concentrate their feedback on?

NormalMixture distribution. This adds slightly opinionated design choice to the API.

Did you add any tests for the change?

Yes. Standard library param tests for dist and est. Additional test for est coming.

Any other comments?

PR checklist

For all contributions
  • I've added myself to the list of contributors with any new badges I've earned :-)
    How to: add yourself to the all-contributors file in the skpro root directory (not the CONTRIBUTORS.md). Common badges: code - fixing a bug, or adding code logic. doc - writing or improving documentation or docstrings. bug - reporting or diagnosing a bug (get this plus code if you also fixed the bug in the PR).maintenance - CI, test framework, release.
    See here for full badge reference
  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.
For new estimators
  • I've added the estimator to the API reference - in docs/source/api_reference/taskname.rst, follow the pattern.
  • I've added one or more illustrative usage examples to the docstring, in a pydocstyle compliant Examples section.
  • If the estimator relies on a soft dependency, I've set the python_dependencies tag and ensured
    dependency isolation, see the estimator dependencies guide.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Nice!

  • question: why implement a NormalMixture when Mixture is available and should be capable of representing mixture of normals by a composition of Normal and Mixture?
  • if we add NormalMixture, it should also be added in the API reference for distribution

@joshdunnlime
Copy link
Contributor Author

My understanding of Mixture is that it only has global weights for each sub-dist. MDNR could use this, however, it would mean that for each prediction row, it would use a new instance of Mixture (computationally this would be crippling). MixtureNormal has performant, vectorised implementation of most of the class methods so is actually pretty quick!

You could argue that we could call it Normal(Row|Instance|HeteroWeighted)Mixture but I think it is documented clearly enough to avoid confusion.

Add NormalMixture to distributions list. Also split and rename to silverman and scott as is more pythonic convention.
Initial testing shows ISJ provides improved convergence and final NLL scores. Bandwidths added as a separate module as this will be very useful for other kernel based estimators.
Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Yes, that makes sense.

Do you want to open an issue that wishes for an abstract mixture that can have row-wise different mixture weights? That could be extended from Mixture, if weights are not just a list but a matrix.

@joshdunnlime
Copy link
Contributor Author

I did consider the following:

  1. Extending Mixture to include "instance" and "global" weights.
  2. Create a generic InstanceWeightedMixture class.
  3. A base class BaseInstanceWeightedMixture.

The downsides were:

  1. Far too much work. IMHO these are two quite different classes to the point I think this would just be a wrapped around a Mixture and InstanceWeightedMixture.

  2. InstanceWeightedMixture which accepts a mixture of different dists comes with a lot of complications. The primary one is that the NormalMixture has very well defined (and mostly exact) methods for all of the standard methods. My understanding of weighted mixtures of distributions stops a normals so I have no idea what is feasible here. The vectorisation is clearly very well defined for normal distributions in scipy, and assumptions on the support for root-finding in _ppf were fairly straightforward. Implemenation-wise, I believe it serves to keep NormalMixture even if we have 3) as the speed-ups and exactness are significant.

  3. Happy to open a Issue. My knowledge of anything beyond normal mixtures of dists is effectively zero so I doubt I'd be up to a PR.

@joshdunnlime
Copy link
Contributor Author

Just as a side note: The natural extension to MDNs is probably something Flow/Kernel based, as opposed to extending this with non-Normal dists.

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine now, but one request:

the estimator MDNRegressor does not actually depend on pytorch-optimizer, except if the string "SOAP" is passed. But that only does aliasing, and you are doing soft dependency checking there is no intrinsic dependency. Hence I would remove pytorch-optimizer from the dependency set.

I would also suggest to replace the try/except for dependency checking with _check_soft_dependencies (severity="none")


XGBoostLSS

Neural conditional density estimation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about "deep learning based regressors" instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants