mirror of https://github.com/commaai/tinygrad.git
torch cuda is very fast
This commit is contained in:
parent
a949de873b
commit
6fe9edf30f
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
|
||||
import unittest
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
@ -17,7 +18,7 @@ except ImportError:
|
|||
|
||||
IN_CHANS = [int(x) for x in os.getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
|
||||
torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else 'cpu')
|
||||
torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else ('cuda' if int(os.getenv("CUDA", "0")) else 'cpu'))
|
||||
|
||||
def colorize_float(x):
|
||||
ret = f"{x:7.2f}x"
|
||||
|
@ -42,9 +43,9 @@ def helper_test_speed(f1, *args):
|
|||
ret = f1(*args)
|
||||
if CL is not None and ret.device in ["GPU", "OPENCL"]:
|
||||
CL.cl_queue.finish()
|
||||
if "mps" in str(ret.device):
|
||||
if torch_device != "cpu":
|
||||
# TODO: better way to sync?
|
||||
torch.zeros(1, device='mps').cpu()
|
||||
torch.zeros(1, device=torch_device).cpu()
|
||||
et = (time.monotonic() - st) * 1000
|
||||
ets.append(et)
|
||||
if GlobalCounters.global_ops:
|
||||
|
@ -70,7 +71,7 @@ def helper_test_generic(name, f1, f2):
|
|||
|
||||
flops = save_ops*1e-6
|
||||
mem = save_mem*4*1e-6
|
||||
print(f"{prefix}{name:40s} {et_torch:7.2f} ms ({flops/et_torch:7.2f} GFLOPS {mem/et_torch:7.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:7.2f} GFLOPS {mem/et_tinygrad:7.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower {flops:7.2f} MOPS {mem:7.2f} MB")
|
||||
print(f"{prefix}{name:40s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} slower {flops:7.2f} MOPS {mem:7.2f} MB")
|
||||
prefix = " "
|
||||
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-4, rtol=1e-3)
|
||||
|
||||
|
|
Loading…
Reference in New Issue