From 091c65a1b5af122f097cbebb4f69e6ca4b6d7f9e Mon Sep 17 00:00:00 2001 From: schnamo Date: Wed, 18 Mar 2026 14:00:49 +0100 Subject: [PATCH 1/2] fixes for tox21 dataset --- chebai/preprocessing/datasets/pubchem.py | 1 + chebai/preprocessing/datasets/tox21.py | 2 +- chebai/preprocessing/reader.py | 5 +++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 2f169b0e..ecb9c011 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -106,6 +106,7 @@ def _load_dict(input_file_path: str) -> Generator[dict, None, None]: Yields: dict: Dictionary containing 'features', 'labels' (None), and 'ident' fields. """ + # pubchem IDs are here with open(input_file_path, "r") as input_file: for row in input_file: ident, smiles = row.split("\t") diff --git a/chebai/preprocessing/datasets/tox21.py b/chebai/preprocessing/datasets/tox21.py index 709c620d..f6298293 100644 --- a/chebai/preprocessing/datasets/tox21.py +++ b/chebai/preprocessing/datasets/tox21.py @@ -161,7 +161,7 @@ def _load_dict(self, input_file_path: str) -> List[Dict]: for row in reader: smiles = row["smiles"] labels = [ - bool(int(float(label))) if len(label) > 1 else None + bool(int(float(label))) if len(label) >= 1 else None for label in (row[k] for k in self.HEADERS) ] # group = int(row["group"]) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index b0297a02..211f169a 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -98,8 +98,9 @@ def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]: under the additional `missing_labels` keyword.""" labels = self._get_raw_label(row) additional_kwargs = self._get_additional_kwargs(row) - if any(label is None for label in labels): - additional_kwargs["missing_labels"] = [label is None for label in labels] + if labels is not None: + if any(label is None for label in labels): + additional_kwargs["missing_labels"] = [label is None for label in labels] return dict( features=self._get_raw_data(row), labels=labels, From 6e382bb506cd998db94e9c1bbb964addad8f2287 Mon Sep 17 00:00:00 2001 From: schnamo Date: Wed, 18 Mar 2026 14:10:09 +0100 Subject: [PATCH 2/2] ruff fixes --- chebai/preprocessing/reader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 211f169a..cc70be7f 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -100,7 +100,9 @@ def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]: additional_kwargs = self._get_additional_kwargs(row) if labels is not None: if any(label is None for label in labels): - additional_kwargs["missing_labels"] = [label is None for label in labels] + additional_kwargs["missing_labels"] = [ + label is None for label in labels + ] return dict( features=self._get_raw_data(row), labels=labels,