Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from itertools import cycle, permutations, product
from typing import TYPE_CHECKING, Any, Generator, Literal, Optional

from networkx import DiGraph
import numpy as np
import pandas as pd
from rdkit import Chem
Expand Down Expand Up @@ -519,6 +520,45 @@ def processed_file_names_dict(self) -> dict:
return {"data": f"aug_data_var{self.aug_smiles_variations}.pt"}


class ChEBIFromList(_ChEBIDataExtractor):
"""
A ChEBI dataset where labels are selected from a predefined list of classes.

"""

READER = dr.ChemDataReader

def __init__(
self,
class_list,
**kwargs,
):
"""
Initializes the ChEBIFromList dataset.

Args:
class_list: Path to a list of class IDs to be used as labels in the dataset.
**kwargs: Additional keyword arguments passed to the superclass initializer
"""
with open(class_list, "r") as f:
class_list = [line.strip() for line in f if line.strip()]
self.class_list = class_list
super().__init__(**kwargs)

@property
def _name(self) -> str:
"""
Returns the name of the dataset.

Returns:
str: The dataset name.
"""
return "ChEBI_from_list"

def select_classes(self, g: DiGraph, *args, **kwargs) -> List:
return self.class_list


class ChEBIOverX(_ChEBIDataExtractor):
"""
A class for extracting data from the ChEBI dataset with a threshold for selecting classes.
Expand Down
Loading