From 42a33e182fab5f7272f45c1fdd4d145a13a62279 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 16 Oct 2025 01:03:36 +0000 Subject: [PATCH] [WIP] Add testing for backwards passes Summary: Here we add correctness tests for backwards passes of ops. This PR does the following things 1) Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other 2) To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset. 3) We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. Note we don't copy tensors/args as sometimes they are views (at least in opinfo) which makes cloning difficult. 4) There are also a bunch of unit tests added to make sure the gradient checking utils work as expected. Test Plan: With this really slow correctish [mm implementation](https://gist.github.com/PaliC/e62859f0286f6bfa338ccb4140e9e74f) we get ```bash uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 1.00 performance score (geomean speedup over all operators): 0.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 1.00 ``` With the bad monkey patched implementation we get ``` uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 0.00 performance score (geomean speedup over all operators): 1.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 0.00 ``` The following two commands with aten also work as expected (100% correctness on forwards and backwards) ``` ``uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --check-backwards`` `uv run python BackendBench/scripts/main.py --suite torchbench --topn 2 --backend aten --check-backwards` ``` Todo: - [ ] rename is_correct -> correct_output (originally in this pr but added noise for reviewers) - [ ] performance tests - [ ] for torchbench suite put backwards checking in dataset - [ ] Assuming the above support ops which have conditions on their args - [ ] support inplace ops --- BackendBench/backwards_utils.py | 173 +++++++++++++++++++++++++++ BackendBench/eval.py | 93 +++++++++++++- BackendBench/multiprocessing_eval.py | 14 ++- BackendBench/scripts/main.py | 33 +++++ BackendBench/suite/base.py | 3 +- BackendBench/suite/opinfo.py | 30 +++-- BackendBench/suite/torchbench.py | 21 +++- test/test_gradient_checks.py | 165 +++++++++++++++++++++++++ 8 files changed, 512 insertions(+), 20 deletions(-) create mode 100644 BackendBench/backwards_utils.py create mode 100644 test/test_gradient_checks.py diff --git a/BackendBench/backwards_utils.py b/BackendBench/backwards_utils.py new file mode 100644 index 00000000..71ebdad9 --- /dev/null +++ b/BackendBench/backwards_utils.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for backwards pass checking and gradient verification. +""" + +from typing import List + +import torch + +from BackendBench.scripts.op_map import query + +# Operations that should be exempted from backwards pass testing +BACKWARDS_PASS_TESTING_EXCEMPTIONS = [ + # We skip this op for 2 reasons: + # 1) This op has the args (shape, stride, storage_offset) where storage offset + # would change if a gradient is included in the tensor. Our suites (ie. opinfo) + # assume we are doing inference so storage is set to a bad value here. + # We'd have to write a custom suite for this. + # 2) As this is a tensor manipulation op, it doesn't really make sense to test + # a backwards pass for this yet. + "as_strided.default", + # The function is not differentiable with respect to argument 'running_mean'. + # This input cannot have requires_grad True. + # We likely need to handle this on the suite level. + "native_batch_norm.default", + "_native_batch_norm_legit.default", + "_batch_norm_with_update.default", + "native_batch_norm_backward.default", # in torchbench only + # The function 'soft_margin_loss' is not differentiable with respect to argument 'target'. + # This input cannot have requires_grad True. + "soft_margin_loss.default", + # The function 'multi_margin_loss' is not differentiable with respect to argument 'weight'. + # This input cannot have requires_grad True. + "multi_margin_loss.default", + # This op doesn't have a derivative unless it's defined explicitly. But there isn't a good way of detecting the fact that this op has no derivative. + "nextafter.default", + # This is the only op that does not pass opinfo + aten on backwards passes + # TODO: figure out why + "grid_sampler_2d.default", + # torchbench: gets IMA error when adding in the gradient on B200 + "max_pool2d_with_indices_backward.default", +] + + +def should_check_backwards_for_op(op_name: str, check_backwards: bool = True) -> bool: + """ + Determine if backwards checking should be performed for a given operation. + + Args: + op_name: The name of the operation (e.g., "aten.relu.default") + check_backwards: Whether backwards checking is globally enabled + + Returns: + True if backwards checking should be performed, False otherwise + """ + if not check_backwards: + return False + + # Check if op is in the exemption list + if op_name in BACKWARDS_PASS_TESTING_EXCEMPTIONS: + return False + + # Check if op is inplace (inplace ops are not supported for backwards checking) + op_map_entries = query(op_name) + if len(op_map_entries) == 1 and op_map_entries[0].get("is_inplace", False): + return False + + return True + + +def _apply_to_tensors(obj, tensor_fn, container_fn=None, accumulator=None): + """ + Generic functor to apply operations to tensors in nested data structures. + + Args: + obj: The object to traverse (tensor, list, tuple, dict, or other) + tensor_fn: Function to apply to each tensor. Should have signature (tensor, accumulator) -> Any + container_fn: Optional function to handle container reconstruction. + Signature: (container_type, transformed_items) -> Any + accumulator: Optional accumulator object passed to tensor_fn + + Returns: + Transformed object or None for in-place operations + """ + if isinstance(obj, torch.Tensor): + return tensor_fn(obj, accumulator) + elif isinstance(obj, list): + transformed = [ + _apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj + ] + return container_fn(list, transformed) if container_fn else transformed + elif isinstance(obj, tuple): + transformed = [ + _apply_to_tensors(item, tensor_fn, container_fn, accumulator) for item in obj + ] + return container_fn(tuple, transformed) if container_fn else tuple(transformed) + elif isinstance(obj, dict): + transformed = { + key: _apply_to_tensors(value, tensor_fn, container_fn, accumulator) + for key, value in obj.items() + } + return container_fn(dict, transformed) if container_fn else transformed + else: + # For immutable types or unknown types + return obj + + +def collect_gradients(args, kwargs) -> List[torch.Tensor]: + """ + Collect all gradients from args and kwargs into a flat list. + + Order is well-defined: + 1. Iterate through args in order + - If arg is a tensor with grad, append grad + - If arg is a list/tuple, iterate through elements in order and append tensor grads + 2. Iterate through kwargs in sorted key order + - If kwarg is a tensor with grad, append grad + - If kwarg is a list/tuple, iterate through elements in order and append tensor grads + + Args: + args: The arguments (can contain tensors or lists/tuples of tensors). + kwargs: The keyword arguments (can contain tensors or lists/tuples of tensors). + + Returns: + List of gradients (torch.Tensor) in the order specified above. + Returns empty list if no gradients are found. + """ + gradients = [] + + def collect_grad_fn(tensor, accumulator): + accumulator.append(tensor.grad) + + # Collect from args + for arg in args: + _apply_to_tensors(arg, collect_grad_fn, accumulator=gradients) + + # Collect from kwargs in sorted key order for deterministic ordering + for key in sorted(kwargs.keys()): + _apply_to_tensors(kwargs[key], collect_grad_fn, accumulator=gradients) + + return gradients + + +def make_tensors_require_gradients(args, kwargs): + def make_require_grad_fn(tensor, _): + # check dtype is floating or complex + if tensor.dtype not in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.complex64, + torch.complex128, + ]: + return + tensor.requires_grad = True + + _apply_to_tensors(args, make_require_grad_fn) + _apply_to_tensors(kwargs, make_require_grad_fn) + + +def clear_gradients(args, kwargs): + def clear_grad_fn(tensor, _): + if tensor.grad is not None: + tensor.grad = None + + _apply_to_tensors(args, clear_grad_fn) + _apply_to_tensors(kwargs, clear_grad_fn) diff --git a/BackendBench/eval.py b/BackendBench/eval.py index c88ee37f..44e2b427 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -6,12 +6,17 @@ import logging import math +import time import traceback from dataclasses import dataclass from typing import List, Tuple import torch +from BackendBench.backwards_utils import ( + clear_gradients, + collect_gradients, +) from BackendBench.utils import compute_errors, serialize_args, uses_cuda_stream @@ -26,6 +31,8 @@ class CorrectnessTestResult: max_abs_error: float = -math.inf max_rel_error: float = -math.inf test_type: str = "correctness" + has_correct_gradients: bool = False + checked_backwards: bool = False @dataclass @@ -90,25 +97,89 @@ def allclose(a, b, atol=1e-2, rtol=1e-2): return False -def eval_correctness_test(op, impl, test) -> CorrectnessTestResult: +def compare_gradients(res_grad, ref_grad, atol=1e-2, rtol=1e-2): + if res_grad is None and ref_grad is None: + return True + if res_grad is None or ref_grad is None: + raise ValueError("One of the gradients is None while the other is not.") + return allclose(res_grad, ref_grad, atol=atol, rtol=rtol) + + +def _check_if_output_has_backwards(output): + if isinstance(output, torch.Tensor): + # todo: ask why we have to do this and why isinstance(output.grad_fn, NotImplementedType) doesn't work for outputs of ops with no derivative like floor_divide.default + has_grad_fn = not (type(output.grad_fn).__name__ == "NotImplemented") + return output.requires_grad and has_grad_fn + elif isinstance(output, list) or isinstance(output, tuple): + return all(_check_if_output_has_backwards(x) for x in output) and len(output) > 0 + else: + return False + + +def _compute_loss(output): + if isinstance(output, torch.Tensor): + return output.sum() + elif isinstance(output, list) or isinstance(output, tuple): + return sum(_compute_loss(x) for x in output) + else: + raise ValueError(f"Unsupported type: {type(output)}") + + +def eval_correctness_test(op, impl, test, check_backwards=False) -> CorrectnessTestResult: """Evaluate impl of op against test. Returns: Tuple of (is_correct, error_message, absolute_error, relative_error) """ + + # Get the test_backwards flag from the test object if it exists + # The suite is responsible for setting this based on op capabilities + test_backwards = getattr(test, "test_backwards", False) + + # Combine with global check_backwards flag + check_backwards = check_backwards and test_backwards + args, kwargs = test.args, test.kwargs ref = op(*args, **kwargs) + + # we now modify check_backwards with another check. Specifically that ref is something that has gradients (aka returns a torch.tensor or a collection of torch.tensors as we cannot perform a backwards pass otherwise) + backwards_possible = _check_if_output_has_backwards(ref) + + check_backwards = backwards_possible and check_backwards + if check_backwards: + loss = _compute_loss(ref) + loss.backward() + ref_grads = collect_gradients(args, kwargs) + clear_gradients(args, kwargs) + else: + ref_grads = None + try: res = impl(*args, **kwargs) + if check_backwards: + loss = _compute_loss(res) + loss.backward() + res_grads = collect_gradients(args, kwargs) + clear_gradients(args, kwargs) + has_correct_gradients = compare_gradients(ref_grads, res_grads) + else: + res_grads = None + has_correct_gradients = False is_correct = allclose(ref, res) abs_error, rel_error = compute_errors(ref, res) + if check_backwards and not has_correct_gradients: + raise ValueError( + f"Gradients are not correct for {op.__name__} with args {serialize_args(args, kwargs)}" + ) result = CorrectnessTestResult( op_name=op.__name__, args=serialize_args(args, kwargs), is_correct=is_correct, max_abs_error=abs_error, max_rel_error=rel_error, + has_correct_gradients=has_correct_gradients, + checked_backwards=check_backwards, ) return result except Exception as e: @@ -125,14 +196,16 @@ def eval_correctness_test(op, impl, test) -> CorrectnessTestResult: return result -def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult]]: +def eval_correctness( + op, impl, tests, check_backwards=False +) -> Tuple[float, List[CorrectnessTestResult]]: """Evaluate correctness of impl against tests.""" correct, total = 0, 0 test_results: List[CorrectnessTestResult] = [] for test in tests: args_str = serialize_args(test.args, test.kwargs) logging.debug(f"Testing {op.__name__} with args {args_str}") - result = eval_correctness_test(op, impl, test) + result = eval_correctness_test(op, impl, test, check_backwards) test_results.append(result) if result.is_correct: correct += 1 @@ -148,7 +221,6 @@ def eval_correctness(op, impl, tests) -> Tuple[float, List[CorrectnessTestResult def cpu_bench(fn, num_runs=100): """Simple CPU benchmarking using time.perf_counter.""" - import time for _ in range(10): fn() @@ -164,6 +236,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult bench_fn = ( triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench ) + base_times = [] test_times = [] args_strs = [] @@ -176,6 +249,12 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult args_str = serialize_args(cached_args, cached_kwargs) args_strs.append(args_str) logging.debug(f"Benchmarking {op.__name__} with args {args_str}") + # Warmup: run both operations to compile CUDA kernels and warm up caches + for _ in range(25): + _ = op(*cached_args, **cached_kwargs) + _ = impl(*cached_args, **cached_kwargs) + if torch.cuda.is_available(): + torch.cuda.synchronize() base_time = bench_fn(lambda: op(*cached_args, **cached_kwargs)) base_times.append(base_time) # Note: If the test fails we consider the speedup to be 1.0 @@ -225,7 +304,7 @@ def eval_performance(op, impl, tests) -> Tuple[float, List[PerformanceTestResult def eval_one_op( - op, impl, correctness_tests, performance_tests + op, impl, correctness_tests, performance_tests, check_backwards=False ) -> Tuple[float, float, List[CorrectnessTestResult], List[PerformanceTestResult]]: """Evaluate impl of op against correctness_tests and performance_tests. @@ -261,7 +340,9 @@ def eval_one_op( ) return 0, 1.0, correctness_results, performance_results - correctness_score, correctness_results = eval_correctness(op, impl, correctness_tests) + correctness_score, correctness_results = eval_correctness( + op, impl, correctness_tests, check_backwards + ) performance_score, performance_results = eval_performance(op, impl, performance_tests) return ( correctness_score, diff --git a/BackendBench/multiprocessing_eval.py b/BackendBench/multiprocessing_eval.py index 09f86116..c24d36a4 100644 --- a/BackendBench/multiprocessing_eval.py +++ b/BackendBench/multiprocessing_eval.py @@ -43,6 +43,7 @@ class EvalTask: correctness_tests: List[Any] performance_tests: List[Any] device: str + check_backwards: bool = False @dataclass @@ -116,7 +117,13 @@ def test_to_device_iterator(tests, device): performance_score, correctness_results, performance_results, - ) = eval_one_op(op, impl, correctness_tests, performance_tests) + ) = eval_one_op( + op, + impl, + correctness_tests, + performance_tests, + check_backwards=task.check_backwards, + ) result = EvalResult( task_id=task.task_id, correctness_score=correctness_score, @@ -239,7 +246,9 @@ def __init__(self, num_workers: int = 1): logger.info(f"Initialized MultiprocessingEvaluator with {num_workers} workers") - def submit_task(self, op, impl, correctness_tests, performance_tests) -> int: + def submit_task( + self, op, impl, correctness_tests, performance_tests, check_backwards=False + ) -> int: task_id = self.next_task_id self.next_task_id += 1 if not is_pickleable(op): @@ -276,6 +285,7 @@ def submit_task(self, op, impl, correctness_tests, performance_tests) -> int: correctness_tests=cpu_correctness_tests, performance_tests=cpu_performance_tests, device=str(orig_device), + check_backwards=check_backwards, ) self.task_queue.put(task) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 3039e11a..51960221 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -148,6 +148,12 @@ def setup_logging(log_level): type=click.Choice(["triton", "pytorch", "cutedsl"]), help="Which DSL to use for LLM backend", ) +@click.option( + "--check-backwards", + default=False, + is_flag=True, + help="Check gradients of the result and reference", +) def cli( log_level, suite, @@ -166,6 +172,7 @@ def cli( check_overhead_dominated_ops, p, dsl, + check_backwards, ): if suite != "torchbench": if topn_inputs is not None: @@ -184,6 +191,7 @@ def cli( "cuda", torch.bfloat16, filter=ops, + check_backwards=check_backwards, ), "torchbench": lambda: TorchBenchTestSuite( "torchbench", @@ -191,6 +199,7 @@ def cli( filter=ops, topn=topn_inputs, check_overhead_dominated_ops=check_overhead_dominated_ops, + check_backwards=check_backwards, ), "facto": lambda: FactoTestSuite( "facto_cuda_bfloat16", @@ -231,6 +240,11 @@ def cli( timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = f"backendbench_output_{timestamp}" + if check_backwards: + assert backend_name == "directory" or backend_name == "aten", ( + "check-backwards is only supported for directory backend or aten backend (for smoketests)" + ) + overall_correctness = [] overall_performance = [] all_correctness_results = [] @@ -248,6 +262,7 @@ def cli( backend[test.op], test.correctness_tests, test.performance_tests, + check_backwards=check_backwards, ) overall_correctness.append(all(result.is_correct for result in correctness_results)) @@ -270,6 +285,7 @@ def cli( backend[test.op], test.correctness_tests, test.performance_tests, + check_backwards=check_backwards, ) # Start evaluation @@ -299,6 +315,23 @@ def cli( f"perf@p score (rate of correct samples with a speedup greater than p, p={p}): {perf_at_p_score:.2f}" ) + if check_backwards: + backwards_correctness = ( + torch.tensor( + [ + result.has_correct_gradients + for result in all_correctness_results + if result.checked_backwards + ] + ) + .float() + .mean() + .item() + ) + print( + f"backwards correctness score (mean pass rate over all operators which support backwards): {backwards_correctness:.2f}" + ) + command = "python -m BackendBench.scripts.main " + " ".join(sys.argv[1:]) # Save results if not disabled diff --git a/BackendBench/suite/base.py b/BackendBench/suite/base.py index 1f5fe635..71ec4800 100644 --- a/BackendBench/suite/base.py +++ b/BackendBench/suite/base.py @@ -6,9 +6,10 @@ class Test: - def __init__(self, *args, **kwargs): + def __init__(self, *args, test_backwards=False, **kwargs): self._args = args self._kwargs = kwargs + self.test_backwards = test_backwards @property def args(self): diff --git a/BackendBench/suite/opinfo.py b/BackendBench/suite/opinfo.py index 611c5063..bef3d3c6 100644 --- a/BackendBench/suite/opinfo.py +++ b/BackendBench/suite/opinfo.py @@ -10,6 +10,10 @@ from torch.testing._internal.common_methods_invocations import op_db from torch.utils._python_dispatch import TorchDispatchMode +from BackendBench.backwards_utils import ( + make_tensors_require_gradients, + should_check_backwards_for_op, +) from BackendBench.eval import allclose from .base import OpTest, TestSuite @@ -18,24 +22,33 @@ class OpInfoTest: - def __init__(self, *args, **kwargs): + def __init__(self, *args, test_backwards=False, **kwargs): self.args = args self.kwargs = kwargs + self.test_backwards = test_backwards class OpInfoOpTest(OpTest): - def __init__(self, op, correctness_tests, indices): + def __init__(self, op, correctness_tests, indices, check_backwards=False): self.op = op self._correctness_tests = correctness_tests self.indices = set(indices) self.performance_tests = [] + self._check_backwards = check_backwards @property def correctness_tests(self): + # Determine if this op should check backwards + test_backwards = should_check_backwards_for_op(self.op.__name__, self._check_backwards) + for idx, test in enumerate(self._correctness_tests): if idx in self.indices: # print(f"{idx} {test.input=} {test.args=} {test.kwargs=}") - yield OpInfoTest(test.input, *test.args, **test.kwargs) + if test_backwards: + make_tensors_require_gradients(test.args, test.kwargs) + yield OpInfoTest( + test.input, *test.args, test_backwards=test_backwards, **test.kwargs + ) class OpTracerMode(TorchDispatchMode): @@ -48,10 +61,11 @@ def __torch_dispatch__(self, fn, types, args=(), kwargs={}): self.ops.append(fn) self.args.append(args) self.kwargs.append(kwargs) + return fn(*args, **kwargs) -def build_op_tests(device, dtype, filter=None): +def build_op_tests(device, dtype, filter=None, check_backwards=False): op_info_op_tests = [] for op in op_db: if filter and op.name not in filter: @@ -85,11 +99,13 @@ def build_op_tests(device, dtype, filter=None): for overload, indices in op_indices.items(): if len(indices) > 0: - op_info_op_tests.append(OpInfoOpTest(overload, sample_inputs, indices)) + op_info_op_tests.append( + OpInfoOpTest(overload, sample_inputs, indices, check_backwards) + ) return op_info_op_tests class OpInfoTestSuite(TestSuite): - def __init__(self, name, device, dtype, filter=None): - super().__init__(name, build_op_tests(device, dtype, filter)) + def __init__(self, name, device, dtype, filter=None, check_backwards=False): + super().__init__(name, build_op_tests(device, dtype, filter, check_backwards)) diff --git a/BackendBench/suite/torchbench.py b/BackendBench/suite/torchbench.py index 2ee3d698..c45a0166 100644 --- a/BackendBench/suite/torchbench.py +++ b/BackendBench/suite/torchbench.py @@ -27,6 +27,10 @@ import torch # noqa: F401 +from BackendBench.backwards_utils import ( + make_tensors_require_gradients, + should_check_backwards_for_op, +) from BackendBench.data_loaders import ( _args_size, load_ops_from_source, @@ -37,16 +41,18 @@ class TorchBenchTest: - def __init__(self, *args, **kwargs): + def __init__(self, *args, test_backwards=False, **kwargs): self.args = args self.kwargs = kwargs + self.test_backwards = test_backwards class TorchBenchOpTest: - def __init__(self, op, inputs, topn): + def __init__(self, op, inputs, topn, check_backwards=False): self.op = eval(f"torch.ops.{op}") self.inputs = inputs self.topn = topn + self._check_backwards = check_backwards def tests(self): inputs_and_sizes = [] @@ -59,9 +65,14 @@ def tests(self): @property def correctness_tests(self): + # Determine if this op should check backwards + test_backwards = should_check_backwards_for_op(self.op.__name__, self._check_backwards) + for inp in self.tests(): args, kwargs = deserialize_args(inp) - yield TorchBenchTest(*args, **kwargs) + if test_backwards: + make_tensors_require_gradients(args, kwargs) + yield TorchBenchTest(*args, test_backwards=test_backwards, **kwargs) @property def performance_tests(self): @@ -78,9 +89,11 @@ def __init__( filter=None, topn=None, check_overhead_dominated_ops=False, + check_backwards=False, ): self.name = name self.topn = topn + self.check_backwards = check_backwards # Load operations using the shared data loader ops_list = load_ops_from_source( source=filename, @@ -102,4 +115,4 @@ def __iter__(self): for op, inputs in self.optests.items(): if any(s in op for s in UNSUPPORTED_OPERATORS): continue - yield TorchBenchOpTest(op, inputs, self.topn) + yield TorchBenchOpTest(op, inputs, self.topn, self.check_backwards) diff --git a/test/test_gradient_checks.py b/test/test_gradient_checks.py new file mode 100644 index 00000000..7cea2183 --- /dev/null +++ b/test/test_gradient_checks.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from BackendBench.backwards_utils import make_tensors_require_gradients +from BackendBench.eval import ( + _check_if_output_has_backwards, + clear_gradients, + collect_gradients, + eval_correctness_test, +) + + +class TestCollectGradients: + """Test the collect_gradients function.""" + + def test_collect_gradients_single_tensor(self): + """Test collecting gradients from a single tensor.""" + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + y = x.sum() + y.backward() + + grads = collect_gradients([x], {}) + assert len(grads) == 1 + assert torch.allclose(grads[0], torch.ones(3)) + + def test_collect_gradients_multiple_tensors(self): + """Test collecting gradients from multiple tensors.""" + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = torch.tensor([3.0, 4.0], requires_grad=True) + z = x.sum() + y.sum() + z.backward() + + grads = collect_gradients([x, y], {}) + assert len(grads) == 2 + assert torch.allclose(grads[0], torch.ones(2)) + assert torch.allclose(grads[1], torch.ones(2)) + + def test_collect_gradients_nested_list(self): + """Test collecting gradients from nested lists.""" + x = torch.tensor([1.0], requires_grad=True) + y = torch.tensor([2.0], requires_grad=True) + z = torch.tensor([3.0], requires_grad=True) + loss = (x + y + z).sum() + loss.backward() + + grads = collect_gradients([[x, y], z], {}) + assert len(grads) == 3 + for grad in grads: + assert torch.allclose(grad, torch.ones(1)) + + def test_collect_gradients_no_grad(self): + """Test collecting when tensors have no gradients.""" + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = torch.tensor([3.0, 4.0], requires_grad=True) + # No backward call, so no gradients + + grads = collect_gradients([x, y], {}) + assert len(grads) == 2 + assert grads[0] is None + assert grads[1] is None + + +class TestMakeTensorsRequireGradients: + """Test the make_tensors_require_gradients function.""" + + def test_make_tensors_require_grad(self): + """Test that integer tensors don't get requires_grad.""" + x = torch.tensor([1, 2, 3]) # int tensor + y = torch.tensor([1.0, 2.0, 3.0]) # float tensor + + make_tensors_require_gradients([x, y], {}) + + assert not x.requires_grad # int tensors can't require grad + assert y.requires_grad + + +class TestClearGradients: + """Test the clear_gradients function.""" + + def test_clear_gradients_single(self): + """Test clearing gradient from single tensor.""" + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = x.sum() + y.backward() + + assert x.grad is not None + + clear_gradients([x], {}) + + assert x.grad is None + + def test_clear_gradients_no_grad(self): + """Test clearing when there are no gradients.""" + x = torch.tensor([1.0], requires_grad=True) + + # Should not raise error + clear_gradients([x], {}) + + assert x.grad is None + + +class TestCheckIfOutputHasBackwards: + """Test the _check_if_output_has_backwards function.""" + + def test_check_tensor_with_grad_fn(self): + """Test tensor with grad_fn.""" + x = torch.tensor([1.0], requires_grad=True) + y = x * 2 + + assert _check_if_output_has_backwards(y) + + def test_check_tensor_without_grad_fn(self): + """Test tensor without grad_fn.""" + x = torch.tensor([1.0], requires_grad=False) + + assert not _check_if_output_has_backwards(x) + + +class TestEvalCorrectnessWithBackwards: + """Integration tests for eval_correctness_test with backwards checking.""" + + def test_eval_correctness_without_backwards(self): + """Test correctness evaluation without backwards checking.""" + op = torch.ops.aten.relu.default + impl = torch.ops.aten.relu.default + + class TestCase: + def __init__(self, args, kwargs): + self.args = args + self.kwargs = kwargs + self.test_backwards = False + + test = TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) + + result = eval_correctness_test(op, impl, test, check_backwards=False) + + assert result.is_correct + assert not result.checked_backwards + assert not result.has_correct_gradients + + def test_eval_correctness_backwards(self): + """Test backwards checking with multiple inputs.""" + op = torch.ops.aten.add.Tensor + impl = torch.ops.aten.add.Tensor + + class TestCase: + def __init__(self, args, kwargs): + self.args = args + self.kwargs = kwargs + self.test_backwards = True + + test = TestCase([torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0])], {}) + + make_tensors_require_gradients(test.args, test.kwargs) + + result = eval_correctness_test(op, impl, test, check_backwards=True) + + assert result.is_correct + assert result.checked_backwards + assert result.has_correct_gradients