From ce8fa3004ba151e3e3915762789528e75a6ed205 Mon Sep 17 00:00:00 2001 From: liuqiyuan Date: Wed, 10 Apr 2024 17:46:29 +0800 Subject: [PATCH] Adapt to Ascend NPU --- environment_npu.yml | 15 +++++++++++++++ sample.py | 3 +++ train.py | 3 +++ utils/device_utils.py | 17 +++++++++++++++++ 4 files changed, 38 insertions(+) create mode 100644 environment_npu.yml create mode 100644 utils/device_utils.py diff --git a/environment_npu.yml b/environment_npu.yml new file mode 100644 index 00000000..829d9ba4 --- /dev/null +++ b/environment_npu.yml @@ -0,0 +1,15 @@ +name: DiT +channels: + - pytorch +dependencies: + - python >= 3.8 + - pytorch >= 1.13 + - torchvision + - pip: + - timm + - diffusers + - accelerate + - protobuf + - decorator + - scipy + - attrs diff --git a/sample.py b/sample.py index a4152afd..bed614ea 100644 --- a/sample.py +++ b/sample.py @@ -16,6 +16,9 @@ from download import find_model from models import DiT_models import argparse +from utils.device_utils import is_npu_available +if is_npu_available(): + from torch_npu.contrib import transfer_to_npu def main(args): diff --git a/train.py b/train.py index 7cfee808..8506446a 100644 --- a/train.py +++ b/train.py @@ -30,6 +30,9 @@ from models import DiT_models from diffusion import create_diffusion from diffusers.models import AutoencoderKL +from utils.device_utils import is_npu_available +if is_npu_available(): + from torch_npu.contrib import transfer_to_npu ################################################################################# diff --git a/utils/device_utils.py b/utils/device_utils.py new file mode 100644 index 00000000..4e1149f8 --- /dev/null +++ b/utils/device_utils.py @@ -0,0 +1,17 @@ +import torch +import importlib + + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False \ No newline at end of file