diff --git a/linear_operator_learning/nn/functional.py b/linear_operator_learning/nn/functional.py index 3cc6482..b290f04 100644 --- a/linear_operator_learning/nn/functional.py +++ b/linear_operator_learning/nn/functional.py @@ -101,13 +101,13 @@ def orthonormal_fro_reg(x: Tensor) -> Tensor: Given a batch of realizations of `x`, the orthonormality regularization term penalizes: - 1. Orthogonality: Linear dependencies among dimensions, - 2. Normality: Deviations of each dimension’s variance from 1, - 3. Centering: Deviations of each dimension’s mean from 0. + 1. Orthogonality: Linear dependencies among features, + 2. Normality: Deviations of each dimension's variance from 1, + 3. Centering: Deviations of each dimension's mean from 0. .. math:: - \frac{1}{D} \| \mathbf{C}_{X} - I \|_F^2 + 2 \| \mathbb{E}_{X} x \|^2 = \frac{1}{D} (\text{tr}(\mathbf{C}^2_{X}) - 2 \text{tr}(\mathbf{C}_{X}) + D + 2 \| \mathbb{E}_{X} x \|^2) + \frac{1}{D} \left( \| \mathbb{E}[XX^\top] - I \|_F^2 + 2 \| \mathbb{E}X \|^2 \right). Args: x (Tensor): Input features. @@ -115,17 +115,13 @@ def orthonormal_fro_reg(x: Tensor) -> Tensor: Shape: ``x``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. """ - x_mean = x.mean(dim=0, keepdim=True) - x_centered = x - x_mean - # As ||Cx||_F^2 = E_(x,x')~p(x) [((x - E_p(x) x)^T (x' - E_p(x) x'))^2] = tr(Cx^2), involves the product of - # covariances, unbiased estimation of this term requires the use of U-statistics - Cx_fro_2 = cov_norm_squared_unbiased(x_centered) - # tr(Cx) = E_p(x) [(x - E_p(x))^T (x - E_p(x))] ≈ 1/N Σ_n (x_n - E_p(x))^T (x_n - E_p(x)) - tr_Cx = torch.einsum("ij,ij->", x_centered, x_centered) / x.shape[0] - centering_loss = (x_mean**2).sum() # ||E_p(x) x||^2 - D = x.shape[-1] # ||I||_F^2 = D - reg = Cx_fro_2 - 2 * tr_Cx + D + 2 * centering_loss - return reg / D + n, d = x.shape + + inner_products = x @ x.T + off_diag = (inner_products**2 + 2 * inner_products).fill_diagonal_(0) + reg = off_diag.sum() / (n * (n - 1)) - 2 * (x**2).sum() / n + d + + return reg / d def orthonormal_logfro_reg(x: Tensor) -> Tensor: