Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions deeprank2/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,9 +873,17 @@ def test(
def _load_params(self) -> None:
"""Loads the parameters of a pretrained model."""
if torch.cuda.is_available():
state = torch.load(self.pretrained_model)
state = torch.load(self.pretrained_model, weights_only=False)
else:
state = torch.load(self.pretrained_model, map_location=torch.device("cpu"))
state = torch.load(self.pretrained_model, map_location=torch.device("cpu"), weights_only=False)

features_transform = state["features_transform"]
for value in features_transform.values():
if value["transform"] is None:
continue

# Deserialize the function
value["transform"] = dill.loads(value["transform"])

self.data_type = state["data_type"]
self.model_load_state_dict = state["model_state"]
Expand All @@ -901,7 +909,7 @@ def _load_params(self) -> None:
self.node_features = state["node_features"]
self.edge_features = state["edge_features"]
self.features = state["features"]
self.features_transform = state["features_transform"]
self.features_transform = features_transform
self.means = state["means"]
self.devs = state["devs"]
self.cuda = state["cuda"]
Expand All @@ -912,16 +920,12 @@ def _save_model(self) -> dict[str, Any]:
features_transform_to_save = copy.deepcopy(self.features_transform)
# prepare transform dictionary for being saved
if features_transform_to_save:
for key in features_transform_to_save.values():
if key["transform"] is None:
for value in features_transform_to_save.values():
if value["transform"] is None:
continue

# Serialize the function
serialized_func = dill.dumps(key["transform"])
# Deserialize the function
deserialized_func = dill.loads(serialized_func) # noqa: S301
str_expr = inspect.getsource(deserialized_func)
match = re.search(r"[\"|\']transform[\"|\']:.*(lambda.*).*,.*[\"|\']standardize[\"|\'].*", str_expr).group(1)
key["transform"] = match
value["transform"] = dill.dumps(value["transform"])

state = {
"data_type": self.data_type,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import tempfile
import unittest
import warnings
import uuid

import numpy
import h5py
import pandas as pd
import pytest
Expand Down Expand Up @@ -124,6 +126,32 @@ def setUpClass(class_) -> None:
def tearDownClass(class_) -> None:
shutil.rmtree(class_.work_directory)

def test_save_transform_function(sef) -> None:
dataset = GridDataset(
hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5",
subset=None,
features=[Efeat.VDW],
target=targets.IRMSD,
task=targets.REGRESS,
)
trainer = Trainer(CnnRegression, dataset)
trainer.features_transform = {"float32": {"transform": lambda t: t.astype(numpy.float32), "standardize": True}}

state = trainer._save_model()
assert "features_transform" in state
assert "float32" in state["features_transform"]

tmp_path = os.path.join(tempfile.gettempdir(), uuid.uuid4().hex + ".pth")
try:
torch.save(state, tmp_path)

trainer.pretrained_model = tmp_path
trainer._load_params()
finally:
os.remove(tmp_path)

trainer.features_transform["float32"]["transform"](numpy.array([0.0]))

def test_grid_regression(self) -> None:
dataset = GridDataset(
hdf5_path="tests/data/hdf5/1ATN_ppi.hdf5",
Expand Down
Loading