forked from Janos95/chamfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
116 lines (92 loc) · 3.47 KB
/
setup.py
File metadata and controls
116 lines (92 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import annotations
import os
import sys
from pathlib import Path
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
os.environ.setdefault("MACOSX_DEPLOYMENT_TARGET", "13.0")
def gather_include_dirs() -> list[str]:
import torch
import nanobind
includes = []
try:
from torch.utils.cpp_extension import include_paths
except ImportError:
torch_dir = Path(torch.__file__).resolve().parent
includes.append(str(torch_dir / "include"))
includes.append(str(torch_dir / "include" / "torch" / "csrc" / "api" / "include"))
else:
includes.extend(include_paths())
nb_root = Path(nanobind.__file__).resolve().parent
includes.append(str(nb_root / "include"))
includes.append(str(nb_root / "ext" / "robin_map" / "include"))
return includes
def gather_extra_sources() -> list[str]:
import nanobind
nb_root = Path(nanobind.__file__).resolve().parent
nb_combined = nb_root / "src" / "nb_combined.cpp"
if nb_combined.exists():
return [str(nb_combined)]
return []
class TorchBuildExt(BuildExtension):
def build_extensions(self) -> None:
include_dirs = gather_include_dirs()
extra_sources = gather_extra_sources()
compiler = self.compiler
if ".mm" not in compiler.src_extensions:
compiler.src_extensions.append(".mm")
compiler.language_map[".mm"] = "objc++"
for ext in self.extensions:
ext.include_dirs.extend(include_dirs)
ext.sources.extend(extra_sources)
super().build_extensions()
def make_extension():
import torch
torch_root = Path(torch.__file__).resolve().parent
torch_lib_dir = torch_root / "lib"
base_sources = [
"chamfer/src/bindings.cpp",
"chamfer/src/kd_tree.cpp",
"chamfer/src/kd_query_cpu.cpp",
]
define_macros: list[tuple[str, str | None]] = []
extra_compile_args: dict[str, list[str]] = {"cxx": ["-std=c++20", "-fvisibility=hidden"]}
extra_link_args: list[str] = []
if sys.platform == "darwin":
sources = base_sources + ["chamfer/src/metal_bridge.mm"]
define_macros.append(("CHAMFER_WITH_MPS", "1"))
extra_compile_args["cxx"].append("-fobjc-arc")
extra_link_args.extend(["-framework", "Metal", "-framework", "Foundation"])
ext_cls = CppExtension
else:
sources = base_sources.copy()
cuda_source = "chamfer/src/kd_query_cuda.cu"
if Path(cuda_source).exists():
sources.append(cuda_source)
define_macros.append(("CHAMFER_WITH_CUDA", "1"))
extra_compile_args["nvcc"] = [
"-O3",
"-std=c++17",
"-DCHAMFER_WITH_CUDA=1",
"-Xcompiler=-fvisibility=hidden",
]
ext_cls = CUDAExtension
else:
ext_cls = CppExtension
runtime_library_dirs: list[str] = []
if sys.platform != "darwin" and torch_lib_dir.exists():
runtime_library_dirs = [str(torch_lib_dir)]
return ext_cls(
"chamfer_ext",
sources=sources,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
runtime_library_dirs=runtime_library_dirs,
)
IS_BUILDING_SDIST = "sdist" in sys.argv
extensions = [] if IS_BUILDING_SDIST else [make_extension()]
setup(
cmdclass={"build_ext": TorchBuildExt},
ext_modules=extensions,
)