Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 121 additions & 14 deletions dptb/nn/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ def __init__(
def forward(self,
data: AtomicDataDict.Type,
nk: Optional[int]=None,
eig_solver: str='torch') -> AtomicDataDict.Type:
eig_solver: str='torch',
ill_threshold: Optional[float]=1e-5) -> AtomicDataDict.Type:
Comment on lines +63 to +64
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Default ill_threshold=1e-5 silently changes behavior for all existing callers.

The PR states backward compatibility is preserved by setting ill_threshold=None, but the default is 1e-5. This means all existing call sites (e.g., loss.py, calculator.py, elec_struc_cal.py) that don't pass this parameter will now use the eigen-decomposition path instead of the Cholesky path—a silent behavioral change that could affect numerical results and performance.

Consider defaulting to ill_threshold=None to preserve existing behavior, requiring callers to explicitly opt-in to the fallback mechanism:

Suggested change
     def forward(self, 
                 data: AtomicDataDict.Type, 
                 nk: Optional[int]=None,
                 eig_solver: str='torch',
-                ill_threshold: Optional[float]=1e-5) -> AtomicDataDict.Type:
+                ill_threshold: Optional[float]=None) -> AtomicDataDict.Type:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
eig_solver: str='torch',
ill_threshold: Optional[float]=1e-5) -> AtomicDataDict.Type:
eig_solver: str='torch',
ill_threshold: Optional[float]=None) -> AtomicDataDict.Type:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@dptb/nn/energy.py` around lines 63 - 64, The default value for the parameter
ill_threshold in the function signature in dptb/nn/energy.py currently uses 1e-5
which silently changes runtime behavior for all existing callers; change the
default to None (ill_threshold: Optional[float]=None) so callers keep the
original Cholesky-first behavior and must opt-in to the eigendecomposition
fallback, update any docstring/parameter description mentioning ill_threshold
and adjust any tests or call sites that relied on the old default to explicitly
pass 1e-5 if they intend to opt into the fallback behavior; ensure the code
paths that check ill_threshold treat None as “do not use fallback” and only use
numeric values to trigger the eigendecomposition path.


if eig_solver is None:
eig_solver = 'torch'
log.warning("eig_solver is not set, using default 'torch'.")
if eig_solver not in ['torch', 'numpy']:
log.error(f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}.")
raise ValueError
raise ValueError
Comment on lines 69 to +71
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Include error message in the exception.

The exception is raised without a message, making it harder to diagnose when caught elsewhere. Move the message into the exception:

Suggested fix
         if eig_solver not in ['torch', 'numpy']:
-            log.error(f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}.")
-            raise ValueError
+            msg = f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}."
+            log.error(msg)
+            raise ValueError(msg)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if eig_solver not in ['torch', 'numpy']:
log.error(f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}.")
raise ValueError
raise ValueError
if eig_solver not in ['torch', 'numpy']:
msg = f"eig_solver should be 'torch' or 'numpy', but got {eig_solver}."
log.error(msg)
raise ValueError(msg)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@dptb/nn/energy.py` around lines 69 - 71, The guard checking eig_solver
currently logs an error but raises ValueError without a message; update the
raise to include a descriptive message (same text as logged) so callers receive
the error details. Locate the check for eig_solver in dptb.nn.energy (the if
eig_solver not in ['torch', 'numpy'] block) and change the raise ValueError to
raise ValueError(f"eig_solver should be 'torch' or 'numpy', but got
{eig_solver}.") ensuring the log.error call can remain or be removed as you
prefer.


kpoints = data[AtomicDataDict.KPOINT_KEY]
if kpoints.is_nested:
Expand All @@ -80,29 +81,135 @@ def forward(self,
eigvals = []
if nk is None:
nk = num_k

for i in range(int(np.ceil(num_k / nk))):
data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk]
data = self.h2k(data)
h_transformed_np = None

batch_eigvals_torch = None
batch_eigvals_np = None

if self.overlap:
data = self.s2k(data)
if eig_solver == 'torch':
chklowt = torch.linalg.cholesky(data[self.s_out_field])
chklowtinv = torch.linalg.inv(chklowt)
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
if ill_threshold is None:
chklowt = torch.linalg.cholesky(data[self.s_out_field])
chklowtinv = torch.linalg.inv(chklowt)
data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj())
else:
S_k = data[self.s_out_field]
H_k = data[self.h_out_field]
egval_S, egvec_S = torch.linalg.eigh(S_k)
B = S_k.shape[0]
num_orbitals = H_k.shape[-1]
real_dtype = torch.float32 if H_k.dtype in [torch.complex64, torch.float32] else torch.float64

processed_eigvals_list = []
for k_idx in range(B):
healthy_mask = egval_S[k_idx] > ill_threshold
n_healthy = int(healthy_mask.sum().item())

if n_healthy == 0:
egval = torch.full((num_orbitals,), 1e4, dtype=real_dtype, device=H_k.device)
processed_eigvals_list.append(egval)
continue
Comment on lines +113 to +116
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded padding value 1e4 may cause silent issues in downstream physics calculations.

The padding value propagates to band structure, DOS, and Fermi surface calculations (see get_eigenvalues in calculator.py). A fixed value of 1e4 could:

  • Incorrectly contribute to density of states at high energies
  • Affect Fermi level calculations if not filtered out
  • Be indistinguishable from legitimate high-energy eigenvalues

Consider:

  1. Using torch.inf/np.inf to make padded values clearly distinguishable
  2. Making the padding value configurable via parameter
  3. Documenting that consumers must filter out these sentinel values
Example using inf
                         if n_healthy == 0:
-                            egval = torch.full((num_orbitals,), 1e4, dtype=real_dtype, device=H_k.device)
+                            egval = torch.full((num_orbitals,), float('inf'), dtype=real_dtype, device=H_k.device)
                             ...
                         if num_projected_out > 0:
