diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 64036d4e..6e452d6b 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -9,6 +9,7 @@ from itertools import cycle, permutations, product from typing import TYPE_CHECKING, Any, Generator, Literal, Optional +from networkx import DiGraph import numpy as np import pandas as pd from rdkit import Chem @@ -519,6 +520,45 @@ def processed_file_names_dict(self) -> dict: return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"} +class ChEBIFromList(_ChEBIDataExtractor): + """ + A ChEBI dataset where labels are selected from a predefined list of classes. + + """ + + READER = dr.ChemDataReader + + def __init__( + self, + class_list, + **kwargs, + ): + """ + Initializes the ChEBIFromList dataset. + + Args: + class_list: Path to a list of class IDs to be used as labels in the dataset. + **kwargs: Additional keyword arguments passed to the superclass initializer + """ + with open(class_list, "r") as f: + class_list = [line.strip() for line in f if line.strip()] + self.class_list = class_list + super().__init__(**kwargs) + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name. + """ + return "ChEBI_from_list" + + def select_classes(self, g: DiGraph, *args, **kwargs) -> List: + return self.class_list + + class ChEBIOverX(_ChEBIDataExtractor): """ A class for extracting data from the ChEBI dataset with a threshold for selecting classes.