From 6ae1518bce2b7617a7438c76798a251d45ef94a6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Mon, 23 Feb 2026 17:06:16 +0100 Subject: [PATCH] add chebiFromList dataset --- chebai/preprocessing/datasets/chebi.py | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index e82a83a8..2a84b9b3 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -17,6 +17,7 @@ from itertools import cycle, permutations, product from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Optional, Union +from networkx import DiGraph import numpy as np import pandas as pd import torch @@ -706,6 +707,45 @@ def select_classes(self, g, *args, **kwargs): return JCI_500_COLUMNS_INT +class ChEBIFromList(_ChEBIDataExtractor): + """ + A ChEBI dataset where labels are selected from a predefined list of classes. + + """ + + READER = dr.ChemDataReader + + def __init__( + self, + class_list, + **kwargs, + ): + """ + Initializes the ChEBIFromList dataset. + + Args: + class_list: Path to a list of class IDs to be used as labels in the dataset. + **kwargs: Additional keyword arguments passed to the superclass initializer + """ + with open(class_list, "r") as f: + class_list = [line.strip() for line in f if line.strip()] + self.class_list = class_list + super().__init__(**kwargs) + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name. + """ + return "ChEBI_from_list" + + def select_classes(self, g: DiGraph, *args, **kwargs) -> List: + return self.class_list + + class ChEBIOverX(_ChEBIDataExtractor): """ A class for extracting data from the ChEBI dataset with a threshold for selecting classes.