Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions frame/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ def main():

# * Get checkpoint and prepare Explainer
model = models.select_model(model_name, tune)
model.load_state_dict(torch.load(path_checkpoint))
model.load_state_dict(torch.load(path_checkpoint,
map_location=device,
weights_only=True))
model.eval()

agg_pred = []
agg_lbl = []
agg_true = []
for data in tqdm(test_loader, ncols=120, desc="Explaining"):
data.to(device)
data = data.to(device)

# * Make predictions
model_out = model(x=data.x.float(),
Expand Down
6 changes: 4 additions & 2 deletions frame/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def main():

# * Get checkpoint and prepare Explainer
model = models.select_model(model_name, tune)
model.load_state_dict(torch.load(path_checkpoint))
model.load_state_dict(torch.load(path_checkpoint,
map_location=device,
weights_only=True))
model.eval()

if task == "classification":
Expand All @@ -77,7 +79,7 @@ def main():
return_type="raw"))

for data in tqdm(dataloader, ncols=120, desc="Explaining"):
data.to(device)
data = data.to(device)

# * Make predictions
model_out = model(x=data.x.float(),
Expand Down
6 changes: 5 additions & 1 deletion frame/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def main():
raise NotImplementedError("Loader not available")

# * Export
bce_weight = (len(dataset.y) - sum(dataset.y)) / sum(dataset.y)
task = params["Data"].get("task", "classification").lower()
if task == "classification" and sum(dataset.y) > 0:
bce_weight = (len(dataset.y) - sum(dataset.y)) / sum(dataset.y)
else:
bce_weight = torch.tensor(1.0)
metadata = {"feat_size": dataset.num_node_features,
"edge_dim": dataset.num_edge_features,
"bce_weight": bce_weight,
Expand Down
4 changes: 2 additions & 2 deletions frame/source/datasets/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def process_data(self):
# * Iterate
data_list = []
for line in tqdm(dataset, ncols=120, desc="Creating graphs"):
line = re.sub(r"\'.*\'", "", line) # Replace ".*" strings.
line = re.sub(r"\'.*?\'", "", line) # Replace '...' strings.
line = line.split(",")

# Get label
Expand Down Expand Up @@ -206,7 +206,7 @@ def _gen_features(smiles):
# [single, double, triple, aromatic, conjugation, ring] + stereo)

# edge_attrs += [edge_attr, edge_attr]
# frag_edge_attr = torch.stack(edge_attrs, dim=0)
# frag_edge_attr = torch.stack(edge_attrs, dim=0)

agg_x = torch.sum(frag_x, dim=0)
return agg_x
2 changes: 1 addition & 1 deletion frame/source/datasets/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def process_data(self):
# * Iterate
data_list = []
for line in tqdm(dataset, ncols=120, desc="Creating graphs"):
line = re.sub(r"\'.*\'", "", line) # Replace ".*" strings.
line = re.sub(r"\'.*?\'", "", line) # Replace '...' strings.
line = line.split(",")

# Get label
Expand Down
6 changes: 3 additions & 3 deletions frame/source/train/epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def train_epoch(model, optim, scheduler, lossfn, loader):
model = model.train()
for batch in loader:
batch = batch.to(device)
optim.zero_grad(batch)
optim.zero_grad()

# * Make predictions
out = model(x=batch.x.float(),
edge_index=batch.edge_index,
edge_attr=batch.edge_attr,
edge_attr=batch.edge_attr.float(),
batch=batch.batch)

# * Compute loss
Expand Down Expand Up @@ -64,7 +64,7 @@ def valid_epoch(model, task, loader):
# * Make predictions
out = model(x=batch.x.float(),
edge_index=batch.edge_index,
edge_attr=batch.edge_attr,
edge_attr=batch.edge_attr.float(),
batch=batch.batch)

# * Read prediction values
Expand Down
6 changes: 3 additions & 3 deletions frame/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def objective(trial, params, dataset):
patience_counter = 0
best_model_state = None

start = time.time()
for epoch in tqdm(range(epochs), ncols=120, desc="Training"):
start = time.time()
_ = train.train_epoch(model, optim, schdlr,
lossfn, train_loader)
val_metrics = train.valid_epoch(model, task, valid_loader)
Expand All @@ -82,7 +82,7 @@ def objective(trial, params, dataset):
if patience_counter >= patience:
break

fit_time = time.time() - start
fit_time = time.time() - start

# Prepare best model
model.load_state_dict(best_model_state)
Expand Down Expand Up @@ -153,7 +153,7 @@ def main():
params = yaml.safe_load(stream)

# * Initialize
task = name = params["Data"]["task"]
task = params["Data"]["task"]
name = params["Data"]["name"]
if name.lower() == "none":
name = str(uuid.uuid4()).split("-")[0]
Expand Down
Loading