good changes from llama branch (#671)

* good changes from llama

* transpose behavior changed
This commit is contained in:
George Hotz 2023-03-09 20:51:22 -08:00 committed by GitHub
parent de1b6d3e08
commit 1a039306d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 76 additions and 55 deletions

1
.gitignore vendored
View File

@ -16,3 +16,4 @@ vertex.bin
recognize* recognize*
.idea .idea
disassemblers/applegpu disassemblers/applegpu
*.prof

View File

@ -1,16 +1,10 @@
import numpy as np import numpy as np
from tinygrad.runtime.ops_metal import CLBuffer, CLProgram from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
def benchmark(prog):
e = prog()
e.waitUntilCompleted()
return (e.GPUEndTime() - e.GPUStartTime())*1e9
def mb(prog, N=10): return min([benchmark(prog) for _ in range(N)])
N = 2048 N = 2048
a = CLBuffer(N*N*4) a = RawMetalBuffer(N*N*4)
b = CLBuffer(N*N*4) b = RawMetalBuffer(N*N*4)
c = CLBuffer(N*N*4) c = RawMetalBuffer(N*N*4)
nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32) nb = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32) nc = np.random.default_rng().standard_normal(size=(N,N), dtype=np.float32) #.astype(np.int32).astype(np.float32)
@ -23,7 +17,7 @@ c.copyin(nc)
FLOPS = N*N*N*2 FLOPS = N*N*N*2
prog = CLProgram("test", f""" prog = MetalProgram("test", f"""
#include <metal_stdlib> #include <metal_stdlib>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+ #include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal; using namespace metal;
@ -92,12 +86,12 @@ kernel void test(device float *a, device const float *data1, device const float
}} }}
}} }}
}}""") }}""")
tm = mb(lambda: prog([N*N//(2*4*4)], [4*32], a._cl, b._cl, c._cl)) tm = min([prog([N*N//(2*4*4)], [4*32], a, b, c, wait=True) for _ in range(10)])
na = a.toCPU().reshape(N,N) na = a.toCPU().reshape(N,N)
comp = nb@nc comp = nb@nc
if N <= 32: if N <= 32:
print(na) print(na)
print(comp) print(comp)
print(f"{N*N:10d} {tm*1e-3:9.2f} us, would be {FLOPS/tm:.2f} GFLOPS matmul") print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul")
np.testing.assert_allclose(na, comp, atol=1e-3) np.testing.assert_allclose(na, comp, atol=1e-3)

View File

@ -26,13 +26,13 @@ class TransformerBlock:
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \ .reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query, self.key, self.value]] for y in [self.query, self.key, self.value]]
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size) query = query.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, time) key = key.permute(order=(0,2,3,1)) # (bs, num_heads, head_size, time)
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, time, head_size) value = value.permute(order=(0,2,1,3)) # (bs, num_heads, time, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size)) score = query.dot(key) * (1 / np.sqrt(self.head_size))
weights = score.softmax() # (bs, num_heads, time, time) weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, time, num_heads, head_size) attention = weights.dot(value).permute(order=(0,2,1,3)) # (bs, time, num_heads, head_size)
return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out) return attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size)).linear(*self.out)

View File

@ -17,7 +17,7 @@ class ViT:
def patch_embed(self, x): def patch_embed(self, x):
x = x.conv2d(*self.embedding, stride=16) x = x.conv2d(*self.embedding, stride=16)
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1)) x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).permute(order=(0,2,1))
return x return x
def forward(self, x): def forward(self, x):

View File

@ -26,7 +26,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
tinygrad_fp = time.monotonic() - st tinygrad_fp = time.monotonic() - st
def compare(s, x,y,atol,rtol): def compare(s, x,y,atol,rtol):
if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch {x.shape} != {y.shape}" if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}"
try: try:
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
except Exception: except Exception:
@ -255,10 +255,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)))
def test_transpose(self): def test_transpose(self):
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1))) helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2))
helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(order=(2,1,0))) helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2))
helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1))) helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1)))
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0))) helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0)))
def test_reshape(self): def test_reshape(self):
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)))

View File

