diff --git a/npbench/infrastructure/utilities.py b/npbench/infrastructure/utilities.py index 7bdf8105..ce3f804f 100644 --- a/npbench/infrastructure/utilities.py +++ b/npbench/infrastructure/utilities.py @@ -158,6 +158,9 @@ def validate(ref, val, framework="Unknown", rtol=1e-5, atol=1e-8, norm_error=1e- val = [val] valid = True for r, v in zip(ref, val): + if f"{type(v).__module__}.{type(v).__name__}" == "torch.Tensor": + v = v.cpu().numpy() + if not np.allclose(r, v, rtol=rtol, atol=atol): try: import cupy