Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@

import logging
import math
import os
import time
import traceback
from dataclasses import dataclass
from typing import List, Tuple

import torch

from BackendBench.power import PowerManager
from BackendBench.utils import compute_errors, serialize_args, uses_cuda_stream


Expand All @@ -33,6 +36,7 @@ class PerformanceTestResult:
op_name: str
args: str
speedup: float
total_energy: float
benchmark_time_ms: float
reference_time_ms: float
error_msg: str = ""
Expand Down Expand Up @@ -148,7 +152,6 @@ def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult

def cpu_bench(fn, num_runs=100):
"""Simple CPU benchmarking using time.perf_counter."""
import time

for _ in range(10):
fn()
Expand All @@ -159,6 +162,44 @@ def cpu_bench(fn, num_runs=100):
return (time.perf_counter() - start) / num_runs


def do_bench_power(
fn,
warm_ups=10,
num_runs=10000,
output_dir="./bench_power",
gpu_id=0,
query_interval=0.01,
):
"""
Benchmark a function while collecting GPU power information.

Args:
fn: The function (e.g., kernel call) to benchmark.
num_runs: Number of times to run fn().
output_dir: Directory to store results (power.csv, plots).
gpu_id: GPU index to monitor.
query_interval: Sampling interval for power measurement (seconds).

Returns:
total_energy: Total energy consumed by fn() in mJoules.
"""
os.makedirs(output_dir, exist_ok=True)

pm = PowerManager()
pm.gpu_id = gpu_id
pm.output_dir = output_dir
pm.query_interval = query_interval

for _ in range(warm_ups):
fn()
pm.start()
for _ in range(num_runs):
fn()
pm.stop()
total_energy = pm.finalize() / num_runs
return total_energy


