From 1a039306d2a7d9c381a4b687c821c18dfcf242df Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 9 Mar 2023 20:51:22 -0800 Subject: [PATCH] good changes from llama branch (#671) * good changes from llama * transpose behavior changed --- .gitignore | 1 + extra/gemm/metal_matmul.py | 20 +++++++------------- models/transformer.py | 8 ++++---- models/vit.py | 2 +- test/test_ops.py | 10 +++++----- tinygrad/image.py | 23 ++++++++++++++++++++++- tinygrad/lazy.py | 1 + tinygrad/ops.py | 14 +++++++++----- tinygrad/runtime/ops_clang.py | 4 +++- tinygrad/runtime/ops_cuda.py | 4 +++- tinygrad/runtime/ops_gpu.py | 4 ++-- tinygrad/runtime/ops_metal.py | 11 ++++++++--- tinygrad/tensor.py | 29 ++++++++++------------------- 13 files changed, 76 insertions(+), 55 deletions(-) diff --git a/.gitignore b/.gitignore index 386263bc..bfcc3fc7 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ vertex.bin recognize* .idea disassemblers/applegpu +*.prof diff --git a/extra/gemm/metal_matmul.py b/extra/gemm/metal_matmul.py index 64eca2c1..cebd0b18 100644 --- a/extra/gemm/metal_matmul.py +++ b/extra/gemm/metal_matmul.py @@ -1,16 +1,10 @@ import numpy as np -from tinygrad.runtime.ops_metal import CLBuffer, CLProgram - -def benchmark(prog): - e = prog() - e.waitUntilCompleted() - return (e.GPUEndTime() - e.GPUStartTime())*1e9 -def mb(prog, N=10): return min([benchmark(prog) for _ in range(N)]) +from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram N = 2048 -a = CLBuffer(N*N*4) -b = CLBuffer(N*N*4) -c = CLBuffer(N*N*4) +a = RawMetalBuffer(N*N*4) +b = RawMetalBuffer(N*N*4) +c = RawMetalBuffer(N*N*4) nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32) nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32) @@ -23,7 +17,7 @@ c.copyin(nc) FLOPS = N*N*N*2 -prog = CLProgram("test", f""" +prog = MetalProgram("test", f""" #include #include // Available from Metal version 2.3 released with OS X 11.0+ using namespace metal; @@ -92,12 +86,12 @@ kernel void test(device float *a, device const float *data1, device const float }} }} }}""") -tm = mb(lambda: prog([N*N//(2*4*4)], [4*32], a._cl, b._cl, c._cl)) +tm = min([prog([N*N//(2*4*4)], [4*32], a, b, c, wait=True) for _ in range(10)]) na = a.toCPU().reshape(N,N) comp = nb@nc if N <= 32: print(na) print(comp) -print(f"{N*N:10d} {tm*1e-3:9.2f} us, would be {FLOPS/tm:.2f} GFLOPS matmul") +print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul") np.testing.assert_allclose(na, comp, atol=1e-3) diff --git a/models/transformer.py b/models/transformer.py index 4255845e..f37f5c8f 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -26,13 +26,13 @@ class TransformerBlock: .reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \ for y in [self.query, self.key, self.value]] - query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size) - key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time) - value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size) + query = query.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size) + key = key.permute(order=(0,2,3,1)) # (bs, num_heads, head_size, time) + value = value.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size) score = query.dot(key) * (1 / np.sqrt(self.head_size)) weights = score.softmax() # (bs, num_heads, time, time) - attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size) + attention = weights.dot(value).permute(order=(0,2,1,3)) # (bs, time, num_heads, head_size) return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out) diff --git a/models/vit.py b/models/vit.py index 1791e3b8..ec1ab40d 100644 --- a/models/vit.py +++ b/models/vit.py @@ -17,7 +17,7 @@ class ViT: def patch_embed(self, x): x = x.conv2d(*self.embedding, stride=16) - x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1)) + x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).permute(order=(0,2,1)) return x def forward(self, x): diff --git a/test/test_ops.py b/test/test_ops.py index ac9689a8..ca941c92 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -26,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra tinygrad_fp = time.monotonic() - st def compare(s, x,y,atol,rtol): - if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch {x.shape} != {y.shape}" + if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}" try: np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) except Exception: @@ -255,10 +255,10 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) def test_transpose(self): - helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1))) - helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(order=(2,1,0))) - helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1))) - helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0))) + helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) + helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2)) + helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1))) + helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0))) def test_reshape(self): helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) diff --git a/tinygrad/image.py b/tinygrad/image.py index 068566a3..34c3b865 100644 --- a/tinygrad/image.py +++ b/tinygrad/image.py @@ -1,6 +1,27 @@ -from tinygrad.helpers import IMAGE +from tinygrad.helpers import IMAGE, prod from tinygrad.lazy import get_single_root +def image_dot_decorator(normal_dot): + if IMAGE == 0: return normal_dot + def image_dot(self, w): + # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) + bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) + cin, cout = w.shape[-2], w.shape[-1] + out_shape_t = self.shape[0:-2] + (cout,-1) + if len(self.shape) > 1: + order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2) + else: + order, out_shape_t = (0,), (cout, ) + worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2) + + # NOTE: with NHWC we can remove the transposes + # bs x groups*cin x H x W + cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1)) + # groups*cout x cin x H, W + cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1)) + return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order) + return image_dot + def image_conv2d_decorator(normal_conv): if IMAGE == 0: return normal_conv diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 879b2fdf..0edb9e0b 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -21,6 +21,7 @@ def get_buffer(name, base='tinygrad.runtime'): class _Device: def __init__(self) -> None: + # TODO: make this dynamic to when you try to access the _buffers self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in [os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None} self.DEFAULT : str = "CPU" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ff72f160..63e69919 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -33,12 +33,16 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp: return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg) _T = TypeVar("_T") -class RawBuffer: - size : int - def __init__(self, size): raise NotImplementedError("must be implemented") +class Copyable: @classmethod def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented") - def toCPU(self:RawBuffer) -> np.ndarray: raise NotImplementedError("must be implemented") + def toCPU(self:Copyable) -> np.ndarray: raise NotImplementedError("must be implemented") + +class RawBuffer(Copyable): # pylint: disable=abstract-method + def __init__(self, size:int): + self.size : int = size + GlobalCounters.mem_used += self.size + def __del__(self): GlobalCounters.mem_used -= self.size class RawBufferCopyIn(RawBuffer): def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented") @@ -58,7 +62,7 @@ class RawBufferCopyInOut(RawBufferCopyIn): return x # a placeholder class to extend by the exec classes -class DeviceBuffer(RawBuffer): +class DeviceBuffer(Copyable): _buf: Any # underlying buffer shape: Tuple[int, ...] @classmethod diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 1a36afcd..bbd4775e 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -6,7 +6,9 @@ from tinygrad.ops import CompiledBuffer, RawBufferCopyIn from tinygrad.codegen.gpu import GPUCodegen, GPULanguage class RawMallocBuffer(RawBufferCopyIn): - def __init__(self, size): self.size, self._buf = size, (ctypes.c_float * (size//4))() + def __init__(self, size): + super().__init__(size) + self._buf = (ctypes.c_float * (size//4))() def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4) def toCPU(self): return np.ctypeslib.as_array(self._buf) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 1de21508..ef9d29e2 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -8,7 +8,9 @@ from tinygrad.ops import CompiledBuffer, RawBufferCopyInOut from tinygrad.codegen.gpu import GPUCodegen, GPULanguage class RawCUDABuffer(RawBufferCopyInOut): - def __init__(self, size): self.size, self._cl = size, cuda.mem_alloc(size) + def __init__(self, size): + super().__init__(size) + self._cl = cuda.mem_alloc(size) def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream) def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl) diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 31df47de..ae182157 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -30,7 +30,7 @@ class CLBuffer(RawBufferCopyInOut): # TODO: this can be in RawBuffer generically BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list) - def __init__(self, size): + def __init__(self, size): # pylint: disable=super-init-not-called self.size = size if len(CLBuffer.BUFFER_CACHE[size]) > 0: self._cl = CLBuffer.BUFFER_CACHE[size].pop() @@ -50,7 +50,7 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT) IMAGE : Final = True - def __init__(self, shape): + def __init__(self, shape): # pylint: disable=super-init-not-called self.size, self._cl = shape, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0])) GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index fe1b6fcf..5f1bdb4a 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -20,9 +20,14 @@ class _METAL: METAL = _METAL() class RawMetalBuffer(RawBufferCopyIn): - def __init__(self, size): self.size, self._cl = size, METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) - def __del__(self): self._cl.release() - def _as_np(self): return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32) + def __init__(self, size): + super().__init__(size) + self._cl = METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) + def __del__(self): + self._cl.release() + super().__del__() + def _buffer(self): return self._cl.contents().as_buffer(self._cl.length()) + def _as_np(self, dtype=np.float32): return np.frombuffer(self._buffer(), dtype=dtype) def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data) def toCPU(self) -> np.ndarray: for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9d3db5c9..9c91fa80 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -5,7 +5,7 @@ import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten from tinygrad.lazy import Device, LazyBuffer, LazyNumpyArray -from tinygrad.image import image_conv2d_decorator +from tinygrad.image import image_conv2d_decorator, image_dot_decorator # An instantiation of the Function is the Context class Function: @@ -252,8 +252,10 @@ class Tensor: # (padding_left, padding_right, padding_top, padding_bottom) def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1]))) - # TODO: this is totally not transpose - def transpose(self, order=(1,0)) -> Tensor: return self.permute(order=order) + def transpose(self, ax1=1, ax2=0) -> Tensor: + order = list(range(len(self.shape))) + order[ax1], order[ax2] = order[ax2], order[ax1] + return self.permute(order) def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1])) # ***** reduce ops ***** @@ -335,23 +337,11 @@ class Tensor: ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) + @image_dot_decorator def dot(self, w:Tensor) -> Tensor: - # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) - bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) - cin, cout = w.shape[-2], w.shape[-1] - out_shape_t = self.shape[0:-2] + (cout,-1) - if len(self.shape) > 1: - order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2) - else: - order, out_shape_t = (0,), (cout, ) - worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2) - - # NOTE: with NHWC we can remove the transposes - # bs x groups*cin x H x W - cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1)) - # groups*cout x cin x H, W - cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1)) - return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order) + x = self.reshape(*self.shape[0:-1], 1, self.shape[-1]) + w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2) + return (x*w).sum(-1).reshape(*x.shape[0:-2], -1) # ***** mlops (unary) ***** @@ -363,6 +353,7 @@ class Tensor: def __neg__(self): return 0.0-self def sqrt(self): return self.pow(0.5) + def rsqrt(self): return self.pow(-0.5) def square(self): return self*self def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu() def abs(self): return self.relu() + (-self).relu()