-                            padding = torch.full((num_projected_out,), 1e4, dtype=egval_proj.dtype, device=egval_proj.device)
+                            padding = torch.full((num_projected_out,), float('inf'), dtype=egval_proj.dtype, device=egval_proj.device)

Also applies to: 137-143

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@dptb/nn/energy.py` around lines 113 - 116, The hardcoded padding value 1e4 in
energy.py (where egval is created and appended to processed_eigvals_list when
n_healthy == 0) should be replaced with a clear sentinel: change construction of
egval to use torch.inf (or np.inf) or make the padding value configurable by
adding a pad_value parameter to the enclosing function/class and using that
value when creating egval; update all other occurrences (lines ~137-143)
similarly, propagate the new pad_value into callers such as get_eigenvalues in
calculator.py, and add a short docstring note that consumers must filter out the
sentinel (inf) values when computing DOS/Fermi-levels so they are not treated as
real high-energy eigenvalues.


if healthy_mask.all():
L = torch.linalg.cholesky(S_k[k_idx])
L_inv = torch.linalg.inv(L)
H_transformed = L_inv @ H_k[k_idx] @ L_inv.conj().T
egval = torch.linalg.eigvalsh(H_transformed)
processed_eigvals_list.append(egval)
continue

U_sel = egvec_S[k_idx, :, healthy_mask]
eval_sel = egval_S[k_idx, healthy_mask]

H_proj = U_sel.conj().T @ H_k[k_idx] @ U_sel
S_proj = torch.diag(eval_sel).to(dtype=H_proj.dtype, device=H_proj.device)

L = torch.linalg.cholesky(S_proj)
L_inv = torch.linalg.inv(L)
H_transformed = L_inv @ H_proj @ L_inv.conj().T
egval_proj = torch.linalg.eigvalsh(H_transformed)

num_projected_out = num_orbitals - egval_proj.shape[0]
if num_projected_out > 0:
padding = torch.full((num_projected_out,), 1e4, dtype=egval_proj.dtype, device=egval_proj.device)
egval = torch.cat([egval_proj, padding], dim=0)
else:
egval = egval_proj

processed_eigvals_list.append(egval)
batch_eigvals_torch = torch.stack(processed_eigvals_list, dim=0)

elif eig_solver == 'numpy':
s_np = data[self.s_out_field].detach().cpu().numpy()
h_np = data[self.h_out_field].detach().cpu().numpy()
chklowt = np.linalg.cholesky(s_np)
chklowtinv = np.linalg.inv(chklowt)
h_transformed_np = chklowtinv @ h_np @ np.transpose(chklowtinv,(0,2,1)).conj()
if ill_threshold is None:
s_np = data[self.s_out_field].detach().cpu().numpy()
h_np = data[self.h_out_field].detach().cpu().numpy()
chklowt = np.linalg.cholesky(s_np)
chklowtinv = np.linalg.inv(chklowt)
h_transformed_np = chklowtinv @ h_np @ np.transpose(chklowtinv,(0,2,1)).conj()
else:
s_np = data[self.s_out_field].detach().cpu().numpy()
h_np = data[self.h_out_field].detach().cpu().numpy()
egval_S, egvec_S = np.linalg.eigh(s_np)
B = s_np.shape[0]
num_orbitals = h_np.shape[-1]
real_dtype = np.float32 if h_np.dtype in [np.complex64, np.float32] else np.float64

processed_eigvals_list = []
for k_idx in range(B):
healthy_mask = egval_S[k_idx] > ill_threshold
n_healthy = int(healthy_mask.sum())

if n_healthy == 0:
egval = np.full((num_orbitals,), 1e4, dtype=real_dtype)
processed_eigvals_list.append(egval)
continue

if healthy_mask.all():
L = np.linalg.cholesky(s_np[k_idx])
L_inv = np.linalg.inv(L)
H_transformed = L_inv @ h_np[k_idx] @ L_inv.conj().T
egval = np.linalg.eigvalsh(H_transformed)
processed_eigvals_list.append(egval)
continue

U_sel = egvec_S[k_idx, :, healthy_mask]
eval_sel = egval_S[k_idx, healthy_mask]

H_proj = U_sel.conj().T @ h_np[k_idx] @ U_sel
S_proj = np.diag(eval_sel).astype(H_proj.dtype)

L = np.linalg.cholesky(S_proj)
L_inv = np.linalg.inv(L)
H_transformed = L_inv @ H_proj @ L_inv.conj().T
egval_proj = np.linalg.eigvalsh(H_transformed)

num_projected_out = num_orbitals - egval_proj.shape[0]
if num_projected_out > 0:
padding = np.full((num_projected_out,), 1e4, dtype=egval_proj.dtype)
egval = np.concatenate([egval_proj, padding], axis=0)
else:
egval = egval_proj

processed_eigvals_list.append(egval)
batch_eigvals_np = np.stack(processed_eigvals_list, axis=0)

if eig_solver == 'torch':
eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
if batch_eigvals_torch is not None:
eigvals.append(batch_eigvals_torch)
else:
eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field]))
elif eig_solver == 'numpy':
if h_transformed_np is None:
h_transformed_np = data[self.h_out_field].detach().cpu().numpy()
eigvals_np = np.linalg.eigvalsh(a=h_transformed_np)
if batch_eigvals_np is not None:
eigvals_np = batch_eigvals_np
else:
if h_transformed_np is None:
h_transformed_np = data[self.h_out_field].detach().cpu().numpy()
eigvals_np = np.linalg.eigvalsh(a=h_transformed_np)
# Preserve dtype by converting to the Hamiltonian's original dtype
eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype, device=self.h2k.device))

Expand Down
Loading