From 1867d3cf561ad8bc0c54c2b8879530c43992b8f6 Mon Sep 17 00:00:00 2001 From: Tong Date: Sun, 2 Mar 2025 20:37:04 -0500 Subject: [PATCH] Updated validate in utilities.py to be able to check torch tensor too --- npbench/infrastructure/utilities.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/npbench/infrastructure/utilities.py b/npbench/infrastructure/utilities.py index 7bdf8105..ce3f804f 100644 --- a/npbench/infrastructure/utilities.py +++ b/npbench/infrastructure/utilities.py @@ -158,6 +158,9 @@ def validate(ref, val, framework="Unknown", rtol=1e-5, atol=1e-8, norm_error=1e- val = [val] valid = True for r, v in zip(ref, val): + if f"{type(v).__module__}.{type(v).__name__}" == "torch.Tensor": + v = v.cpu().numpy() + if not np.allclose(r, v, rtol=rtol, atol=atol): try: import cupy