diff --git a/deeprank2/trainer.py b/deeprank2/trainer.py index 8f5b5d64..f60bf750 100644 --- a/deeprank2/trainer.py +++ b/deeprank2/trainer.py @@ -873,9 +873,17 @@ def test( def _load_params(self) -> None: """Loads the parameters of a pretrained model.""" if torch.cuda.is_available(): - state = torch.load(self.pretrained_model) + state = torch.load(self.pretrained_model, weights_only=False) else: - state = torch.load(self.pretrained_model, map_location=torch.device("cpu")) + state = torch.load(self.pretrained_model, map_location=torch.device("cpu"), weights_only=False) + + features_transform = state["features_transform"] + for value in features_transform.values(): + if value["transform"] is None: + continue + + # Deserialize the function + value["transform"] = dill.loads(value["transform"]) self.data_type = state["data_type"] self.model_load_state_dict = state["model_state"] @@ -901,7 +909,7 @@ def _load_params(self) -> None: self.node_features = state["node_features"] self.edge_features = state["edge_features"] self.features = state["features"] - self.features_transform = state["features_transform"] + self.features_transform = features_transform self.means = state["means"] self.devs = state["devs"] self.cuda = state["cuda"] @@ -912,16 +920,12 @@ def _save_model(self) -> dict[str, Any]: features_transform_to_save = copy.deepcopy(self.features_transform) # prepare transform dictionary for being saved if features_transform_to_save: - for key in features_transform_to_save.values(): - if key["transform"] is None: + for value in features_transform_to_save.values(): + if value["transform"] is None: continue + # Serialize the function - serialized_func = dill.dumps(key["transform"]) - # Deserialize the function - deserialized_func = dill.loads(serialized_func) # noqa: S301 - str_expr = inspect.getsource(deserialized_func) - match = re.search(r"[\"|\']transform[\"|\']:.*(lambda.*).*,.*[\"|\']standardize[\"|\'].*", str_expr).group(1) - key["transform"] = match + value["transform"] = dill.dumps(value["transform"]) state = { "data_type": self.data_type, diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b57ffdd6..0194e1c1 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -5,7 +5,9 @@ import tempfile import unittest import warnings +import uuid +import numpy import h5py import pandas as pd import pytest @@ -124,6 +126,32 @@ def setUpClass(class_) -> None: def tearDownClass(class_) -> None: shutil.rmtree(class_.work_directory) + def test_save_transform_function(sef) -> None: + dataset = GridDataset( + hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5", + subset=None, + features=[Efeat.VDW], + target=targets.IRMSD, + task=targets.REGRESS, + ) + trainer = Trainer(CnnRegression, dataset) + trainer.features_transform = {"float32": {"transform": lambda t: t.astype(numpy.float32), "standardize": True}} + + state = trainer._save_model() + assert "features_transform" in state + assert "float32" in state["features_transform"] + + tmp_path = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".pth") + try: + torch.save(state, tmp_path) + + trainer.pretrained_model = tmp_path + trainer._load_params() + finally: + os.remove(tmp_path) + + trainer.features_transform["float32"]["transform"](numpy.array([0.0])) + def test_grid_regression(self) -> None: dataset = GridDataset( hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5",