diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index b9f63a57..808b84eb 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -1,4 +1,5 @@ import os +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" import unittest import torch torch.set_num_threads(1) @@ -17,7 +18,7 @@ except ImportError: IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "4,16,64").split(",")] -torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else 'cpu') +torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else ('cuda' if int(os.getenv("CUDA", "0")) else 'cpu')) def colorize_float(x): ret = f"{x:7.2f}x" @@ -42,9 +43,9 @@ def helper_test_speed(f1, *args): ret = f1(*args) if CL is not None and ret.device in ["GPU", "OPENCL"]: CL.cl_queue.finish() - if "mps" in str(ret.device): + if torch_device != "cpu": # TODO: better way to sync? - torch.zeros(1, device='mps').cpu() + torch.zeros(1, device=torch_device).cpu() et = (time.monotonic() - st) * 1000 ets.append(et) if GlobalCounters.global_ops: @@ -70,7 +71,7 @@ def helper_test_generic(name, f1, f2): flops = save_ops*1e-6 mem = save_mem*4*1e-6 - print(f"{prefix}{name:40s} {et_torch:7.2f} ms ({flops/et_torch:7.2f} GFLOPS {mem/et_torch:7.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:7.2f} GFLOPS {mem/et_tinygrad:7.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower {flops:7.2f} MOPS {mem:7.2f} MB") + print(f"{prefix}{name:40s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower {flops:7.2f} MOPS {mem:7.2f} MB") prefix = " " np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3)