mirror of https://github.com/commaai/tinygrad.git
shapetracker from newgpu (#456)
* shapetracker from newgpu * touchup ops * test * testst * thneed deletes unused inputs * test * bugfix
This commit is contained in:
parent
538b1d7f5b
commit
4885fce56e
|
@ -187,9 +187,8 @@ class LLVMBuffer(ExplicitExecAST):
|
|||
k = ASTKernel(ast)
|
||||
|
||||
# cached kernel
|
||||
key = str(ast) # TODO: does this uniquely determine the AST? No! The shapetracker can change. Do this better.
|
||||
if key in LLVMBuffer.func_cache:
|
||||
LLVMBuffer.func_cache[key](*[x._buf for x in k.bufs])
|
||||
if k.key in LLVMBuffer.func_cache:
|
||||
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
|
||||
return k.ret
|
||||
|
||||
# cache miss, we have to process the kernel
|
||||
|
@ -362,5 +361,5 @@ class LLVMBuffer(ExplicitExecAST):
|
|||
|
||||
loop_entry[-1].branch(loop_exit[-1]._block)
|
||||
loop_exit[0].ret_void()
|
||||
LLVMBuffer.func_cache[key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
|
||||
LLVMBuffer.func_cache[k.key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
|
||||
return k.ret
|
||||
|
|
|
@ -35,9 +35,15 @@ class Thneed:
|
|||
if len(nodes[n]['out_edges']) == 0:
|
||||
self.outputs.append(n)
|
||||
|
||||
for n in self.inputs.values():
|
||||
assert n in self.buffers_to_save, f"{n} was not an input"
|
||||
self.buffers_to_save.remove(n)
|
||||
fake_inputs = []
|
||||
for k,n in self.inputs.items():
|
||||
if n in self.buffers_to_save:
|
||||
self.buffers_to_save.remove(n)
|
||||
else:
|
||||
print(f"WARNING: {k} was not a used input, removing it")
|
||||
fake_inputs.append(k)
|
||||
for k in fake_inputs:
|
||||
del self.inputs[k]
|
||||
|
||||
def load(self, input_fn):
|
||||
float32 = not FLOAT16
|
||||
|
@ -266,8 +272,8 @@ class Thneed:
|
|||
events.append(prg.clprg(CL().cl_queue, *args))
|
||||
mt = time.monotonic()
|
||||
CL().cl_queue.finish()
|
||||
et = time.monotonic()
|
||||
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {(et-st)*1000.0:.2f} ms")
|
||||
et = time.monotonic() - st
|
||||
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
|
||||
|
||||
if DEBUGCL:
|
||||
total_runtime = 0
|
||||
|
@ -275,7 +281,7 @@ class Thneed:
|
|||
runtime = (e.profile.end - e.profile.start)
|
||||
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {prg.options}")
|
||||
total_runtime += runtime
|
||||
print(f"total runtime: {total_runtime/1e6:.2f} ms")
|
||||
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
|
||||
|
||||
def optimize_local_workgroup(self):
|
||||
MAX_WORKGROUP = CL.cl_ctx.devices[0].max_work_group_size
|
||||
|
|
|
@ -2,7 +2,8 @@
|
|||
import os, time, io, pathlib, sys
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
|
||||
|
||||
os.environ['OPT'] = '99'
|
||||
if os.getenv("OPT", None) is None:
|
||||
os.environ['OPT'] = '99'
|
||||
if os.getenv("GPU", None) is None:
|
||||
os.environ['OPENCL'] = '1'
|
||||
|
||||
|
@ -70,7 +71,7 @@ def compile(dat, output_fn):
|
|||
# initial run(s) to load weights
|
||||
for _ in range(2):
|
||||
st = time.monotonic()
|
||||
tinygrad_out = run_onnx(inputs)['outputs']
|
||||
tinygrad_out = next(iter(run_onnx(inputs).values()))
|
||||
mt = time.monotonic()
|
||||
tinygrad_out.realize()
|
||||
mt2 = time.monotonic()
|
||||
|
@ -87,7 +88,7 @@ def compile(dat, output_fn):
|
|||
# real run
|
||||
inputs, np_inputs = get_random_input_tensors(input_shapes)
|
||||
print("***** REAL RUN *****")
|
||||
tinygrad_out = run_onnx(inputs)['outputs']
|
||||
tinygrad_out = next(iter(run_onnx(inputs).values()))
|
||||
|
||||
# note, since CL.CACHE is enabled, it doesn't actually run the kernels
|
||||
CL.CACHE = []
|
||||
|
@ -101,7 +102,8 @@ def compile(dat, output_fn):
|
|||
from extra.thneed import Thneed
|
||||
t = Thneed(CL.CACHE, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()})
|
||||
CL.CACHE = None
|
||||
t.optimize_local_workgroup()
|
||||
if int(os.getenv("OPTWG", "0")):
|
||||
t.optimize_local_workgroup()
|
||||
|
||||
# save thneed (before run)
|
||||
t.save(output_fn)
|
||||
|
|
|
@ -14,7 +14,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato
|
|||
else:
|
||||
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps]
|
||||
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=True) for x in ts]
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts]
|
||||
|
||||
st = time.monotonic()
|
||||
out = torch_fxn(*ts)
|
||||
|
@ -225,6 +225,7 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(x, w),
|
||||
lambda x,w: x.conv2d(w), atol=1e-2)
|
||||
|
||||
#@unittest.skip("not supported with IMAGE=1")
|
||||
def test_large_bs_conv(self):
|
||||
# large batch size can cause OpenCL image to exceed max image height on macOS
|
||||
# (or cause the conv kernel to overflow short sampling coords)
|
||||
|
@ -232,6 +233,7 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(x, w),
|
||||
lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2)
|
||||
|
||||
#@unittest.skip("not supported with IMAGE=1")
|
||||
def test_large_ic_conv(self):
|
||||
# large input channel count can cause OpenCL image to exceed max image width on macOS
|
||||
helper_test_op([(1,2048,3,3), (1,2048,3,3)],
|
||||
|
@ -245,10 +247,15 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b), atol=1e-4)
|
||||
|
||||
def test_simple_conv2d(self):
|
||||
helper_test_op([(1,1,9,9), (1,1,3,3)],
|
||||
helper_test_op([(1,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_nested_conv2d(self):
|
||||
helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)],
|
||||
lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2).relu(),
|
||||
lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
# expect reduce nodes == 3
|
||||
def test_simple_conv2d_nhwc(self):
|
||||
# weights (from tf): filter_height x filter_width x in_channels x out_channels
|
||||
|
@ -301,6 +308,15 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_depthwise_conv2d(self):
|
||||
bs = 1
|
||||
groups = 32
|
||||
rcout = 1
|
||||
cin = 1
|
||||
helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_grouped_conv2d(self):
|
||||
bs = 4
|
||||
groups = 5
|
||||
|
|
|
@ -94,6 +94,12 @@ class TestComplexShapeTracker(unittest.TestCase):
|
|||
self.st.permute(1, 0, 2, 3)
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_fancy_factorize(self):
|
||||
self.st = ShapeTracker((32, 3, 3, 1))
|
||||
self.st.strided(*zip((32, 3, 3, 1), (1, 4096, 32, 1)))
|
||||
self.st.reshape(*(8, 4, 3, 3))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_super_complex_2_fail(self):
|
||||
self.st = ShapeTracker((4, 4, 4))
|
||||
self.st.permute(2, 0, 1)
|
||||
|
|
|
@ -9,10 +9,7 @@ from functools import partial
|
|||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
try:
|
||||
from termcolor import colored
|
||||
except ImportError:
|
||||
colored = lambda x, _: x
|
||||
from tinygrad.helpers import colored
|
||||
try:
|
||||
from tinygrad.llops.ops_gpu import CL
|
||||
except ImportError:
|
||||
|
@ -24,15 +21,12 @@ torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else 'cpu')
|
|||
|
||||
def colorize_float(x):
|
||||
ret = f"{x:7.2f}x"
|
||||
if colored:
|
||||
if x < 0.75:
|
||||
return colored(ret, 'green')
|
||||
elif x > 1.5:
|
||||
return colored(ret, 'red')
|
||||
else:
|
||||
return colored(ret, 'yellow')
|
||||
if x < 0.75:
|
||||
return colored(ret, 'green')
|
||||
elif x > 1.5:
|
||||
return colored(ret, 'red')
|
||||
else:
|
||||
return ret
|
||||
return colored(ret, 'yellow')
|
||||
|
||||
save_ops, save_mem = 0, 0
|
||||
CNT = 8
|
||||
|
@ -101,6 +95,14 @@ class TestSpeed(unittest.TestCase):
|
|||
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
|
||||
def f(a, b): return a.permute(1,0).contiguous()
|
||||
helper_test_generic_square('permute', N, f, f)
|
||||
|
||||
def test_double_permute(self):
|
||||
N = 64
|
||||
torch.manual_seed(0)
|
||||
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
|
||||
tiny_a = Tensor(torch_a.cpu().numpy())
|
||||
def f(a): return a.permute(1,0,3,2).contiguous()
|
||||
helper_test_generic(f"double_permute {tiny_a.shape}", partial(f, torch_a), partial(f, tiny_a))
|
||||
|
||||
def test_neg(self):
|
||||
def f(a, b): return -a
|
||||
|
@ -159,6 +161,20 @@ class TestSpeed(unittest.TestCase):
|
|||
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
|
||||
helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
|
||||
|
||||
def test_openpilot_conv2d(self):
|
||||
bs, in_chans, out_chans = 1,12,32
|
||||
torch.manual_seed(0)
|
||||
torch_dat = torch.rand(bs, 64, 128, 12).to(torch_device)
|
||||
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1).to(torch_device)
|
||||
|
||||
tiny_dat = Tensor(torch_dat.cpu().numpy())
|
||||
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
|
||||
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
|
||||
|
||||
def f1(): return torch_conv(torch_dat.permute(0,3,1,2))
|
||||
def f2(): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
|
||||
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2)
|
||||
|
||||
def test_conv2d(self):
|
||||
torch.manual_seed(0)
|
||||
for bs in [32]:
|
||||
|
|
|
@ -6,6 +6,7 @@ def prod(x): return math.prod(x)
|
|||
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], tuple) or isinstance(x[0], list) else tuple(x)
|
||||
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
||||
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
|
||||
def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line
|
||||
|
||||
def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape)))
|
||||
def shape_to_axis(old_shape, new_shape):
|
||||
|
|
|
@ -84,7 +84,16 @@ class ExplicitExecAST(DeviceBuffer):
|
|||
|
||||
# universal for shape tracked
|
||||
def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), self)
|
||||
|
||||
# TODO: creating a new object is making a copy, breaking the thneed compiler
|
||||
def contiguous(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
#def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
|
||||
|
||||
def get_first_reduce(shapes):
|
||||
for i in range(len(shapes[0])):
|
||||
if not all_same([x[i] for x in shapes]):
|
||||
return i
|
||||
return len(shapes[0]) # off the end
|
||||
|
||||
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
|
||||
class ASTKernel:
|
||||
|
@ -105,6 +114,9 @@ class ASTKernel:
|
|||
assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape"
|
||||
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
|
||||
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
self.key = str(ast)
|
||||
|
||||
def process(self):
|
||||
# get shape, strides, and offset
|
||||
# if it's a multiview buffer we take the final view
|
||||
|
@ -121,11 +133,7 @@ class ASTKernel:
|
|||
strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in strides]
|
||||
|
||||
# find first mismatch, don't reduce this
|
||||
first_reduce = -1
|
||||
for i in range(len(shapes[0])):
|
||||
if not all_same([x[i] for x in shapes]):
|
||||
first_reduce = i
|
||||
break
|
||||
first_reduce = get_first_reduce(shapes)
|
||||
|
||||
# merge dimensions if we can, multi get_shape_strides
|
||||
# TODO: does this always preserve the reduce dimension, NO
|
||||
|
@ -144,10 +152,14 @@ class ASTKernel:
|
|||
else:
|
||||
rets[j].append((shapes[j][i], strides[j][i]))
|
||||
self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets]
|
||||
self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged
|
||||
|
||||
# include the offsets (as is)
|
||||
self.offsets = [x.st.views[-1].offset for x in self.bufs]
|
||||
|
||||
@property
|
||||
def shape_len(self): return len(self.shapes[0])
|
||||
|
||||
# this should be aware of the three parts to the shape
|
||||
# * the input/output dimensions
|
||||
# * the reduce dimensions
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import functools
|
||||
from typing import Tuple, Union, List
|
||||
from tinygrad.helpers import prod
|
||||
|
||||
# TODO: fix DEBUG import
|
||||
DEBUG = int(os.getenv("DEBUG", "0"))
|
||||
|
||||
def divmodidx(acc, d, mod=True):
|
||||
lr = f"(idx//{acc})" if acc != 1 else "idx"
|
||||
return f"({lr}%{d})" if mod else lr # don't mod the top shape dimension
|
||||
|
@ -102,7 +106,7 @@ class ShapeTracker:
|
|||
return self
|
||||
|
||||
def reshape(self, *new_shape):
|
||||
assert all(isinstance(x, int) for x in new_shape)
|
||||
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
|
||||
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
|
||||
|
||||
# check if this is adding or removing 1s (only)
|
||||
|
@ -127,21 +131,26 @@ class ShapeTracker:
|
|||
while len(new_strides) != len(new_shape):
|
||||
assert new_shape[len(new_strides)] == 1
|
||||
new_strides.append(1)
|
||||
self.views[-1] = View(new_shape, new_strides, self.offset)
|
||||
return self # early return, it factorized!
|
||||
break
|
||||
curr_dim, curr_stride = min_shape_strides.pop(0)
|
||||
else:
|
||||
break # didn't factorize
|
||||
|
||||
if len(new_shape) == len(new_strides):
|
||||
self.views[-1] = View(new_shape, new_strides, self.offset)
|
||||
return self
|
||||
|
||||
view = View(new_shape, strides_for_shape(new_shape))
|
||||
if self.contiguous:
|
||||
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
|
||||
else:
|
||||
if DEBUG >= 2:
|
||||
print(f"WARNING: reshape from {self.shape} w strides {self.strides} -> {new_shape} is creating another view")
|
||||
self.views.append(view)
|
||||
return self
|
||||
|
||||
def permute(self, *axis):
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis)
|
||||
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
|
||||
self.views[-1] = View([self.shape[a] for a in axis], [self.strides[a] for a in axis], self.offset)
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue