synchronize for honest speed compare

This commit is contained in:
George Hotz 2023-03-24 10:24:27 -07:00
parent 1cb5b2d015
commit 23f88fb026
5 changed files with 26 additions and 11 deletions

View File

@ -10,6 +10,7 @@ import time
import numpy as np
np.set_printoptions(linewidth=160)
from functools import partial
from tinygrad.lazy import Device
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
@ -19,6 +20,14 @@ from tinygrad.jit import TinyJit
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
torch_device = torch.device('mps' if getenv("MPS", 0) else ('cuda' if getenv("TORCHCUDA", 0) else 'cpu'))
if str(torch_device) == "mps":
import torch.mps
sync = lambda: torch.mps.synchronize()
elif str(torch_device) == "cuda":
import torch.cuda
sync = lambda: torch.cuda.synchronize()
else:
sync = lambda: None
def colorize_float(x):
ret = f"{x:7.2f}x"
@ -47,9 +56,11 @@ def helper_test_speed(f1, *args):
if DEBUG >= 4: print("benchmark start")
st = time.monotonic()
ret = f1(*args)
# not ideal, it's copying (sometimes). why is this so slow in tinygrad?
if isinstance(ret, Tensor) or str(torch_device) == "cpu": ret.numpy()
else: ret.cpu().numpy()
if isinstance(ret, Tensor):
ret.realize()
Device[ret.device].synchronize()
else:
sync()
et = (time.monotonic() - st) * 1000
ets.append(et)
if DEBUG >= 4: print("benchmark stop")
@ -160,7 +171,7 @@ class TestSpeed(unittest.TestCase):
def test_gemm(self):
def f(a, b): return a @ b
helper_test_generic_square('gemm', 512, f, f)
helper_test_generic_square('gemm', 1024, f, f)
def test_gemm_unrolled(self):
N = 512

View File

@ -39,6 +39,7 @@ class Interpreted:
self.fxn_for_op = fxn_for_op
self.from_lazybuffer = from_lazybuffer
self.to_underlying = to_underlying
self.synchronize = lambda: None
self.codegen = None
def exec_ast(self, ast:LazyOp, output=None, context=None):
@ -122,8 +123,8 @@ class ASTRunner:
return min([(self.timeit(rawbufs, local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]
class Compiled:
def __init__(self, buffer: Type[RawBuffer], codegen, runtime):
self.buffer, self.codegen, self.runtime = buffer, codegen, runtime
def __init__(self, buffer: Type[RawBuffer], codegen, runtime, synchronize=lambda: None):
self.buffer, self.codegen, self.runtime, self.synchronize = buffer, codegen, runtime, synchronize
self.method_cache: Dict[str, ASTRunner] = {}
def exec_ast(self, ast:LazyOp, output):

View File

@ -45,4 +45,4 @@ class CUDACodegen(CStyleCodegen):
gid = [f'blockDim.{chr(120+i)}*blockIdx.{chr(120+i)}+threadIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram)
CUDABuffer = Compiled(RawCUDABuffer, CUDACodegen, CUDAProgram, cuda.Context.synchronize)

View File

@ -75,4 +75,4 @@ class CLCodegen(CStyleCodegen):
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)], uses_vload=True)
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram)
GPUBuffer = Compiled(CLBuffer, CLCodegen, CLProgram, CL.cl_queue.finish)

View File

@ -14,6 +14,10 @@ class _METAL:
self.mtl_buffers_in_flight: List[Any] = []
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueue()
# TODO: is there a better way to do this?
def synchronize(self):
for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
self.mtl_buffers_in_flight.clear()
METAL = _METAL()
class RawMetalBuffer(RawBufferMapped):
@ -22,8 +26,7 @@ class RawMetalBuffer(RawBufferMapped):
self._buf.release()
super().__del__()
def _buffer(self):
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()
METAL.mtl_buffers_in_flight.clear()
METAL.synchronize()
return self._buf.contents().as_buffer(self._buf.length())
def unwrap(x):
@ -80,4 +83,4 @@ class MetalCodegen(CStyleCodegen):
gid = [f"gid.{chr(120+i)}" for i in range(3)], lid = [f"lid.{chr(120+i)}" for i in range(3)],
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
MetalBuffer = Compiled(RawMetalBuffer, MetalCodegen, MetalProgram)
MetalBuffer = Compiled(RawMetalBuffer, MetalCodegen, MetalProgram, METAL.synchronize)