From afff348aa579f4aa3c99f456995cb97dc844cd57 Mon Sep 17 00:00:00 2001 From: c4fun Date: Wed, 6 Mar 2024 11:13:46 +0800 Subject: [PATCH] fixed the linux model problem --- .gitignore | 8 ++++++++ rope/Models.py | 26 +++++++++++++++++++++----- 2 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..a8d6c00f --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +**/__pycache__/** + +models/*.ckpt +models/*.pth +models/*.onnx +saved_parameters.json +data.json +merged_embeddings.txt diff --git a/rope/Models.py b/rope/Models.py index f8fe9241..193c6ccc 100644 --- a/rope/Models.py +++ b/rope/Models.py @@ -12,6 +12,7 @@ from itertools import product as product import subprocess as sp onnxruntime.set_default_logger_severity(4) +from platform import system class Models(): def __init__(self): @@ -57,20 +58,29 @@ def run_detect(self, img, detect_mode='Retinaface', max_num=1, score=0.5): if detect_mode=='Retinaface': if not self.retinaface_model: - self.retinaface_model = onnxruntime.InferenceSession('.\models\det_10g.onnx', providers=self.providers) + if system() == 'Linux': + self.retinaface_model = onnxruntime.InferenceSession('./models/det_10g.onnx', providers=self.providers) + else: + self.retinaface_model = onnxruntime.InferenceSession('.\models\det_10g.onnx', providers=self.providers) kpss = self.detect_retinaface(img, max_num=max_num, score=score) elif detect_mode=='SCRDF': if not self.scrdf_model: - self.scrdf_model = onnxruntime.InferenceSession('.\models\scrfd_2.5g_bnkps.onnx', providers=self.providers) + if system() == 'Linux': + self.scrdf_model = onnxruntime.InferenceSession('./models/scrfd_2.5g_bnkps.onnx', providers=self.providers) + else: + self.scrdf_model = onnxruntime.InferenceSession('.\models\scrfd_2.5g_bnkps.onnx', providers=self.providers) kpss = self.detect_scrdf(img, max_num=max_num, score=score) elif detect_mode=='Yolov8': if not self.yoloface_model: - self.yoloface_model = onnxruntime.InferenceSession('.\models\yoloface_8n.onnx', providers=self.providers) + if system() == 'Linux': + self.yoloface_model = onnxruntime.InferenceSession('./models/yoloface_8n.onnx', providers=self.providers) + else: + self.yoloface_model = onnxruntime.InferenceSession('.\models\yoloface_8n.onnx', providers=self.providers) kpss = self.detect_yoloface(img, max_num=max_num, score=score) @@ -80,7 +90,10 @@ def run_align(self, img): points = [] if not self.insight106_model: - self.insight106_model = onnxruntime.InferenceSession('.\models\2d106det.onnx', providers=self.providers) + if system() == 'Linux': + self.insight106_model = onnxruntime.InferenceSession('./models/2d106det.onnx', providers=self.providers) + else: + self.insight106_model = onnxruntime.InferenceSession('.\models\2d106det.onnx', providers=self.providers) points = self.detect_insight106(img) @@ -102,7 +115,10 @@ def delete_models(self): def run_recognize(self, img, kps): if not self.recognition_model: - self.recognition_model = onnxruntime.InferenceSession('.\models\w600k_r50.onnx', providers=self.providers) + if system() == 'Linux': + self.recognition_model = onnxruntime.InferenceSession('./models/w600k_r50.onnx', providers=self.providers) + else: + self.recognition_model = onnxruntime.InferenceSession('.\models\w600k_r50.onnx', providers=self.providers) embedding, cropped_image = self.recognize(img, kps) return embedding, cropped_image