diff --git a/mdpath/src/structure.py b/mdpath/src/structure.py index 9f9e302..1a5207e 100644 --- a/mdpath/src/structure.py +++ b/mdpath/src/structure.py @@ -19,6 +19,7 @@ from multiprocessing import Pool from Bio import PDB from itertools import combinations +from scipy.spatial import cKDTree import logging @@ -78,6 +79,50 @@ def calculate_distance(self, atom1: tuple, atom2: tuple) -> float: distance = np.linalg.norm(distance_vector) return distance + def _build_kdtree(self, dist: float): + """Builds a KDTree from heavy atoms and returns close residue pairs. + + Args: + dist (float): Distance cutoff. + + Returns: + close_res_pairs (set): Set of (min_id, max_id) tuples for close residue pairs. + all_unique_res (list): Sorted list of all unique residue IDs. + """ + parser = PDB.PDBParser(QUIET=True) + structure = parser.get_structure("pdb_structure", self.pdb) + heavy_atoms = {"C", "N", "O", "S"} + residues = [ + res for res in structure.get_residues() if PDB.Polypeptide.is_aa(res) + ] + + coords = [] + res_ids = [] + for res in residues: + rid = res.get_id()[1] + if rid <= self.last_res_num: + for atom in res: + if atom.element in heavy_atoms: + coords.append(atom.coord) + res_ids.append(rid) + + if not coords: + return set(), [] + + coords = np.array(coords) + res_ids = np.array(res_ids) + tree = cKDTree(coords) + atom_pairs = tree.query_pairs(r=dist) + + close_res_pairs = set() + for i, j in atom_pairs: + r1, r2 = int(res_ids[i]), int(res_ids[j]) + if r1 != r2: + close_res_pairs.add((min(r1, r2), max(r1, r2))) + + all_unique_res = sorted(set(res_ids.tolist())) + return close_res_pairs, all_unique_res + def calculate_residue_suroundings(self, dist: float, mode: str) -> pd.DataFrame: """Calculates residues that are either close to or far away from each other in a PDB structure. @@ -92,42 +137,46 @@ def calculate_residue_suroundings(self, dist: float, mode: str) -> pd.DataFrame: if mode not in ["close", "far"]: raise ValueError("Mode must be either 'close' or 'far'.") - parser = PDB.PDBParser(QUIET=True) - structure = parser.get_structure("pdb_structure", self.pdb) - heavy_atoms = ["C", "N", "O", "S"] - residue_pairs = [] - residues = [ - res for res in structure.get_residues() if PDB.Polypeptide.is_aa(res) - ] + close_res_pairs, all_unique_res = self._build_kdtree(dist) - for res1, res2 in tqdm( - combinations(residues, 2), - desc=f"\033[1mCalculating {mode} residue surroundings\033[0m", - total=len(residues) * (len(residues) - 1) // 2, - ): - res1_id = res1.get_id()[1] - res2_id = res2.get_id()[1] - if res1_id <= self.last_res_num and res2_id <= self.last_res_num: - condition_met = False if mode == "close" else True - for atom1 in res1: - if atom1.element in heavy_atoms: - for atom2 in res2: - if atom2.element in heavy_atoms: - distance = self.calculate_distance( - atom1.coord, atom2.coord - ) - if (mode == "close" and distance <= dist) or ( - mode == "far" and distance > dist - ): - condition_met = True - break - if condition_met: - break - if condition_met: - residue_pairs.append((res1_id, res2_id)) + if not all_unique_res: + return pd.DataFrame(columns=["Residue1", "Residue2"]) + + if mode == "close": + residue_pairs = sorted(close_res_pairs) + else: + all_res_pairs = set( + (min(r1, r2), max(r1, r2)) for r1, r2 in combinations(all_unique_res, 2) + ) + residue_pairs = sorted(all_res_pairs - close_res_pairs) return pd.DataFrame(residue_pairs, columns=["Residue1", "Residue2"]) + def calculate_close_and_far(self, dist: float): + """Calculates both close and far residue pairs in a single KDTree pass. + + Args: + dist (float): Distance cutoff for residue pairs. + + Returns: + df_close (pd.DataFrame): Close residue pairs. + df_far (pd.DataFrame): Far residue pairs. + """ + close_res_pairs, all_unique_res = self._build_kdtree(dist) + + if not all_unique_res: + empty = pd.DataFrame(columns=["Residue1", "Residue2"]) + return empty, empty.copy() + + all_res_pairs = set( + (min(r1, r2), max(r1, r2)) for r1, r2 in combinations(all_unique_res, 2) + ) + far_res_pairs = all_res_pairs - close_res_pairs + + df_close = pd.DataFrame(sorted(close_res_pairs), columns=["Residue1", "Residue2"]) + df_far = pd.DataFrame(sorted(far_res_pairs), columns=["Residue1", "Residue2"]) + return df_close, df_far + class DihedralAngles: """Calculate dihedral angle movements for residues in a molecular dynamics (MD) trajectory. @@ -167,7 +216,7 @@ def calc_dihedral_angle_movement(self, res_id: int) -> tuple: try: res = self.traj.residues[res_id] ags = [res.phi_selection()] - if not all(ags): # Check if any selections are None + if not all(ags): return None R = Dihedral(ags).run() dihedrals = R.results.angles @@ -191,7 +240,7 @@ def calculate_dihedral_movement_parallel( Returns: pd.DataFrame: DataFrame with all residue dihedral angle movements. """ - df_all_residues = pd.DataFrame() + collected = [] try: with Pool(processes=num_parallel_processes) as pool: @@ -212,11 +261,8 @@ def calculate_dihedral_movement_parallel( res_id, dihedral_data = result try: - df_residue = pd.DataFrame( - dihedral_data, columns=[f"Res {res_id}"] - ) - df_all_residues = pd.concat( - [df_all_residues, df_residue], axis=1 + collected.append( + pd.DataFrame(dihedral_data, columns=[f"Res {res_id}"]) ) except Exception as e: logging.error( @@ -228,4 +274,6 @@ def calculate_dihedral_movement_parallel( except Exception as e: logging.error(f"Parallel processing failed: {str(e)}") - return df_all_residues + if collected: + return pd.concat(collected, axis=1) + return pd.DataFrame()