-
Notifications
You must be signed in to change notification settings - Fork 30
feat: Add robust fallback for ill-conditioned overlap matrices in Eigenvalues solver
#320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -60,14 +60,15 @@ def __init__( | |||||||||||||||||
| def forward(self, | ||||||||||||||||||
| data: AtomicDataDict.Type, | ||||||||||||||||||
| nk: Optional[int]=None, | ||||||||||||||||||
| eig_solver: str='torch') -> AtomicDataDict.Type: | ||||||||||||||||||
| eig_solver: str='torch', | ||||||||||||||||||
| ill_threshold: Optional[float]=1e-5) -> AtomicDataDict.Type: | ||||||||||||||||||
|
|
||||||||||||||||||
| if eig_solver is None: | ||||||||||||||||||
| eig_solver = 'torch' | ||||||||||||||||||
| log.warning("eig_solver is not set, using default 'torch'.") | ||||||||||||||||||
| if eig_solver not in ['torch', 'numpy']: | ||||||||||||||||||
| log.error(f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}.") | ||||||||||||||||||
| raise ValueError | ||||||||||||||||||
| raise ValueError | ||||||||||||||||||
|
Comment on lines
69
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Include error message in the exception. The exception is raised without a message, making it harder to diagnose when caught elsewhere. Move the message into the exception: Suggested fix if eig_solver not in ['torch', 'numpy']:
- log.error(f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}.")
- raise ValueError
+ msg = f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}."
+ log.error(msg)
+ raise ValueError(msg)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||
|
|
||||||||||||||||||
| kpoints = data[AtomicDataDict.KPOINT_KEY] | ||||||||||||||||||
| if kpoints.is_nested: | ||||||||||||||||||
|
|
@@ -80,29 +81,135 @@ def forward(self, | |||||||||||||||||
| eigvals = [] | ||||||||||||||||||
| if nk is None: | ||||||||||||||||||
| nk = num_k | ||||||||||||||||||
|
|
||||||||||||||||||
| for i in range(int(np.ceil(num_k / nk))): | ||||||||||||||||||
| data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk] | ||||||||||||||||||
| data = self.h2k(data) | ||||||||||||||||||
| h_transformed_np = None | ||||||||||||||||||
|
|
||||||||||||||||||
| batch_eigvals_torch = None | ||||||||||||||||||
| batch_eigvals_np = None | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.overlap: | ||||||||||||||||||
| data = self.s2k(data) | ||||||||||||||||||
| if eig_solver == 'torch': | ||||||||||||||||||
| chklowt = torch.linalg.cholesky(data[self.s_out_field]) | ||||||||||||||||||
| chklowtinv = torch.linalg.inv(chklowt) | ||||||||||||||||||
| data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj()) | ||||||||||||||||||
| if ill_threshold is None: | ||||||||||||||||||
| chklowt = torch.linalg.cholesky(data[self.s_out_field]) | ||||||||||||||||||
| chklowtinv = torch.linalg.inv(chklowt) | ||||||||||||||||||
| data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj()) | ||||||||||||||||||
| else: | ||||||||||||||||||
| S_k = data[self.s_out_field] | ||||||||||||||||||
| H_k = data[self.h_out_field] | ||||||||||||||||||
| egval_S, egvec_S = torch.linalg.eigh(S_k) | ||||||||||||||||||
| B = S_k.shape[0] | ||||||||||||||||||
| num_orbitals = H_k.shape[-1] | ||||||||||||||||||
| real_dtype = torch.float32 if H_k.dtype in [torch.complex64, torch.float32] else torch.float64 | ||||||||||||||||||
|
|
||||||||||||||||||
| processed_eigvals_list = [] | ||||||||||||||||||
| for k_idx in range(B): | ||||||||||||||||||
| healthy_mask = egval_S[k_idx] > ill_threshold | ||||||||||||||||||
| n_healthy = int(healthy_mask.sum().item()) | ||||||||||||||||||
|
|
||||||||||||||||||
| if n_healthy == 0: | ||||||||||||||||||
| egval = torch.full((num_orbitals,), 1e4, dtype=real_dtype, device=H_k.device) | ||||||||||||||||||
| processed_eigvals_list.append(egval) | ||||||||||||||||||
| continue | ||||||||||||||||||
|
Comment on lines
+113
to
+116
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded padding value The padding value propagates to band structure, DOS, and Fermi surface calculations (see
Consider:
Example using inf if n_healthy == 0:
- egval = torch.full((num_orbitals,), 1e4, dtype=real_dtype, device=H_k.device)
+ egval = torch.full((num_orbitals,), float('inf'), dtype=real_dtype, device=H_k.device)
...
if num_projected_out > 0:
- padding = torch.full((num_projected_out,), 1e4, dtype=egval_proj.dtype, device=egval_proj.device)
+ padding = torch.full((num_projected_out,), float('inf'), dtype=egval_proj.dtype, device=egval_proj.device)Also applies to: 137-143 🤖 Prompt for AI Agents |
||||||||||||||||||
|
|
||||||||||||||||||
| if healthy_mask.all(): | ||||||||||||||||||
| L = torch.linalg.cholesky(S_k[k_idx]) | ||||||||||||||||||
| L_inv = torch.linalg.inv(L) | ||||||||||||||||||
| H_transformed = L_inv @ H_k[k_idx] @ L_inv.conj().T | ||||||||||||||||||
| egval = torch.linalg.eigvalsh(H_transformed) | ||||||||||||||||||
| processed_eigvals_list.append(egval) | ||||||||||||||||||
| continue | ||||||||||||||||||
|
|
||||||||||||||||||
| U_sel = egvec_S[k_idx, :, healthy_mask] | ||||||||||||||||||
| eval_sel = egval_S[k_idx, healthy_mask] | ||||||||||||||||||
|
|
||||||||||||||||||
| H_proj = U_sel.conj().T @ H_k[k_idx] @ U_sel | ||||||||||||||||||
| S_proj = torch.diag(eval_sel).to(dtype=H_proj.dtype, device=H_proj.device) | ||||||||||||||||||
|
|
||||||||||||||||||
| L = torch.linalg.cholesky(S_proj) | ||||||||||||||||||
| L_inv = torch.linalg.inv(L) | ||||||||||||||||||
| H_transformed = L_inv @ H_proj @ L_inv.conj().T | ||||||||||||||||||
| egval_proj = torch.linalg.eigvalsh(H_transformed) | ||||||||||||||||||
|
|
||||||||||||||||||
| num_projected_out = num_orbitals - egval_proj.shape[0] | ||||||||||||||||||
| if num_projected_out > 0: | ||||||||||||||||||
| padding = torch.full((num_projected_out,), 1e4, dtype=egval_proj.dtype, device=egval_proj.device) | ||||||||||||||||||
| egval = torch.cat([egval_proj, padding], dim=0) | ||||||||||||||||||
| else: | ||||||||||||||||||
| egval = egval_proj | ||||||||||||||||||
|
|
||||||||||||||||||
| processed_eigvals_list.append(egval) | ||||||||||||||||||
| batch_eigvals_torch = torch.stack(processed_eigvals_list, dim=0) | ||||||||||||||||||
|
|
||||||||||||||||||
| elif eig_solver == 'numpy': | ||||||||||||||||||
| s_np = data[self.s_out_field].detach().cpu().numpy() | ||||||||||||||||||
| h_np = data[self.h_out_field].detach().cpu().numpy() | ||||||||||||||||||
| chklowt = np.linalg.cholesky(s_np) | ||||||||||||||||||
| chklowtinv = np.linalg.inv(chklowt) | ||||||||||||||||||
| h_transformed_np = chklowtinv @ h_np @ np.transpose(chklowtinv,(0,2,1)).conj() | ||||||||||||||||||
| if ill_threshold is None: | ||||||||||||||||||
| s_np = data[self.s_out_field].detach().cpu().numpy() | ||||||||||||||||||
| h_np = data[self.h_out_field].detach().cpu().numpy() | ||||||||||||||||||
| chklowt = np.linalg.cholesky(s_np) | ||||||||||||||||||
| chklowtinv = np.linalg.inv(chklowt) | ||||||||||||||||||
| h_transformed_np = chklowtinv @ h_np @ np.transpose(chklowtinv,(0,2,1)).conj() | ||||||||||||||||||
| else: | ||||||||||||||||||
| s_np = data[self.s_out_field].detach().cpu().numpy() | ||||||||||||||||||
| h_np = data[self.h_out_field].detach().cpu().numpy() | ||||||||||||||||||
| egval_S, egvec_S = np.linalg.eigh(s_np) | ||||||||||||||||||
| B = s_np.shape[0] | ||||||||||||||||||
| num_orbitals = h_np.shape[-1] | ||||||||||||||||||
| real_dtype = np.float32 if h_np.dtype in [np.complex64, np.float32] else np.float64 | ||||||||||||||||||
|
|
||||||||||||||||||
| processed_eigvals_list = [] | ||||||||||||||||||
| for k_idx in range(B): | ||||||||||||||||||
| healthy_mask = egval_S[k_idx] > ill_threshold | ||||||||||||||||||
| n_healthy = int(healthy_mask.sum()) | ||||||||||||||||||
|
|
||||||||||||||||||
| if n_healthy == 0: | ||||||||||||||||||
| egval = np.full((num_orbitals,), 1e4, dtype=real_dtype) | ||||||||||||||||||
| processed_eigvals_list.append(egval) | ||||||||||||||||||
| continue | ||||||||||||||||||
|
|
||||||||||||||||||
| if healthy_mask.all(): | ||||||||||||||||||
| L = np.linalg.cholesky(s_np[k_idx]) | ||||||||||||||||||
| L_inv = np.linalg.inv(L) | ||||||||||||||||||
| H_transformed = L_inv @ h_np[k_idx] @ L_inv.conj().T | ||||||||||||||||||
| egval = np.linalg.eigvalsh(H_transformed) | ||||||||||||||||||
| processed_eigvals_list.append(egval) | ||||||||||||||||||
| continue | ||||||||||||||||||
|
|
||||||||||||||||||
| U_sel = egvec_S[k_idx, :, healthy_mask] | ||||||||||||||||||
| eval_sel = egval_S[k_idx, healthy_mask] | ||||||||||||||||||
|
|
||||||||||||||||||
| H_proj = U_sel.conj().T @ h_np[k_idx] @ U_sel | ||||||||||||||||||
| S_proj = np.diag(eval_sel).astype(H_proj.dtype) | ||||||||||||||||||
|
|
||||||||||||||||||
| L = np.linalg.cholesky(S_proj) | ||||||||||||||||||
| L_inv = np.linalg.inv(L) | ||||||||||||||||||
| H_transformed = L_inv @ H_proj @ L_inv.conj().T | ||||||||||||||||||
| egval_proj = np.linalg.eigvalsh(H_transformed) | ||||||||||||||||||
|
|
||||||||||||||||||
| num_projected_out = num_orbitals - egval_proj.shape[0] | ||||||||||||||||||
| if num_projected_out > 0: | ||||||||||||||||||
| padding = np.full((num_projected_out,), 1e4, dtype=egval_proj.dtype) | ||||||||||||||||||
| egval = np.concatenate([egval_proj, padding], axis=0) | ||||||||||||||||||
| else: | ||||||||||||||||||
| egval = egval_proj | ||||||||||||||||||
|
|
||||||||||||||||||
| processed_eigvals_list.append(egval) | ||||||||||||||||||
| batch_eigvals_np = np.stack(processed_eigvals_list, axis=0) | ||||||||||||||||||
|
|
||||||||||||||||||
| if eig_solver == 'torch': | ||||||||||||||||||
| eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field])) | ||||||||||||||||||
| if batch_eigvals_torch is not None: | ||||||||||||||||||
| eigvals.append(batch_eigvals_torch) | ||||||||||||||||||
| else: | ||||||||||||||||||
| eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field])) | ||||||||||||||||||
| elif eig_solver == 'numpy': | ||||||||||||||||||
| if h_transformed_np is None: | ||||||||||||||||||
| h_transformed_np = data[self.h_out_field].detach().cpu().numpy() | ||||||||||||||||||
| eigvals_np = np.linalg.eigvalsh(a=h_transformed_np) | ||||||||||||||||||
| if batch_eigvals_np is not None: | ||||||||||||||||||
| eigvals_np = batch_eigvals_np | ||||||||||||||||||
| else: | ||||||||||||||||||
| if h_transformed_np is None: | ||||||||||||||||||
| h_transformed_np = data[self.h_out_field].detach().cpu().numpy() | ||||||||||||||||||
| eigvals_np = np.linalg.eigvalsh(a=h_transformed_np) | ||||||||||||||||||
| # Preserve dtype by converting to the Hamiltonian's original dtype | ||||||||||||||||||
| eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype, device=self.h2k.device)) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default
ill_threshold=1e-5silently changes behavior for all existing callers.The PR states backward compatibility is preserved by setting
ill_threshold=None, but the default is1e-5. This means all existing call sites (e.g.,loss.py,calculator.py,elec_struc_cal.py) that don't pass this parameter will now use the eigen-decomposition path instead of the Cholesky path—a silent behavioral change that could affect numerical results and performance.Consider defaulting to
ill_threshold=Noneto preserve existing behavior, requiring callers to explicitly opt-in to the fallback mechanism:Suggested change
def forward(self, data: AtomicDataDict.Type, nk: Optional[int]=None, eig_solver: str='torch', - ill_threshold: Optional[float]=1e-5) -> AtomicDataDict.Type: + ill_threshold: Optional[float]=None) -> AtomicDataDict.Type:📝 Committable suggestion
🤖 Prompt for AI Agents