Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 88 additions & 40 deletions mdpath/src/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from multiprocessing import Pool
from Bio import PDB
from itertools import combinations
from scipy.spatial import cKDTree
import logging


Expand Down Expand Up @@ -78,6 +79,50 @@ def calculate_distance(self, atom1: tuple, atom2: tuple) -> float:
distance = np.linalg.norm(distance_vector)
return distance

def _build_kdtree(self, dist: float):
"""Builds a KDTree from heavy atoms and returns close residue pairs.

Args:
dist (float): Distance cutoff.

Returns:
close_res_pairs (set): Set of (min_id, max_id) tuples for close residue pairs.
all_unique_res (list): Sorted list of all unique residue IDs.
"""
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure("pdb_structure", self.pdb)
heavy_atoms = {"C", "N", "O", "S"}
residues = [
res for res in structure.get_residues() if PDB.Polypeptide.is_aa(res)
]

coords = []
res_ids = []
for res in residues:
rid = res.get_id()[1]
if rid <= self.last_res_num:
for atom in res:
if atom.element in heavy_atoms:
coords.append(atom.coord)
res_ids.append(rid)

if not coords:
return set(), []

coords = np.array(coords)
res_ids = np.array(res_ids)
tree = cKDTree(coords)
atom_pairs = tree.query_pairs(r=dist)

close_res_pairs = set()
for i, j in atom_pairs:
r1, r2 = int(res_ids[i]), int(res_ids[j])
if r1 != r2:
close_res_pairs.add((min(r1, r2), max(r1, r2)))

all_unique_res = sorted(set(res_ids.tolist()))
return close_res_pairs, all_unique_res

def calculate_residue_suroundings(self, dist: float, mode: str) -> pd.DataFrame:
"""Calculates residues that are either close to or far away from each other in a PDB structure.

Expand All @@ -92,42 +137,46 @@ def calculate_residue_suroundings(self, dist: float, mode: str) -> pd.DataFrame:
if mode not in ["close", "far"]:
raise ValueError("Mode must be either 'close' or 'far'.")

parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure("pdb_structure", self.pdb)
heavy_atoms = ["C", "N", "O", "S"]
residue_pairs = []
residues = [
res for res in structure.get_residues() if PDB.Polypeptide.is_aa(res)
]
close_res_pairs, all_unique_res = self._build_kdtree(dist)

for res1, res2 in tqdm(
combinations(residues, 2),
desc=f"\033[1mCalculating {mode} residue surroundings\033[0m",
total=len(residues) * (len(residues) - 1) // 2,
):
res1_id = res1.get_id()[1]
res2_id = res2.get_id()[1]
if res1_id <= self.last_res_num and res2_id <= self.last_res_num:
condition_met = False if mode == "close" else True
for atom1 in res1:
if atom1.element in heavy_atoms:
for atom2 in res2:
if atom2.element in heavy_atoms:
distance = self.calculate_distance(
atom1.coord, atom2.coord
)
if (mode == "close" and distance <= dist) or (
mode == "far" and distance > dist
):
condition_met = True
break
if condition_met:
break
if condition_met:
residue_pairs.append((res1_id, res2_id))
if not all_unique_res:
return pd.DataFrame(columns=["Residue1", "Residue2"])

if mode == "close":
residue_pairs = sorted(close_res_pairs)
else:
all_res_pairs = set(
(min(r1, r2), max(r1, r2)) for r1, r2 in combinations(all_unique_res, 2)
)
residue_pairs = sorted(all_res_pairs - close_res_pairs)

return pd.DataFrame(residue_pairs, columns=["Residue1", "Residue2"])

def calculate_close_and_far(self, dist: float):
"""Calculates both close and far residue pairs in a single KDTree pass.

Args:
dist (float): Distance cutoff for residue pairs.

Returns:
df_close (pd.DataFrame): Close residue pairs.
df_far (pd.DataFrame): Far residue pairs.
"""
close_res_pairs, all_unique_res = self._build_kdtree(dist)

if not all_unique_res:
empty = pd.DataFrame(columns=["Residue1", "Residue2"])
return empty, empty.copy()

all_res_pairs = set(
(min(r1, r2), max(r1, r2)) for r1, r2 in combinations(all_unique_res, 2)
)
far_res_pairs = all_res_pairs - close_res_pairs

df_close = pd.DataFrame(sorted(close_res_pairs), columns=["Residue1", "Residue2"])
df_far = pd.DataFrame(sorted(far_res_pairs), columns=["Residue1", "Residue2"])
return df_close, df_far


class DihedralAngles:
"""Calculate dihedral angle movements for residues in a molecular dynamics (MD) trajectory.
Expand Down Expand Up @@ -167,7 +216,7 @@ def calc_dihedral_angle_movement(self, res_id: int) -> tuple:
try:
res = self.traj.residues[res_id]
ags = [res.phi_selection()]
if not all(ags): # Check if any selections are None
if not all(ags):
return None
R = Dihedral(ags).run()
dihedrals = R.results.angles
Expand All @@ -191,7 +240,7 @@ def calculate_dihedral_movement_parallel(
Returns:
pd.DataFrame: DataFrame with all residue dihedral angle movements.
"""
df_all_residues = pd.DataFrame()
collected = []

try:
with Pool(processes=num_parallel_processes) as pool:
Expand All @@ -212,11 +261,8 @@ def calculate_dihedral_movement_parallel(

res_id, dihedral_data = result
try:
df_residue = pd.DataFrame(
dihedral_data, columns=[f"Res {res_id}"]
)
df_all_residues = pd.concat(
[df_all_residues, df_residue], axis=1
collected.append(
pd.DataFrame(dihedral_data, columns=[f"Res {res_id}"])
)
except Exception as e:
logging.error(
Expand All @@ -228,4 +274,6 @@ def calculate_dihedral_movement_parallel(
except Exception as e:
logging.error(f"Parallel processing failed: {str(e)}")

return df_all_residues
if collected:
return pd.concat(collected, axis=1)
return pd.DataFrame()
Loading