@ -1,6 +1,27 @@
from tinygrad.helpers import IMAGE from tinygrad.helpers import IMAGE, prod
from tinygrad.lazy import get_single_root from tinygrad.lazy import get_single_root
def image_dot_decorator(normal_dot):
if IMAGE == 0: return normal_dot
def image_dot(self, w):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2])
cin, cout = w.shape[-2], w.shape[-1]
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.permute(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.permute(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).permute(order=order)
return image_dot
def image_conv2d_decorator(normal_conv): def image_conv2d_decorator(normal_conv):
if IMAGE == 0: return normal_conv if IMAGE == 0: return normal_conv

View File

@ -21,6 +21,7 @@ def get_buffer(name, base='tinygrad.runtime'):
class _Device: class _Device:
def __init__(self) -> None: def __init__(self) -> None:
# TODO: make this dynamic to when you try to access the _buffers
self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in self._buffers : Dict[str, Type[DeviceBuffer]] = {x.upper():get_buffer(x) for x in
[os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None} [os.path.splitext(x)[0][len("ops_"):] for x in sorted(os.listdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "runtime"))) if x.startswith("ops_")] if x is not None}
self.DEFAULT : str = "CPU" self.DEFAULT : str = "CPU"

View File

@ -33,12 +33,16 @@ def map_buffers(real_srcs, x:LazyOp) -> LazyOp:
return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg) return LazyOp(x.op, tuple((map_buffers(real_srcs, y) if isinstance(y, LazyOp) else real_srcs[y]) for y in x.src), x.arg)
_T = TypeVar("_T") _T = TypeVar("_T")
class RawBuffer: class Copyable:
size : int
def __init__(self, size): raise NotImplementedError("must be implemented")
@classmethod @classmethod
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented") def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
def toCPU(self:RawBuffer) -> np.ndarray: raise NotImplementedError("must be implemented") def toCPU(self:Copyable) -> np.ndarray: raise NotImplementedError("must be implemented")
class RawBuffer(Copyable): # pylint: disable=abstract-method
def __init__(self, size:int):
self.size : int = size
GlobalCounters.mem_used += self.size
def __del__(self): GlobalCounters.mem_used -= self.size
class RawBufferCopyIn(RawBuffer): class RawBufferCopyIn(RawBuffer):
def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented") def copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
@ -58,7 +62,7 @@ class RawBufferCopyInOut(RawBufferCopyIn):
return x return x
# a placeholder class to extend by the exec classes # a placeholder class to extend by the exec classes
class DeviceBuffer(RawBuffer): class DeviceBuffer(Copyable):
_buf: Any # underlying buffer _buf: Any # underlying buffer
shape: Tuple[int, ...] shape: Tuple[int, ...]
@classmethod @classmethod

View File

@ -6,7 +6,9 @@ from tinygrad.ops import CompiledBuffer, RawBufferCopyIn
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawMallocBuffer(RawBufferCopyIn): class RawMallocBuffer(RawBufferCopyIn):
def __init__(self, size): self.size, self._buf = size, (ctypes.c_float * (size//4))() def __init__(self, size):
super().__init__(size)
self._buf = (ctypes.c_float * (size//4))()
def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4) def copyin(self, x:np.ndarray): ctypes.memmove(self._buf, x.ctypes.data, x.size*4)
def toCPU(self): return np.ctypeslib.as_array(self._buf) def toCPU(self): return np.ctypeslib.as_array(self._buf)

View File

@ -8,7 +8,9 @@ from tinygrad.ops import CompiledBuffer, RawBufferCopyInOut
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
class RawCUDABuffer(RawBufferCopyInOut): class RawCUDABuffer(RawBufferCopyInOut):
def __init__(self, size): self.size, self._cl = size, cuda.mem_alloc(size) def __init__(self, size):
super().__init__(size)
self._cl = cuda.mem_alloc(size)
def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream) def copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._cl, x, stream)
def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl) def copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._cl)

View File

@ -30,7 +30,7 @@ class CLBuffer(RawBufferCopyInOut):
# TODO: this can be in RawBuffer generically # TODO: this can be in RawBuffer generically
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list) BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
def __init__(self, size): def __init__(self, size): # pylint: disable=super-init-not-called
self.size = size self.size = size
if len(CLBuffer.BUFFER_CACHE[size]) > 0: if len(CLBuffer.BUFFER_CACHE[size]) > 0:
self._cl = CLBuffer.BUFFER_CACHE[size].pop() self._cl = CLBuffer.BUFFER_CACHE[size].pop()
@ -50,7 +50,7 @@ class CLImage(RawBuffer): # pylint: disable=abstract-method
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT) fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
IMAGE : Final = True IMAGE : Final = True
def __init__(self, shape): def __init__(self, shape): # pylint: disable=super-init-not-called
self.size, self._cl = shape, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0])) self.size, self._cl = shape, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height

