diff --git a/BackendBench/backends/kernel_agent.py b/BackendBench/backends/kernel_agent.py index cb8d110..006b2c8 100644 --- a/BackendBench/backends/kernel_agent.py +++ b/BackendBench/backends/kernel_agent.py @@ -8,7 +8,11 @@ import os from typing import Callable, Dict -from BackendBench.utils import compile_kernel_from_string, op_name_to_folder_name +from BackendBench.utils import ( + compile_kernel_from_string, + extract_operator_name, + op_name_to_folder_name, +) from .base import Backend @@ -154,7 +158,7 @@ def _adapt_kernel_function_name(self, kernel_code: str, op_name: str) -> str: Returns: Modified kernel code with correct function name """ - folder_name = os.path.basename(os.path.dirname(kernel_code)) + folder_name = op_name_to_folder_name(op_name) expected_name = f"{folder_name}_kernel_impl" # Replace 'def kernel_function' with 'def {op_name}_kernel_impl' @@ -179,38 +183,40 @@ def {expected_name}(*args, **kwargs): ''' return kernel_code + wrapper_code + def _generate_kernel_file_path(self, folder_name: str, attempt: int) -> str: + """Generate the file path for a kernel, creating the operator subdirectory if needed.""" + op_dir = os.path.join(self.kernels_dir, folder_name) + os.makedirs(op_dir, exist_ok=True) + return os.path.join(op_dir, f"{folder_name}_implementation_v{attempt}.py") + def compile_kernel_from_string( self, kernel_code: str, op_name: str, attempt: int = 1 ) -> Callable: """Compile a kernel from string code and return a callable.""" folder_name = op_name_to_folder_name(op_name) adapted_code = self._adapt_kernel_function_name(kernel_code, op_name) - kernel_file_path = os.path.join(self.kernels_dir, f"{folder_name}_kernel.py") - expected_fn_name = f"{folder_name}_kernel_impl" - module_name = f"kernel_agent_{folder_name}" - + kernel_file_path = self._generate_kernel_file_path(folder_name, attempt) + module_name = f"{folder_name}_implementation_v{attempt}" + kernel = None try: kernel = compile_kernel_from_string( kernel_code=adapted_code, op_name=op_name, kernel_file_path=kernel_file_path, - expected_fn_name=expected_fn_name, + expected_fn_name=folder_name, module_name=module_name, ) except Exception as e: raise e return kernel - def add_kernel(self, op, kernel_code: str, op_name: str): + def add_kernel(self, op, kernel_code: str, op_name: str, attempt: int = 1): """Add a kernel implementation for a specific operator.""" - compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1) - self.compiled_kernels[op] = compiled_kernel - - # Save the original KernelAgent code as well - folder_name = op_name_to_folder_name(op_name) - original_file = os.path.join(self.kernels_dir, f"{folder_name}_original_kernel_agent.py") - with open(original_file, "w") as f: - f.write(kernel_code) + try: + compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=attempt) + self.compiled_kernels[op] = compiled_kernel + except Exception as e: + print(f"❌ Failed to compile kernel for {op_name}: {e}") def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]: """ @@ -260,14 +266,104 @@ def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]: except Exception as e: print(f" Warning: Could not preserve session: {e}") - return result["kernel_code"], True + return result["kernel_code"], result["rounds"], True else: print(f"❌ KernelAgent failed for {op_name}: {result['message']}") - return "", False + return "", result["rounds"], False except Exception as e: print(f"❌ KernelAgent error for {op_name}: {e}") - return "", False + return "", 0, False + + def generate_kernels(self, suite, daemon=True): + """Generate kernels for all operators in the suite with comprehensive feedback.""" + self.daemon = daemon + successful_ops = 0 + total_ops = 0 + + for op_test in suite: + total_ops += 1 + op = op_test.op + op_str = str(op) + op_name = extract_operator_name(op_str) + folder_name = op_name_to_folder_name(op_name) + + logger.info(f"Generating kernel for {op_name} (full op: {op_str})") + + # Generate kernel with feedback-driven retry + kernel_code, best_kernel_attempt, success = self.generate_kernel_with_agent( + op=op, op_name=op_name + ) + + # Add kernel to backend and track success + self.add_kernel(op, kernel_code, op_name) + if success: + successful_ops += 1 + logger.info(f"✓ Success! Best attempt: {best_kernel_attempt}") + else: + logger.info(f"✗ Failed after {best_kernel_attempt} rounds") + + # Write operation summary in the operator subdirectory + self._write_summary( + folder_name=folder_name, + op_name=op_name, + op_str=op_str, + best_kernel_attempt=best_kernel_attempt, + success=success, + ) + + # Generate and save overall summary + self._write_overall_summary(successful_ops, total_ops) + + def _write_summary( + self, + folder_name: str, + op_name: str, + op_str: str, + best_kernel_attempt: int, + success: bool, + ): + """Write operation summary to the operator subdirectory.""" + op_dir = os.path.join(self.kernels_dir, folder_name) + os.makedirs(op_dir, exist_ok=True) + summary_file = os.path.join(op_dir, f"{folder_name}_summary.txt") + with open(summary_file, "w") as f: + f.write(f"Operation: {op_name}\n") + f.write(f"Full op: {op_str}\n") + f.write("Backend: KernelAgent\n") + f.write(f"Workers: {self.num_workers}\n") + f.write(f"Max rounds: {self.max_rounds}\n") + f.write(f"Best kernel attempt: {best_kernel_attempt}\n") + f.write(f"Final Status: {'✓ Success' if success else '✗ Failure'}\n") + f.write(f"Final kernel file: {folder_name}_implementation_v{best_kernel_attempt}.py\n") + + def _write_overall_summary(self, successful_ops: int, total_ops: int): + """Write overall summary of kernel generation results.""" + failed_ops = total_ops - successful_ops + success_rate = f"{successful_ops / total_ops * 100:.1f}%" if total_ops > 0 else "0.0%" + + summary_lines = [ + "=" * 60, + "KERNEL AGENT BACKEND SETUP SUMMARY", + "=" * 60, + f"Total operations attempted: {total_ops}", + f"Successfully created correct kernels for: {successful_ops} ops", + f"Failed to create correct kernels for: {failed_ops} ops", + f"Success rate: {success_rate}", + f"Workers: {self.num_workers}", + f"Max rounds: {self.max_rounds}", + f"Generated kernels saved to: {self.kernels_dir}", + "Backend: KernelAgent", + "=" * 60, + ] + + # Log summary + for line in summary_lines: + logger.info(line) + + # Save to file + with open(os.path.join(self.kernels_dir, "OVERALL_SUMMARY.txt"), "w") as f: + f.write("\n".join(summary_lines)) def __getitem__(self, key): if key in self.compiled_kernels: diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 4268481..1aa019d 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -20,6 +20,8 @@ from datasets import load_dataset from tqdm import tqdm +from BackendBench.utils import extract_operator_name + # constants for downloading the test set from huggingface # you can explore the dataset here # https://huggingface.co/datasets/GPUMODE/backendbench_tests @@ -224,9 +226,10 @@ def _load_from_parquet( table = pq.read_table(source) df = table.to_pandas() - # Apply filter if provided + # Apply filter if provided - use exact matching on extracted operator names + # e.g., "relu.default" should match "aten.relu.default" but NOT "aten.leaky_relu.default" if filter: - mask = df["op_name"].apply(lambda op: any(f in op for f in filter)) + mask = df["op_name"].apply(lambda op: extract_operator_name(op) in filter) df = df[mask] return df.to_dict("records") diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 7f0072b..a47558d 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -224,6 +224,8 @@ def cli( elif backend == "kernel_agent": if backends.KernelAgentBackend is None: raise NotImplementedError("KernelAgent backend is for internal use only") + backend = backends.KernelAgentBackend() + backend.generate_kernels(suite, daemon=daemon) elif backend == "directory": if dsl == "cuda": backend = backends.DirectoryBackend(ops_directory, load_cpp_source=load_cpp_source) @@ -233,7 +235,6 @@ def cli( backend = { "aten": backends.AtenBackend, "flag_gems": backends.FlagGemsBackend, - "kernel_agent": backends.KernelAgentBackend, "directory": backends.DirectoryBackend, }[backend]() @@ -253,6 +254,7 @@ def cli( if num_workers is None: for test in suite: + print(test.op) if test.op not in backend: continue