From 65ea948e86322af781db31f78087c561777e5193 Mon Sep 17 00:00:00 2001 From: SexyERIC0723 Date: Fri, 3 Apr 2026 10:08:16 +0100 Subject: [PATCH] fix: harden torch.load, bare except, squeeze, and type checks Security: - Add weights_only=True to 5 torch.load() calls that were missing it, preventing arbitrary code execution from untrusted pickle files Correctness: - Replace 5 bare except: clauses with except Exception: or except ImportError: to avoid swallowing KeyboardInterrupt/SystemExit (metrics/ranking.py, datasets/base_dataset.py x2, calib/predictionset/favmac/core.py, calib/predictionset/scrib/quicksearch.py) - Replace .squeeze() with .squeeze(1) in interpret/basic_gradient.py to prevent batch dimension collapse when batch_size=1 - Replace .squeeze() with .item() in tasks/temple_university_EEG_tasks.py for scalar tensor extraction (clearer intent, no dimension ambiguity) - Replace type(i) == int with isinstance(i, int) in datasets/utils.py to correctly handle numpy integer types and bool subclasses --- examples/interpretability/gim_stagenet_mimic4.py | 2 +- examples/interpretability/gim_transformer_mimic4.py | 2 +- examples/interpretability/shap_stagenet_mimic4.py | 2 +- examples/lime_stagenet_mimic4.py | 2 +- pyhealth/calib/predictionset/favmac/core.py | 2 +- pyhealth/calib/predictionset/scrib/quicksearch.py | 2 +- pyhealth/datasets/base_dataset.py | 4 ++-- pyhealth/datasets/utils.py | 2 +- pyhealth/interpret/methods/basic_gradient.py | 2 +- .../pretrained_embeddings/kg_emb/examples/train_kge_model.py | 2 +- pyhealth/metrics/ranking.py | 2 +- pyhealth/tasks/temple_university_EEG_tasks.py | 4 ++-- 12 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/interpretability/gim_stagenet_mimic4.py b/examples/interpretability/gim_stagenet_mimic4.py index b38eb7a87..e6e4659b9 100644 --- a/examples/interpretability/gim_stagenet_mimic4.py +++ b/examples/interpretability/gim_stagenet_mimic4.py @@ -81,7 +81,7 @@ def load_icd_description_map(dataset_root: str) -> dict: dropout=0.3, ) -state_dict = torch.load("../resources/best.ckpt", map_location=device) +state_dict = torch.load("../resources/best.ckpt", map_location=device, weights_only=True) model.load_state_dict(state_dict) model = model.to(device) model.eval() diff --git a/examples/interpretability/gim_transformer_mimic4.py b/examples/interpretability/gim_transformer_mimic4.py index 884be99a5..4cc479ee7 100644 --- a/examples/interpretability/gim_transformer_mimic4.py +++ b/examples/interpretability/gim_transformer_mimic4.py @@ -113,7 +113,7 @@ def load_icd_description_map(dataset_root: str) -> dict: f"Missing pretrained weights at {ckpt_path}. " "Train the Transformer model and place the checkpoint in ../resources/." ) -state_dict = torch.load(str(ckpt_path), map_location=device) +state_dict = torch.load(str(ckpt_path), map_location=device, weights_only=True) model.load_state_dict(state_dict) model = model.to(device) model.eval() diff --git a/examples/interpretability/shap_stagenet_mimic4.py b/examples/interpretability/shap_stagenet_mimic4.py index a06a9300d..d0faa1f20 100644 --- a/examples/interpretability/shap_stagenet_mimic4.py +++ b/examples/interpretability/shap_stagenet_mimic4.py @@ -186,7 +186,7 @@ def main(): dropout=0.3, ) - state_dict = torch.load("../resources/best.ckpt", map_location=device) + state_dict = torch.load("../resources/best.ckpt", map_location=device, weights_only=True) model.load_state_dict(state_dict) model = model.to(device) model.eval() diff --git a/examples/lime_stagenet_mimic4.py b/examples/lime_stagenet_mimic4.py index b3b54085d..a339fd625 100644 --- a/examples/lime_stagenet_mimic4.py +++ b/examples/lime_stagenet_mimic4.py @@ -187,7 +187,7 @@ def main(): dropout=0.3, ) - state_dict = torch.load("../resources/best.ckpt", map_location=device) + state_dict = torch.load("../resources/best.ckpt", map_location=device, weights_only=True) model.load_state_dict(state_dict) model = model.to(device) model.eval() diff --git a/pyhealth/calib/predictionset/favmac/core.py b/pyhealth/calib/predictionset/favmac/core.py index 7e9b7d354..700a80a46 100644 --- a/pyhealth/calib/predictionset/favmac/core.py +++ b/pyhealth/calib/predictionset/favmac/core.py @@ -116,7 +116,7 @@ def _greedy_sequence(self, pred:np.ndarray): if self.proxy_fn.is_additive(): Ss, _ = self.util_fn.greedy_maximize_seq(pred=pred, d_proxy = self.proxy_fn.values * (1-pred)) return Ss, list(map(proxy_fn, Ss)) - except: + except Exception: pass Ss = [np.zeros(len(pred), dtype=int)] diff --git a/pyhealth/calib/predictionset/scrib/quicksearch.py b/pyhealth/calib/predictionset/scrib/quicksearch.py index 09ebf910f..e92660b98 100644 --- a/pyhealth/calib/predictionset/scrib/quicksearch.py +++ b/pyhealth/calib/predictionset/scrib/quicksearch.py @@ -9,7 +9,7 @@ pyximport.install() from . import quicksearch_cython as cdc _CYTHON_ENABLED = True -except: +except Exception: print("This is a warning of potentially slow compute. You could uncomment this line and use the Python implementation instead of Cython.") __all__ = ['loss_overall', 'loss_classspecific', diff --git a/pyhealth/datasets/base_dataset.py b/pyhealth/datasets/base_dataset.py index 0e4280aab..a264c3d01 100644 --- a/pyhealth/datasets/base_dataset.py +++ b/pyhealth/datasets/base_dataset.py @@ -832,7 +832,7 @@ def _task_transform( while not result.ready(): try: progress.update(queue.get(timeout=1)) - except: + except Exception: pass # remaining items @@ -897,7 +897,7 @@ def _proc_transform( while not result.ready(): try: progress.update(queue.get(timeout=1)) - except: + except Exception: pass # remaining items diff --git a/pyhealth/datasets/utils.py b/pyhealth/datasets/utils.py index 24c87a1d5..a21371e21 100644 --- a/pyhealth/datasets/utils.py +++ b/pyhealth/datasets/utils.py @@ -140,7 +140,7 @@ def is_homo_list(l: List) -> bool: return True # if the value vector is a mix of float and int, convert all to float - l = [float(i) if type(i) == int else i for i in l] + l = [float(i) if isinstance(i, int) else i for i in l] return all(isinstance(i, type(l[0])) for i in l) diff --git a/pyhealth/interpret/methods/basic_gradient.py b/pyhealth/interpret/methods/basic_gradient.py index 452811fac..d34964fb6 100644 --- a/pyhealth/interpret/methods/basic_gradient.py +++ b/pyhealth/interpret/methods/basic_gradient.py @@ -130,7 +130,7 @@ def attribute(self, save_to_batch=False, **data) -> Dict[str, torch.Tensor]: output = self.model(image=batch_images, disease=batch_labels) y_prob = output['y_prob'] target_class = y_prob.argmax(dim=1) - scores = y_prob.gather(1, target_class.unsqueeze(1)).squeeze() + scores = y_prob.gather(1, target_class.unsqueeze(1)).squeeze(1) # Compute gradients self.model.zero_grad() diff --git a/pyhealth/medcode/pretrained_embeddings/kg_emb/examples/train_kge_model.py b/pyhealth/medcode/pretrained_embeddings/kg_emb/examples/train_kge_model.py index 53eadb7f1..f2f38b9a2 100644 --- a/pyhealth/medcode/pretrained_embeddings/kg_emb/examples/train_kge_model.py +++ b/pyhealth/medcode/pretrained_embeddings/kg_emb/examples/train_kge_model.py @@ -54,7 +54,7 @@ ) print('Loaded model: ', model) -state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt") +state_dict = torch.load("/data/pj20/umls_kge/pretrained_model/umls_transe_new/1_250000_last.ckpt", weights_only=True) model.load_state_dict(state_dict) # initialize a trainer and start training diff --git a/pyhealth/metrics/ranking.py b/pyhealth/metrics/ranking.py index b19f5107d..1f2e4612c 100644 --- a/pyhealth/metrics/ranking.py +++ b/pyhealth/metrics/ranking.py @@ -33,7 +33,7 @@ def ranking_metrics_fn(qrels: Dict[str, Dict[str, int]], """ try: import pytrec_eval - except: + except ImportError: raise ImportError("pytrec_eval is not installed. Please install it manually by running \ 'pip install pytrec_eval'.") ret = {} diff --git a/pyhealth/tasks/temple_university_EEG_tasks.py b/pyhealth/tasks/temple_university_EEG_tasks.py index fc2ba702a..33da15d96 100644 --- a/pyhealth/tasks/temple_university_EEG_tasks.py +++ b/pyhealth/tasks/temple_university_EEG_tasks.py @@ -187,8 +187,8 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]: "signal_file": edf_path, "split": split, "signal": signal, - "offending_channel": int(offending_channel.squeeze()), - "label": int(label.squeeze()) - 1, + "offending_channel": int(offending_channel.item()), + "label": int(label.item()) - 1, } if self.compute_stft: # get_stft_torch expects (B, C, T); unsqueeze/squeeze the batch dim