View File

@ -20,9 +20,14 @@ class _METAL:
METAL = _METAL() METAL = _METAL()
class RawMetalBuffer(RawBufferCopyIn): class RawMetalBuffer(RawBufferCopyIn):
def __init__(self, size): self.size, self._cl = size, METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared) def __init__(self, size):
def __del__(self): self._cl.release() super().__init__(size)
def _as_np(self): return np.frombuffer(self._cl.contents().as_buffer(self._cl.length()), dtype=np.float32) self._cl = METAL.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
def __del__(self):
self._cl.release()
super().__del__()
def _buffer(self): return self._cl.contents().as_buffer(self._cl.length())
def _as_np(self, dtype=np.float32): return np.frombuffer(self._buffer(), dtype=dtype)
def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data) def copyin(self, x:np.ndarray): np.copyto(self._as_np(), x.reshape(-1).data)
def toCPU(self) -> np.ndarray: def toCPU(self) -> np.ndarray:
for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted() for cbuf in METAL.mtl_buffers_in_flight: cbuf.waitUntilCompleted()

View File

@ -5,7 +5,7 @@ import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten from tinygrad.helpers import prod, argfix, make_pair, getenv, DEBUG, flatten
from tinygrad.lazy import Device, LazyBuffer, LazyNumpyArray from tinygrad.lazy import Device, LazyBuffer, LazyNumpyArray
from tinygrad.image import image_conv2d_decorator from tinygrad.image import image_conv2d_decorator, image_dot_decorator
# An instantiation of the Function is the Context # An instantiation of the Function is the Context
class Function: class Function:
@ -252,8 +252,10 @@ class Tensor:
# (padding_left, padding_right, padding_top, padding_bottom) # (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1]))) def pad2d(self, padding:Tuple[int, ...]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1])))
# TODO: this is totally not transpose def transpose(self, ax1=1, ax2=0) -> Tensor:
def transpose(self, order=(1,0)) -> Tensor: return self.permute(order=order) order = list(range(len(self.shape)))
order[ax1], order[ax2] = order[ax2], order[ax1]
return self.permute(order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1])) def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
# ***** reduce ops ***** # ***** reduce ops *****
@ -335,23 +337,11 @@ class Tensor:
ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox) ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
@image_dot_decorator
def dot(self, w:Tensor) -> Tensor: def dot(self, w:Tensor) -> Tensor:
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) x = self.reshape(*self.shape[0:-1], 1, self.shape[-1])
bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) w = w.reshape(*w.shape[0:-2], 1, w.shape[-2], w.shape[-1]).transpose(-1, -2)
cin, cout = w.shape[-2], w.shape[-1] return (x*w).sum(-1).reshape(*x.shape[0:-2], -1)
out_shape_t = self.shape[0:-2] + (cout,-1)
if len(self.shape) > 1:
order = tuple(range(len(self.shape)-2)) + (len(self.shape)-1, len(self.shape)-2)
else:
order, out_shape_t = (0,), (cout, )
worder = tuple(range(len(w.shape)-2)) + (len(w.shape)-1, len(w.shape)-2)
# NOTE: with NHWC we can remove the transposes
# bs x groups*cin x H x W
cx = self.transpose(order=order).reshape(shape=(bs//groups, groups*cin, -1, 1))
# groups*cout x cin x H, W
cw = w.transpose(order=worder).reshape(shape=(groups*cout, cin, 1, 1))
return cx.conv2d(cw, groups=groups).reshape(shape=out_shape_t).transpose(order=order)
# ***** mlops (unary) ***** # ***** mlops (unary) *****
@ -363,6 +353,7 @@ class Tensor:
def __neg__(self): return 0.0-self def __neg__(self): return 0.0-self
def sqrt(self): return self.pow(0.5) def sqrt(self): return self.pow(0.5)
def rsqrt(self): return self.pow(-0.5)
def square(self): return self*self def square(self): return self*self
def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu() def clip(self, min_, max_): return ((self-min_).relu()+min_) - (self-max_).relu()
def abs(self): return self.relu() + (-self).relu() def abs(self): return self.relu() + (-self).relu()