diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py index 59d72e888..f9f1ea1fa 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/sample_kg_dataset.py @@ -1,14 +1,14 @@ -from pyhealth.datasets import SampleBaseDataset +from torch.utils.data import Dataset -class SampleKGDataset(SampleBaseDataset): +class SampleKGDataset(Dataset): """Sample KG dataset class. - This class inherits from `SampleBaseDataset` and is specifically designed - for KG datasets. + This class inherits from `torch.utils.data.Dataset` and is specifically + designed for KG datasets. Args: - samples: a list of samples + samples: a list of samples A sample is a dict containing following data: { 'triple': a positive triple e.g., (0, 0, 2835) @@ -24,11 +24,11 @@ class SampleKGDataset(SampleBaseDataset): task_name: the name of the task. Default is None. """ def __init__( - self, - samples, - dataset_name="", - task_name="", - dev=False, + self, + samples, + dataset_name="", + task_name="", + dev=False, entity_num=0, relation_num=0, entity2id=None, @@ -36,7 +36,10 @@ def __init__( **kwargs ): - super().__init__(samples, dataset_name, task_name) + super().__init__() + self.samples = samples + self.dataset_name = dataset_name + self.task_name = task_name self.dev = dev self.entity_num = entity_num self.relation_num = relation_num @@ -65,6 +68,10 @@ def __getitem__(self, index): """ return self.samples[index] + def __len__(self): + """Returns the number of samples in the dataset.""" + return len(self.samples) + def stat(self): """Returns some statistics of the base dataset.""" lines = list() diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py index 559ecb7c4..d51aa226a 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py @@ -4,18 +4,16 @@ import numpy as np import torch -from pyhealth.datasets import SampleBaseDataset - def split( - dataset: SampleBaseDataset, + dataset, ratios: Union[Tuple[float, float, float], List[float]], seed: Optional[int] = None, ): """Splits the dataset by its outermost indexed items Args: - dataset: a `SampleBaseDataset` object + dataset: a `SampleKGDataset` object ratios: a list/tuple of ratios for train / val / test seed: random seed for shuffling the dataset diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py index 8fa2a443a..d530c80f8 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/complex.py @@ -1,5 +1,4 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset import torch @@ -13,7 +12,7 @@ class ComplEx(KGEBaseModel): def __init__( self, - dataset: SampleBaseDataset, + dataset, e_dim: int = 600, r_dim: int = 600, ns: str = "adv", diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py index e7563137c..80e40dd2d 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/distmult.py @@ -1,5 +1,4 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset import torch @@ -12,7 +11,7 @@ class DistMult(KGEBaseModel): """ def __init__( self, - dataset: SampleBaseDataset, + dataset, e_dim: int = 300, r_dim: int = 300, ns: str = "adv", diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py index 2de13afe2..c9dd75270 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/kg_base.py @@ -1,5 +1,4 @@ from abc import ABC -from pyhealth.datasets import SampleBaseDataset import torch import time @@ -32,7 +31,7 @@ def device(self): def __init__( self, - dataset: SampleBaseDataset, + dataset, e_dim: int = 500, r_dim: int = 500, ns: str = "uniform", diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py index df7143a6e..a932a6b5b 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/rotate.py @@ -1,5 +1,4 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset import torch @@ -13,7 +12,7 @@ class RotatE(KGEBaseModel): def __init__( self, - dataset: SampleBaseDataset, + dataset, e_dim: int = 600, r_dim: int = 300, ns='adv', diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py index fbb6e68f6..2799c4b79 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/models/transe.py @@ -1,5 +1,4 @@ from.kg_base import KGEBaseModel -from pyhealth.datasets import SampleBaseDataset import torch @@ -13,7 +12,7 @@ class TransE(KGEBaseModel): def __init__( self, - dataset: SampleBaseDataset, + dataset, e_dim: int = 300, r_dim: int = 300, ns: str = "adv", diff --git a/pyproject.toml b/pyproject.toml index 934d4f1bb..df69a48c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ keywords = [ [project.optional-dependencies] graph = [ "torch-geometric>=2.6.0", + "pandarallel", ] nlp = [ "editdistance~=0.8.1",