Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _load_dict(input_file_path: str) -> Generator[dict, None, None]:
Yields:
dict: Dictionary containing 'features', 'labels' (None), and 'ident' fields.
"""
# pubchem IDs are here
with open(input_file_path, "r") as input_file:
for row in input_file:
ident, smiles = row.split("\t")
Expand Down
2 changes: 1 addition & 1 deletion chebai/preprocessing/datasets/tox21.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _load_dict(self, input_file_path: str) -> List[Dict]:
for row in reader:
smiles = row["smiles"]
labels = [
bool(int(float(label))) if len(label) > 1 else None
bool(int(float(label))) if len(label) >= 1 else None
for label in (row[k] for k in self.HEADERS)
]
# group = int(row["group"])
Expand Down
7 changes: 5 additions & 2 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,11 @@ def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]:
under the additional `missing_labels` keyword."""
labels = self._get_raw_label(row)
additional_kwargs = self._get_additional_kwargs(row)
if any(label is None for label in labels):
additional_kwargs["missing_labels"] = [label is None for label in labels]
if labels is not None:
if any(label is None for label in labels):
additional_kwargs["missing_labels"] = [
label is None for label in labels
]
return dict(
features=self._get_raw_data(row),
labels=labels,
Expand Down
Loading