Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/interpretability/gim_stagenet_mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def load_icd_description_map(dataset_root: str) -> dict:
dropout=0.3,
)

state_dict = torch.load("../resources/best.ckpt", map_location=device)
state_dict = torch.load("../resources/best.ckpt", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion examples/interpretability/gim_transformer_mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def load_icd_description_map(dataset_root: str) -> dict:
f"Missing pretrained weights at {ckpt_path}. "
"Train the Transformer model and place the checkpoint in ../resources/."
)
state_dict = torch.load(str(ckpt_path), map_location=device)
state_dict = torch.load(str(ckpt_path), map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion examples/interpretability/shap_stagenet_mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main():
dropout=0.3,
)

state_dict = torch.load("../resources/best.ckpt", map_location=device)
state_dict = torch.load("../resources/best.ckpt", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion examples/lime_stagenet_mimic4.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def main():
dropout=0.3,
)

state_dict = torch.load("../resources/best.ckpt", map_location=device)
state_dict = torch.load("../resources/best.ckpt", map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/calib/predictionset/favmac/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _greedy_sequence(self, pred:np.ndarray):
if self.proxy_fn.is_additive():
Ss, _ = self.util_fn.greedy_maximize_seq(pred=pred, d_proxy = self.proxy_fn.values * (1-pred))
return Ss, list(map(proxy_fn, Ss))
except:
except Exception:
pass

Ss = [np.zeros(len(pred), dtype=int)]
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/calib/predictionset/scrib/quicksearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
pyximport.install()
from . import quicksearch_cython as cdc
_CYTHON_ENABLED = True
except:
except Exception:
print("This is a warning of potentially slow compute. You could uncomment this line and use the Python implementation instead of Cython.")

__all__ = ['loss_overall', 'loss_classspecific',
Expand Down
4 changes: 2 additions & 2 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ def _task_transform(
while not result.ready():
try:
progress.update(queue.get(timeout=1))
except:
except Exception:
pass

# remaining items
Expand Down Expand Up @@ -897,7 +897,7 @@ def _proc_transform(
while not result.ready():
try:
progress.update(queue.get(timeout=1))
except:
except Exception:
pass

# remaining items
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def is_homo_list(l: List) -> bool:
return True

# if the value vector is a mix of float and int, convert all to float
l = [float(i) if type(i) == int else i for i in l]
l = [float(i) if isinstance(i, int) else i for i in l]
return all(isinstance(i, type(l[0])) for i in l)


Expand Down
2 changes: 1 addition & 1 deletion pyhealth/interpret/methods/basic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def attribute(self, save_to_batch=False, **data) -> Dict[str, torch.Tensor]:
output = self.model(image=batch_images, disease=batch_labels)
y_prob = output['y_prob']
target_class = y_prob.argmax(dim=1)
scores = y_prob.gather(1, target_class.unsqueeze(1)).squeeze()
scores = y_prob.gather(1, target_class.unsqueeze(1)).squeeze(1)

# Compute gradients
self.model.zero_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)

print('Loaded model: ', model)
state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt")
state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt", weights_only=True)
model.load_state_dict(state_dict)

# initialize a trainer and start training
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def ranking_metrics_fn(qrels: Dict[str, Dict[str, int]],
"""
try:
import pytrec_eval
except:
except ImportError:
raise ImportError("pytrec_eval is not installed. Please install it manually by running \
'pip install pytrec_eval'.")
ret = {}
Expand Down
4 changes: 2 additions & 2 deletions pyhealth/tasks/temple_university_EEG_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
"signal_file": edf_path,
"split": split,
"signal": signal,
"offending_channel": int(offending_channel.squeeze()),
"label": int(label.squeeze()) - 1,
"offending_channel": int(offending_channel.item()),
"label": int(label.item()) - 1,
}
if self.compute_stft:
# get_stft_torch expects (B, C, T); unsqueeze/squeeze the batch dim
Expand Down
Loading