def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult]]:
"""Evaluate performance of impl against tests."""
bench_fn = (
Expand Down Expand Up @@ -193,11 +234,13 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
f"Reference and result tensors are not close: max absolute error {abs_error}, max relative error {rel_error}"
)
test_time = bench_fn(lambda: impl(*cached_args, **cached_kwargs))
total_energy = do_bench_power(lambda: impl(*cached_args, **cached_kwargs))
performance_results.append(
PerformanceTestResult(
op_name=op.__name__,
args=args_str,
speedup=base_time / test_time,
total_energy=total_energy,
successfully_ran=True,
benchmark_time_ms=test_time,
reference_time_ms=base_time,
Expand All @@ -211,6 +254,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult
args=args_str,
successfully_ran=False,
speedup=None,
total_energy=None,
benchmark_time_ms=None,
reference_time_ms=base_time,
error_msg=error_msg,
Expand Down
186 changes: 186 additions & 0 deletions BackendBench/power.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import csv
import os
import threading
import time
from dataclasses import asdict, dataclass, fields

import matplotlib.pyplot as plt
from pynvml import (
NVML_CLOCK_ID_CURRENT,
NVML_CLOCK_MEM,
NVML_CLOCK_SM,
NVML_FI_DEV_POWER_CURRENT_LIMIT,
NVML_FI_DEV_POWER_INSTANT,
NVML_TEMPERATURE_GPU,
nvmlDeviceGetClock,
nvmlDeviceGetFieldValues,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetTemperature,
nvmlInit,
nvmlShutdown,
)

# query every 10 ms
DEFAULT_QUERY_INTERVAL = 0.01


@dataclass
class PowerEvent:
timestamp: float
sm_clock: float
mem_clock: float
power_draw_instant: float
power_draw_current_limit: float
gpu_temp: float


def check_nvml_status(nvml_status):
if nvml_status:
raise RuntimeError("NVML initialization failed")


class GPUCollectorThread:
def __init__(self, gpu_id=None, query_interval=DEFAULT_QUERY_INTERVAL) -> None:
self.gpu_id = int(gpu_id) if gpu_id else os.environ.get("CUDA_VISIBLE_DEVICES", "0")
# Assume Python GIL so not protecting this using Atomics
self.continue_monitoring = True
# Sampling interval in seconds
self.sampling_interval = query_interval
self.events = []
self.iter = []
check_nvml_status(nvmlInit())

def start(self):
handle = nvmlDeviceGetHandleByIndex(int(self.gpu_id))
while self.continue_monitoring:
# check gpu power event
sm_clock = nvmlDeviceGetClock(handle, NVML_CLOCK_SM, NVML_CLOCK_ID_CURRENT)
mem_clock = nvmlDeviceGetClock(handle, NVML_CLOCK_MEM, NVML_CLOCK_ID_CURRENT)
power_info = nvmlDeviceGetFieldValues(
handle, [NVML_FI_DEV_POWER_INSTANT, NVML_FI_DEV_POWER_CURRENT_LIMIT]
)
gpu_temp = nvmlDeviceGetTemperature(handle, NVML_TEMPERATURE_GPU)
self.events.append(
PowerEvent(
timestamp=int(time.time_ns() / 1e3),
sm_clock=sm_clock,
mem_clock=mem_clock,
power_draw_instant=power_info[0].value.uiVal / 1000.0,
power_draw_current_limit=power_info[1].value.uiVal / 1000.0,
gpu_temp=gpu_temp,
)
)
time.sleep(self.sampling_interval)
nvmlShutdown()


class PowerManager:
def __init__(self) -> None:
self.gpu_id = None
self.output_dir = None
self.query_interval = None

def start(self) -> None:
self.collector = GPUCollectorThread(self.gpu_id, self.query_interval)
self._t = threading.Thread(target=self.collector.start)
self._t.start()

def stop(self) -> None:
self.collector.continue_monitoring = False
self._t.join()

def finalize(self) -> None:
# flush results to file
result_file = os.path.join(self.output_dir, "power.csv")
with open(result_file, "w", newline="") as csvfile:
# Get the field names from the dataclass to use as CSV header
fieldnames = [field.name for field in fields(PowerEvent)]

# Create a DictWriter object
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=";")

# Write the header row
writer.writeheader()

total_energy = 0
# Write each dataclass instance as a row in the CSV
current_interval = self.query_interval
for i in range(len(self.collector.events)):
if i < len(self.collector.events) - 1:
current_interval = (
self.collector.events[i + 1].timestamp - self.collector.events[i].timestamp
) / 1e6
event = self.collector.events[i]
total_energy += event.power_draw_instant * current_interval
writer.writerow(asdict(event))
return total_energy


def plot_power_charts(benchmark_name: str, gpu_id: int, output_dir: str, power_csv_file: str):
# Read CSV
with open(power_csv_file) as f:
reader = csv.reader(f, delimiter=";")
header = next(reader) # first row as header
header = [col.strip() for col in header]
data = {col: [] for col in header}

for row in reader:
for col, value in zip(header, row):
value = float(value)
data[col].append(value)

# Generate synthetic time axis (100 ms per sample)
n_samples = len(next(iter(data.values())))
time = [
(data["timestamp"][i] - data["timestamp"][0]) / 1000.0 for i in range(n_samples)
] # seconds (0.1s = 100 ms)

# Plot power chart
plt.figure(figsize=(10, 6))
for power_col in header[3:5]:
plt.plot(time, data[power_col], label=power_col)
plt.xlabel("Time (ms)")
plt.ylabel("Power (W)")
plt.legend()
plt.title(f"{benchmark_name} power consumption over time on device {gpu_id}")
plt.savefig(
os.path.join(output_dir, f"{benchmark_name}-power.png"),
dpi=300,
bbox_inches="tight",
)
# Plot temp chart
plt.figure(figsize=(10, 6))
for temp_col in header[5:]:
plt.plot(time, data[temp_col], label=temp_col)
plt.xlabel("Time (ms)")
plt.ylabel("Temperature (C)")
plt.legend()
plt.title(f"{benchmark_name} temperature over time on device {gpu_id}")
plt.savefig(
os.path.join(output_dir, f"{benchmark_name}-temp.png"),
dpi=300,
bbox_inches="tight",
)
# Plot frequency chart
plt.figure(figsize=(10, 6))
for temp_col in header[1:3]:
plt.plot(time, data[temp_col], label=temp_col)
plt.xlabel("Time (ms)")
plt.ylabel("Frequency (MHz)")
plt.legend()
plt.title(f"{benchmark_name} frequency over time on device {gpu_id}")
plt.savefig(
os.path.join(output_dir, f"{benchmark_name}-freq.png"),
dpi=300,
bbox_inches="tight",
)


if __name__ == "__main__":
plot_power_charts("addmm", 0, "./bench_power", "./bench_power/power.csv")
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dev-dependencies = [
"torch",
"numpy",
"pyarrow",
"nvidia-ml-py",
"matplotlib",
# cupy-cuda12x is platform specific, install manually if needed
]

Expand Down
5 changes: 5 additions & 0 deletions test/test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _create_test_fixtures(self):
op_name="torch.ops.aten.add.Tensor",
args="[tensor([1, 2]), tensor([3, 4])]",
speedup=1.5,
total_energy=150.0,
benchmark_time_ms=10.0,
reference_time_ms=15.0,
successfully_ran=True,
Expand All @@ -67,6 +68,7 @@ def _create_test_fixtures(self):
op_name="torch.ops.aten.add.Tensor",
args="[tensor([5, 6]), tensor([7, 8])]",
speedup=2.0,
total_energy=200.0,
benchmark_time_ms=8.0,
reference_time_ms=16.0,
successfully_ran=True,
Expand All @@ -75,6 +77,7 @@ def _create_test_fixtures(self):
op_name="torch.ops.aten.mul.Tensor",
args="[tensor([1, 2]), tensor([3, 4])]",
speedup=1.0,
total_energy=100.0,
benchmark_time_ms=20.0,
reference_time_ms=20.0,
successfully_ran=True,
Expand All @@ -83,6 +86,7 @@ def _create_test_fixtures(self):
op_name="torch.ops.aten.sin.default",
args="[tensor([0.5])]",
speedup=None,
total_energy=0.0,
benchmark_time_ms=None,
reference_time_ms=20.0,
successfully_ran=False,
Expand Down Expand Up @@ -258,6 +262,7 @@ def test_edge_cases(self):
op_name="edge_case_op",
args="[tensor([nan])]",
speedup=float("inf"),
total_energy=0.0,
benchmark_time_ms=0.0,
reference_time_ms=1.0,
successfully_ran=True,
Expand Down