Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
fd9109c
add loadinline to support cuda kernel
jiannanWang Oct 28, 2025
93e5d4c
update
jiannanWang Oct 28, 2025
e896275
Merge branch 'main' into jiannanwang/loadinline
jiannanWang Oct 28, 2025
4a8d599
solve conflict
jiannanWang Oct 29, 2025
b5f441d
Merge branch 'main' into jiannanwang/loadinline
jiannanWang Oct 30, 2025
222a7fd
ruff
jiannanWang Oct 30, 2025
d7f8074
fix
jiannanWang Oct 30, 2025
d1c8649
fix
jiannanWang Oct 30, 2025
334a39e
add ninja to ci
jiannanWang Oct 30, 2025
08b7054
set CUDA_HOME
jiannanWang Oct 30, 2025
44a427b
add skip
jiannanWang Oct 30, 2025
1cf7b4d
fix
jiannanWang Oct 30, 2025
401473e
add no_implicit_headers
jiannanWang Oct 30, 2025
39eb648
test
jiannanWang Oct 30, 2025
e417006
test cuda version
jiannanWang Oct 30, 2025
e65f8e9
update
jiannanWang Oct 30, 2025
945353d
install cuda toolkit in ci
jiannanWang Oct 30, 2025
c35add9
install cuda toolkit in ci
jiannanWang Oct 30, 2025
e99eb6a
install cuda toolkit in ci
jiannanWang Oct 30, 2025
05fe3b8
fix
jiannanWang Oct 30, 2025
734c933
fix
jiannanWang Oct 30, 2025
42fbe36
fix
jiannanWang Oct 30, 2025
f75fbaf
skip cuda testing in CI since no CUDA_HOME
jiannanWang Oct 30, 2025
c2f179d
merge
jiannanWang Dec 18, 2025
27af6f5
find cuda path
jiannanWang Dec 18, 2025
4f95593
install cuda toolkit
jiannanWang Dec 18, 2025
eb594ae
install cuda toolkit 12.8
jiannanWang Dec 18, 2025
50e99c3
fix
jiannanWang Dec 18, 2025
ee2c8c0
generate cpp source from cu source
jiannanWang Jan 8, 2026
6d74881
add comment
jiannanWang Jan 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/smoke-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install CUDA Toolkit
run: |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get install -y cuda-toolkit-12-8
echo "CUDA_HOME=/usr/local/cuda-12.8" >> $GITHUB_ENV
echo "/usr/local/cuda-12.8/bin" >> $GITHUB_PATH

- name: Install uv
uses: astral-sh/setup-uv@v3

Expand Down
144 changes: 132 additions & 12 deletions BackendBench/backends/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
import os
from typing import Callable, Dict

from torch.utils.cpp_extension import load_inline

from ..utils import folder_name_to_op_name, get_pytorch_op
from .base import Backend

logger = logging.getLogger(__name__)


class DirectoryBackend(Backend):
def __init__(self, ops_dir="generated_kernels"):
def __init__(self, ops_dir="generated_kernels", load_cpp_source=False):
super().__init__("directory")
self.ops_dir = ops_dir
self.compiled_kernels: Dict[str, Callable] = {}
self._load_kernels()
self.load_cpp_source = load_cpp_source
self._load_kernels(load_cpp_source)

def _load_kernels(self):
def _load_kernels(self, load_cpp_source=False):
"""
Discovers and loads kernel implementations from the operator directory structure.

Expand All @@ -47,7 +50,8 @@ def _load_kernels(self):
impl_files = [
f
for f in os.listdir(op_dir)
if f.endswith(".py") and f.startswith(f"{folder_name}_implementation")
if (f.endswith(".py") or f.endswith(".cu") or f.endswith(".cpp"))
and f.startswith(f"{folder_name}_implementation")
]
if not impl_files:
logger.debug(f"No implementation files found in {op_dir}")
Expand All @@ -58,7 +62,7 @@ def _load_kernels(self):

try:
op_name = folder_name_to_op_name(folder_name)
kernel_func = self._load_kernel_from_file(impl_path, folder_name)
kernel_func = self._load_kernel_from_file(impl_path, folder_name, load_cpp_source)

pytorch_op = get_pytorch_op(op_name)
if pytorch_op:
Expand All @@ -71,17 +75,13 @@ def _load_kernels(self):

logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/")

def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable:
def _load_python_kernel(self, file_path: str, folder_name: str) -> Callable:
"""
Dynamically load a kernel implementation function from a Python file.

