From 4825eba97d76375f2c4b0ec613089dac75fd6056 Mon Sep 17 00:00:00 2001 From: Politrees <143968312+Bebra777228@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:26:15 +0500 Subject: [PATCH] CPU support --- Example.py | 2 +- FlashSR/FlashSR.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Example.py b/Example.py index fd79207..d88e408 100644 --- a/Example.py +++ b/Example.py @@ -10,7 +10,7 @@ student_ldm_ckpt_path:str = './ModelWeights/student_ldm.pth' sr_vocoder_ckpt_path:str = './ModelWeights/sr_vocoder.pth' vae_ckpt_path:str = './ModelWeights/vae.pth' -flashsr = FlashSR( student_ldm_ckpt_path, sr_vocoder_ckpt_path, vae_ckpt_path) +flashsr = FlashSR( student_ldm_ckpt_path, sr_vocoder_ckpt_path, vae_ckpt_path, device) flashsr = flashsr.to(device) audio_path:str = './Assets/ExampleInput/music.wav' diff --git a/FlashSR/FlashSR.py b/FlashSR/FlashSR.py index 578b5d1..fd1e1a5 100644 --- a/FlashSR/FlashSR.py +++ b/FlashSR/FlashSR.py @@ -18,6 +18,7 @@ def __init__( student_ldm_ckpt_path:str, sr_vocoder_ckpt_path:str, autoencoder_ckpt_path:str, + device:str, model_output_type:str = 'v_prediction', beta_schedule_type:str = 'cosine', **kwargs @@ -31,7 +32,7 @@ def __init__( self.vae = VAEWrapper(autoencoder_ckpt_path) self.sr_vocoder = SRVocoder() - sr_vocoder_state_dict = torch.load(sr_vocoder_ckpt_path) + sr_vocoder_state_dict = torch.load(sr_vocoder_ckpt_path, map_location=torch.device(device)) self.sr_vocoder.load_state_dict(sr_vocoder_state_dict) def forward(self,