From 0ba9402f5f611f2c337b1e3f871f52843bbf2e06 Mon Sep 17 00:00:00 2001 From: Rafael Date: Fri, 13 Mar 2026 19:13:00 +0200 Subject: [PATCH] chg: small corrections --- frame/evaluate.py | 6 ++++-- frame/explain.py | 6 ++++-- frame/generate.py | 6 +++++- frame/source/datasets/decompose.py | 4 ++-- frame/source/datasets/default.py | 2 +- frame/source/train/epoch.py | 6 +++--- frame/tune.py | 6 +++--- 7 files changed, 22 insertions(+), 14 deletions(-) diff --git a/frame/evaluate.py b/frame/evaluate.py index 893e9ce..a680889 100644 --- a/frame/evaluate.py +++ b/frame/evaluate.py @@ -53,14 +53,16 @@ def main(): # * Get checkpoint and prepare Explainer model = models.select_model(model_name, tune) - model.load_state_dict(torch.load(path_checkpoint)) + model.load_state_dict(torch.load(path_checkpoint, + map_location=device, + weights_only=True)) model.eval() agg_pred = [] agg_lbl = [] agg_true = [] for data in tqdm(test_loader, ncols=120, desc="Explaining"): - data.to(device) + data = data.to(device) # * Make predictions model_out = model(x=data.x.float(), diff --git a/frame/explain.py b/frame/explain.py index 644abfc..465c0ed 100644 --- a/frame/explain.py +++ b/frame/explain.py @@ -60,7 +60,9 @@ def main(): # * Get checkpoint and prepare Explainer model = models.select_model(model_name, tune) - model.load_state_dict(torch.load(path_checkpoint)) + model.load_state_dict(torch.load(path_checkpoint, + map_location=device, + weights_only=True)) model.eval() if task == "classification": @@ -77,7 +79,7 @@ def main(): return_type="raw")) for data in tqdm(dataloader, ncols=120, desc="Explaining"): - data.to(device) + data = data.to(device) # * Make predictions model_out = model(x=data.x.float(), diff --git a/frame/generate.py b/frame/generate.py index 02853bb..5a3ad50 100644 --- a/frame/generate.py +++ b/frame/generate.py @@ -42,7 +42,11 @@ def main(): raise NotImplementedError("Loader not available") # * Export - bce_weight = (len(dataset.y) - sum(dataset.y)) / sum(dataset.y) + task = params["Data"].get("task", "classification").lower() + if task == "classification" and sum(dataset.y) > 0: + bce_weight = (len(dataset.y) - sum(dataset.y)) / sum(dataset.y) + else: + bce_weight = torch.tensor(1.0) metadata = {"feat_size": dataset.num_node_features, "edge_dim": dataset.num_edge_features, "bce_weight": bce_weight, diff --git a/frame/source/datasets/decompose.py b/frame/source/datasets/decompose.py index 0e7536a..2ba9884 100644 --- a/frame/source/datasets/decompose.py +++ b/frame/source/datasets/decompose.py @@ -59,7 +59,7 @@ def process_data(self): # * Iterate data_list = [] for line in tqdm(dataset, ncols=120, desc="Creating graphs"): - line = re.sub(r"\'.*\'", "", line) # Replace ".*" strings. + line = re.sub(r"\'.*?\'", "", line) # Replace '...' strings. line = line.split(",") # Get label @@ -206,7 +206,7 @@ def _gen_features(smiles): # [single, double, triple, aromatic, conjugation, ring] + stereo) # edge_attrs += [edge_attr, edge_attr] - # frag_edge_attr = torch.stack(edge_attrs, dim=0) + # frag_edge_attr = torch.stack(edge_attrs, dim=0) agg_x = torch.sum(frag_x, dim=0) return agg_x diff --git a/frame/source/datasets/default.py b/frame/source/datasets/default.py index 8941069..44d3a21 100644 --- a/frame/source/datasets/default.py +++ b/frame/source/datasets/default.py @@ -59,7 +59,7 @@ def process_data(self): # * Iterate data_list = [] for line in tqdm(dataset, ncols=120, desc="Creating graphs"): - line = re.sub(r"\'.*\'", "", line) # Replace ".*" strings. + line = re.sub(r"\'.*?\'", "", line) # Replace '...' strings. line = line.split(",") # Get label diff --git a/frame/source/train/epoch.py b/frame/source/train/epoch.py index 1be88d0..c90be48 100644 --- a/frame/source/train/epoch.py +++ b/frame/source/train/epoch.py @@ -26,12 +26,12 @@ def train_epoch(model, optim, scheduler, lossfn, loader): model = model.train() for batch in loader: batch = batch.to(device) - optim.zero_grad(batch) + optim.zero_grad() # * Make predictions out = model(x=batch.x.float(), edge_index=batch.edge_index, - edge_attr=batch.edge_attr, + edge_attr=batch.edge_attr.float(), batch=batch.batch) # * Compute loss @@ -64,7 +64,7 @@ def valid_epoch(model, task, loader): # * Make predictions out = model(x=batch.x.float(), edge_index=batch.edge_index, - edge_attr=batch.edge_attr, + edge_attr=batch.edge_attr.float(), batch=batch.batch) # * Read prediction values diff --git a/frame/tune.py b/frame/tune.py index 476ae15..a8e8afb 100644 --- a/frame/tune.py +++ b/frame/tune.py @@ -64,8 +64,8 @@ def objective(trial, params, dataset): patience_counter = 0 best_model_state = None + start = time.time() for epoch in tqdm(range(epochs), ncols=120, desc="Training"): - start = time.time() _ = train.train_epoch(model, optim, schdlr, lossfn, train_loader) val_metrics = train.valid_epoch(model, task, valid_loader) @@ -82,7 +82,7 @@ def objective(trial, params, dataset): if patience_counter >= patience: break - fit_time = time.time() - start + fit_time = time.time() - start # Prepare best model model.load_state_dict(best_model_state) @@ -153,7 +153,7 @@ def main(): params = yaml.safe_load(stream) # * Initialize - task = name = params["Data"]["task"] + task = params["Data"]["task"] name = params["Data"]["name"] if name.lower() == "none": name = str(uuid.uuid4()).split("-")[0]