Each operator directory should contain implementation files that export a function
named {op_name}_kernel_impl. This function becomes the kernel implementation
that gets registered for all variants of the operator.
Load a kernel implementation from a Python file.

Args:
file_path: Path to the Python implementation file
op_name: Base name of the operator (e.g., "add", "mul", "conv2d")
folder_name: Base name of the operator (e.g., "add__Tensor")

Returns:
Callable kernel implementation function
Expand All @@ -99,6 +99,126 @@ def _load_kernel_from_file(self, file_path: str, folder_name: str) -> Callable:
else:
raise ValueError(f"No function named {kernel_func_name} found in {file_path}")

def _generate_cpp_source(self, base_name: str, cuda_source: str) -> str:
"""
Generate C++ source code from a CUDA file.

Args:
file_path: Path to the CUDA implementation file (.cu or .cpp)
folder_name: Base name of the operator (e.g., "add__Tensor")

Returns:
str: Generated C++ source code
"""
output_lines = []
# Always include the torch extension header
output_lines.append("#include <torch/extension.h>\n")
# Find the function signature for the given base_name
for line in cuda_source.splitlines():
stripped = line.strip()
if stripped.startswith("at::Tensor") and base_name in stripped:
# Remove the function body if present
signature = stripped.split("{")[0].rstrip()
# Ensure it ends with a semicolon
if not signature.endswith(";"):
signature += ";"
output_lines.append(signature + "\n")
break # Only one function per file is expected
return "".join(output_lines)

def _load_cuda_kernel(
self, file_path: str, folder_name: str, load_cpp_source: bool = False
) -> Callable:
"""
Load and compile a kernel implementation from CUDA files using load_inline.

Args:
file_path: Path to the CUDA implementation file
folder_name: Base name of the operator (e.g., "add__Tensor")
load_cpp_source: Whether to also load the corresponding .cpp file. Defaults to False.

Returns:
Callable kernel implementation function

Raises:
ValueError: If the expected kernel function is not found in the compiled module
"""
file_dir = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
base_name = file_name.rsplit(".", 1)[0]

cu_file = os.path.join(file_dir, f"{base_name}.cu")

cuda_source = ""
# Read cuda file if exists
if os.path.exists(cu_file):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we can simplify this a a bit and only make the LLM spit out the .cu file - the cpp file should typically be quite simple for us to provide. see this as an example https://github.com/gpu-mode/reference-kernels/blob/main/problems/pmpp/vectoradd_py/solutions/correct/submission_cuda_inline.py#L48

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved! Added a new parameter load_cpp_source which by default set to false. It controls whether to load cpp source from the .cpp file (load_cpp_source=true) or to generate the cpp source content from the cuda source (oad_cpp_source=false)

with open(cu_file, "r") as f:
cuda_source = f.read()

if cuda_source == "" and not load_cpp_source:
logger.warning(f"No CUDA source found for {file_path}.")
return None

cpp_source = ""
if load_cpp_source:
# Read cpp file if exists
cpp_file = os.path.join(file_dir, f"{base_name}.cpp")
if os.path.exists(cpp_file):
with open(cpp_file, "r") as f:
cpp_source = f.read()
else:
# Generate cpp file from cuda file
cpp_source = self._generate_cpp_source(folder_name, cuda_source)

# Use load_inline for all cases
module_name = f"{folder_name}_cuda_inline"
cuda_module = load_inline(
name=module_name,
cpp_sources=cpp_source,
cuda_sources=cuda_source,
functions=[folder_name],
no_implicit_headers=True,
)

