good changes from llama branch (#671)

* good changes from llama

* transpose behavior changed
This commit is contained in:
George Hotz 2023-03-09 20:51:22 -08:00 committed by GitHub
parent de1b6d3e08
commit 1a039306d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 76 additions and 55 deletions

1
.gitignore vendored
View File

@ -16,3 +16,4 @@ vertex.bin
recognize*
.idea
disassemblers/applegpu
*.prof

View File

@ -1,16 +1,10 @@
import numpy as np
from tinygrad.runtime.ops_metal import CLBuffer, CLProgram
def benchmark(prog):
e = prog()
e.waitUntilCompleted()
return (e.GPUEndTime() - e.GPUStartTime())*1e9
def mb(prog, N=10): return min([benchmark(prog) for _ in range(N)])
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
N = 2048
a = CLBuffer(N*N*4)
b = CLBuffer(N*N*4)
c = CLBuffer(N*N*4)
a = RawMetalBuffer(N*N*4)
b = RawMetalBuffer(N*N*4)
c = RawMetalBuffer(N*N*4)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
@ -23,7 +17,7 @@ c.copyin(nc)
FLOPS = N*N*N*2
prog = CLProgram("test", f"""
prog = MetalProgram("test", f"""
#include <metal_stdlib>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal;
@ -92,12 +86,12 @@ kernel void test(device float *a, device const float *data1, device const float
}}
}}
}}""")
tm = mb(lambda: prog([N*N//(2*4*4)], [4*32], a._cl, b._cl, c._cl))
tm = min([prog([N*N//(2*4*4)], [4*32], a, b, c, wait=True) for _ in range(10)])
na = a.toCPU().reshape(N,N)
comp = nb@nc
if N <= 32:
print(na)
print(comp)
print(f"{N*N:10d} {tm*1e-3:9.2f} us, would be {FLOPS/tm:.2f} GFLOPS matmul")
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul")
np.testing.assert_allclose(na, comp, atol=1e-3)

View File

@ -26,13 +26,13 @@ class TransformerBlock:
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query, self.key, self.value]]
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
query = query.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
key = key.permute(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
value = value.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
attention = weights.dot(value).permute(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)

View File

@ -17,7 +17,7 @@ class ViT:
def patch_embed(self, x):
x = x.conv2d(*self.embedding, stride=16)
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).permute(order=(0,2,1))
return x
def forward(self, x):

View File

@ -26,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
tinygrad_fp = time.monotonic() - st
def compare(s, x,y,atol,rtol):
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch {x.shape} != {y.shape}"
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}"
try:
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
except Exception:
@ -255,10 +255,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
def test_transpose(self):
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)))
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(order=(2,1,0)))
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)))
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)))
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2))
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))

View File

@ -1,6 +1,27 @@
from tinygrad.helpers import IMAGE
from tinygrad.helpers import IMAGE, prod
from tinygrad.lazy import get_single_root
def image_dot_decorator(normal_dot):
if IMAGE == 0: return normal_dot
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
return image_dot
def image_conv2d_decorator(normal_conv):
if IMAGE == 0: return normal_conv

View File

@ -21,6 +21,7 @@ def get_buffer(name, base='tinygrad.runtime'):
class _Device:
def __init__(self) -> None:
# TODO: make this dynamic to when you try to access the _buffers
self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in
[os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None}
self.DEFAULT : str = "CPU"

View File

@ -33,12 +33,16 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
_T = TypeVar("_T")
class RawBuffer:
size : int
def __init__(self, size): raise NotImplementedError("must be implemented")
class Copyable:
@classmethod
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
def toCPU(self:RawBuffer) -> np.ndarray: raise NotImplementedError("must be implemented")
def toCPU(self:Copyable) -> np.ndarray: raise NotImplementedError("must be implemented")
class RawBuffer(Copyable): # pylint: disable=abstract-method
def __init__(self, size:int):
self.size : int = size
GlobalCounters.mem_used += self.size
def __del__(self): GlobalCounters.mem_used -= self.size
class RawBufferCopyIn(RawBuffer):
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
@ -58,7 +62,7 @@ class RawBufferCopyInOut(RawBufferCopyIn):
return x
# a placeholder class to extend by the exec classes
class DeviceBuffer(RawBuffer):
class DeviceBuffer(Copyable):
_buf: Any # underlying buffer
shape: Tuple[int, ...]
@classmethod

View File

@ -6,7 +6,9 @@ from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawMallocBuffer(RawBufferCopyIn):
def __init__(self, size): self.size, self._buf = size, (ctypes.c_float * (size//4))()
def __init__(self, size):
super().__init__(size)
self._buf = (ctypes.c_float * (size//4))()
def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4)
def toCPU(self): return np.ctypeslib.as_array(self._buf)

View File

@ -8,7 +8,9 @@ from tinygrad.ops import CompiledBuffer, RawBufferCopyInOut
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size): self.size, self._cl = size, cuda.mem_alloc(size)
def __init__(self, size):
super().__init__(size)
self._cl = cuda.mem_alloc(size)
def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)

View File

@ -30,7 +30,7 @@ class CLBuffer(RawBufferCopyInOut):
# TODO: this can be in RawBuffer generically
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
def __init__(self, size):
def __init__(self, size): # pylint: disable=super-init-not-called
self.size = size
if len(CLBuffer.BUFFER_CACHE[size]) > 0:
self._cl = CLBuffer.BUFFER_CACHE[size].pop()
@ -50,7 +50,7 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
IMAGE : Final = True
def __init__(self, shape):
def __init__(self, shape): # pylint: disable=super-init-not-called
self.size, self._cl = shape, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height

View File

@ -20,9 +20,14 @@ class _METAL:
METAL = _METAL()
class RawMetalBuffer(RawBufferCopyIn):
def __init__(self, size): self.size, self._cl = size, METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
def __del__(self): self._cl.release()
def _as_np(self): return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32)
def __init__(self, size):
super().__init__(size)
self._cl = METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
def __del__(self):
self._cl.release()
super().__del__()
def _buffer(self): return self._cl.contents().as_buffer(self._cl.length())
def _as_np(self, dtype=np.float32): return np.frombuffer(self._buffer(), dtype=dtype)
def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data)
def toCPU(self) -> np.ndarray:
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()

View File

@ -5,7 +5,7 @@ import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
from tinygrad.lazy import Device, LazyBuffer, LazyNumpyArray
from tinygrad.image import image_conv2d_decorator
from tinygrad.image import image_conv2d_decorator, image_dot_decorator
# An instantiation of the Function is the Context
class Function:
@ -252,8 +252,10 @@ class Tensor:
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])))
# TODO: this is totally not transpose
def transpose(self, order=(1,0)) -> Tensor: return self.permute(order=order)
def transpose(self, ax1=1, ax2=0) -> Tensor:
order = list(range(len(self.shape)))
order[ax1], order[ax2] = order[ax2], order[ax1]
return self.permute(order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
# ***** reduce ops *****
@ -335,23 +337,11 @@ class Tensor:
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
@image_dot_decorator
def dot(self, w:Tensor) -> Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
return (x*w).sum(-1).reshape(*x.shape[0:-2], -1)
# ***** mlops (unary) *****
@ -363,6 +353,7 @@ class Tensor:
def __neg__(self): return 0.0-self
def sqrt(self): return self.pow(0.5)
def rsqrt(self): return self.pow(-0.5)
def square(self): return self*self
def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu()
def abs(self): return self.relu() + (-self).relu()