Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions linear_operator_learning/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,27 @@ def orthonormal_fro_reg(x: Tensor) -> Tensor:

Given a batch of realizations of `x`, the orthonormality regularization term penalizes:

1. Orthogonality: Linear dependencies among dimensions,
2. Normality: Deviations of each dimensions variance from 1,
3. Centering: Deviations of each dimensions mean from 0.
1. Orthogonality: Linear dependencies among features,
2. Normality: Deviations of each dimension's variance from 1,
3. Centering: Deviations of each dimension's mean from 0.

.. math::

\frac{1}{D} \| \mathbf{C}_{X} - I \|_F^2 + 2 \| \mathbb{E}_{X} x \|^2 = \frac{1}{D} (\text{tr}(\mathbf{C}^2_{X}) - 2 \text{tr}(\mathbf{C}_{X}) + D + 2 \| \mathbb{E}_{X} x \|^2)
\frac{1}{D} \left( \| \mathbb{E}[XX^\top] - I \|_F^2 + 2 \| \mathbb{E}X \|^2 \right).

Args:
x (Tensor): Input features.

Shape:
``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features.
"""
x_mean = x.mean(dim=0, keepdim=True)
x_centered = x - x_mean
# As ||Cx||_F^2 = E_(x,x')~p(x) [((x - E_p(x) x)^T (x' - E_p(x) x'))^2] = tr(Cx^2), involves the product of
# covariances, unbiased estimation of this term requires the use of U-statistics
Cx_fro_2 = cov_norm_squared_unbiased(x_centered)
# tr(Cx) = E_p(x) [(x - E_p(x))^T (x - E_p(x))] ≈ 1/N Σ_n (x_n - E_p(x))^T (x_n - E_p(x))
tr_Cx = torch.einsum("ij,ij->", x_centered, x_centered) / x.shape[0]
centering_loss = (x_mean**2).sum() # ||E_p(x) x||^2
D = x.shape[-1] # ||I||_F^2 = D
reg = Cx_fro_2 - 2 * tr_Cx + D + 2 * centering_loss
return reg / D
n, d = x.shape

inner_products = x @ x.T
off_diag = (inner_products**2 + 2 * inner_products).fill_diagonal_(0)
reg = off_diag.sum() / (n * (n - 1)) - 2 * (x**2).sum() / n + d

return reg / d


def orthonormal_logfro_reg(x: Tensor) -> Tensor:
Expand Down