if hasattr(cuda_module, folder_name):
return getattr(cuda_module, folder_name)
else:
raise ValueError(
f"No function named {folder_name} found in compiled CUDA module from {file_path}"
)

def _load_kernel_from_file(
self, file_path: str, folder_name: str, load_cpp_source: bool = False
) -> Callable:
"""
Dynamically load a kernel implementation function from a Python or CUDA file.

Dispatches to the appropriate loader based on file extension:
- .py files -> _load_python_kernel
- .cu or .cpp files -> _load_cuda_kernel

Args:
file_path: Path to the implementation file (Python or CUDA)
op_name: Base name of the operator (e.g., "add", "mul", "conv2d")
load_cpp_source: Whether to also load the corresponding .cpp file. Defaults to False.

Returns:
Callable kernel implementation function

Raises:
ValueError: If the file extension is unsupported or the kernel function is not found
"""
file_ext = os.path.splitext(file_path)[1]

if file_ext == ".py":
return self._load_python_kernel(file_path, folder_name)
elif file_ext in [".cu", ".cpp"]:
return self._load_cuda_kernel(file_path, folder_name, load_cpp_source)
else:
raise ValueError(
f"Unsupported file extension {file_ext} for {file_path}. Expected .py, .cu, or .cpp"
)

def __getitem__(self, key):
if key in self.compiled_kernels:
return self.compiled_kernels[key]
Expand Down
67 changes: 67 additions & 0 deletions BackendBench/scripts/create_simple_test_ops_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
Create simple kernel implementations for 5 common operations.
Each just calls the original PyTorch function.
"""

import argparse
import logging
import os

logger = logging.getLogger(__name__)


def create_add(base_dir):
os.makedirs(f"{base_dir}/add__Tensor", exist_ok=True)
with open(f"{base_dir}/add__Tensor/add__Tensor_implementation_v1.cu", "w") as f:
f.write("""#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

__global__ void add__Tensor_kernel(
const float* __restrict__ x,
const float* __restrict__ y,
float* __restrict__ output,
const int size) {
const auto index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < size) {
output[index] = x[index] + y[index];
}
}

at::Tensor add__Tensor(const at::Tensor& a, const at::Tensor& b) {
auto out = at::empty_like(a);
int64_t numel = a.numel();
const int threads = 256;
const int blocks = (numel + threads - 1) / threads;
add__Tensor_kernel<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
a.data_ptr<float>(), b.data_ptr<float>(), out.data_ptr<float>(), numel
);
return out;
}
""")
logger.info("Created add implementation")


def main():
"""Create 1 simple test operations."""
parser = argparse.ArgumentParser(description="Creating cuda kernel implementations for testing")
parser.add_argument(
"--base-dir",
default="generated_kernels",
help="Base directory containing operator subdirectories",
)

args = parser.parse_args()

create_add(args.base_dir)


if __name__ == "__main__":
main()
11 changes: 10 additions & 1 deletion BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def setup_logging(log_level):
default=True,
help="Use daemon worker processes (default: True). Use --no-daemon for Helion",
)
@click.option(
"--load_cpp_source",
default=False,
help="Load C++ source code for Cuda kernels. When set to False BackendBench will construct cpp source files from the given cuda source code.",
)
def cli(
log_level,
suite,
Expand All @@ -172,6 +177,7 @@ def cli(
p,
dsl,
daemon,
load_cpp_source,
):
if suite != "torchbench":
if topn_inputs is not None:
Expand Down Expand Up @@ -219,7 +225,10 @@ def cli(
if backends.KernelAgentBackend is None:
raise NotImplementedError("KernelAgent backend is for internal use only")
elif backend == "directory":
backend = backends.DirectoryBackend(ops_directory)
if dsl == "cuda":
backend = backends.DirectoryBackend(ops_directory, load_cpp_source=load_cpp_source)
else:
backend = backends.DirectoryBackend(ops_directory)
else:
backend = {
"aten": backends.AtenBackend,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"pandas",
"datasets",
"tenacity",
"ninja",
"nvidia-cutlass-dsl",
]

Expand Down
Loading