shapetracker from newgpu (#456)

* shapetracker from newgpu

* touchup ops

* test

* testst

* thneed deletes unused inputs

* test

* bugfix
This commit is contained in:
George Hotz 2023-01-09 12:40:01 -08:00 committed by GitHub
parent 538b1d7f5b
commit 4885fce56e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 104 additions and 37 deletions

View File

@ -187,9 +187,8 @@ class LLVMBuffer(ExplicitExecAST):
k = ASTKernel(ast)
# cached kernel
key = str(ast) # TODO: does this uniquely determine the AST? No! The shapetracker can change. Do this better.
if key in LLVMBuffer.func_cache:
LLVMBuffer.func_cache[key](*[x._buf for x in k.bufs])
if k.key in LLVMBuffer.func_cache:
LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs])
return k.ret
# cache miss, we have to process the kernel
@ -362,5 +361,5 @@ class LLVMBuffer(ExplicitExecAST):
loop_entry[-1].branch(loop_exit[-1]._block)
loop_exit[0].ret_void()
LLVMBuffer.func_cache[key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
LLVMBuffer.func_cache[k.key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs))
return k.ret

View File

@ -35,9 +35,15 @@ class Thneed:
if len(nodes[n]['out_edges']) == 0:
self.outputs.append(n)
for n in self.inputs.values():
assert n in self.buffers_to_save, f"{n} was not an input"
self.buffers_to_save.remove(n)
fake_inputs = []
for k,n in self.inputs.items():
if n in self.buffers_to_save:
self.buffers_to_save.remove(n)
else:
print(f"WARNING: {k} was not a used input, removing it")
fake_inputs.append(k)
for k in fake_inputs:
del self.inputs[k]
def load(self, input_fn):
float32 = not FLOAT16
@ -266,8 +272,8 @@ class Thneed:
events.append(prg.clprg(CL().cl_queue, *args))
mt = time.monotonic()
CL().cl_queue.finish()
et = time.monotonic()
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {(et-st)*1000.0:.2f} ms")
et = time.monotonic() - st
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
if DEBUGCL:
total_runtime = 0
@ -275,7 +281,7 @@ class Thneed:
runtime = (e.profile.end - e.profile.start)
print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {prg.options}")
total_runtime += runtime
print(f"total runtime: {total_runtime/1e6:.2f} ms")
print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms")
def optimize_local_workgroup(self):
MAX_WORKGROUP = CL.cl_ctx.devices[0].max_work_group_size

View File

@ -2,7 +2,8 @@
import os, time, io, pathlib, sys
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent))
os.environ['OPT'] = '99'
if os.getenv("OPT", None) is None:
os.environ['OPT'] = '99'
if os.getenv("GPU", None) is None:
os.environ['OPENCL'] = '1'
@ -70,7 +71,7 @@ def compile(dat, output_fn):
# initial run(s) to load weights
for _ in range(2):
st = time.monotonic()
tinygrad_out = run_onnx(inputs)['outputs']
tinygrad_out = next(iter(run_onnx(inputs).values()))
mt = time.monotonic()
tinygrad_out.realize()
mt2 = time.monotonic()
@ -87,7 +88,7 @@ def compile(dat, output_fn):
# real run
inputs, np_inputs = get_random_input_tensors(input_shapes)
print("***** REAL RUN *****")
tinygrad_out = run_onnx(inputs)['outputs']
tinygrad_out = next(iter(run_onnx(inputs).values()))
# note, since CL.CACHE is enabled, it doesn't actually run the kernels
CL.CACHE = []
@ -101,7 +102,8 @@ def compile(dat, output_fn):
from extra.thneed import Thneed
t = Thneed(CL.CACHE, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()})
CL.CACHE = None
t.optimize_local_workgroup()
if int(os.getenv("OPTWG", "0")):
t.optimize_local_workgroup()
# save thneed (before run)
t.save(output_fn)

View File

@ -14,7 +14,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato
else:
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps]
tst = [Tensor(x.detach().numpy(), requires_grad=True) for x in ts]
tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts]
st = time.monotonic()
out = torch_fxn(*ts)
@ -225,6 +225,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-2)
#@unittest.skip("not supported with IMAGE=1")
def test_large_bs_conv(self):
# large batch size can cause OpenCL image to exceed max image height on macOS
# (or cause the conv kernel to overflow short sampling coords)
@ -232,6 +233,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2)
#@unittest.skip("not supported with IMAGE=1")
def test_large_ic_conv(self):
# large input channel count can cause OpenCL image to exceed max image width on macOS
helper_test_op([(1,2048,3,3), (1,2048,3,3)],
@ -245,10 +247,15 @@ class TestOps(unittest.TestCase):
lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b), atol=1e-4)
def test_simple_conv2d(self):
helper_test_op([(1,1,9,9), (1,1,3,3)],
helper_test_op([(1,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
def test_nested_conv2d(self):
helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)],
lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2).relu(),
lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2).relu(), atol=1e-4, grad_rtol=1e-5)
# expect reduce nodes == 3
def test_simple_conv2d_nhwc(self):
# weights (from tf): filter_height x filter_width x in_channels x out_channels
@ -301,6 +308,15 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_depthwise_conv2d(self):
bs = 1
groups = 32
rcout = 1
cin = 1
helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_grouped_conv2d(self):
bs = 4
groups = 5

View File

@ -94,6 +94,12 @@ class TestComplexShapeTracker(unittest.TestCase):
self.st.permute(1, 0, 2, 3)
assert self.st.contiguous
def test_fancy_factorize(self):
self.st = ShapeTracker((32, 3, 3, 1))
self.st.strided(*zip((32, 3, 3, 1), (1, 4096, 32, 1)))
self.st.reshape(*(8, 4, 3, 3))
assert len(self.st.views) == 1
def test_super_complex_2_fail(self):
self.st = ShapeTracker((4, 4, 4))
self.st.permute(2, 0, 1)

View File

@ -9,10 +9,7 @@ from functools import partial
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
try:
from termcolor import colored
except ImportError:
colored = lambda x, _: x
from tinygrad.helpers import colored
try:
from tinygrad.llops.ops_gpu import CL
except ImportError:
@ -24,15 +21,12 @@ torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else 'cpu')
def colorize_float(x):
ret = f"{x:7.2f}x"
if colored:
if x < 0.75:
return colored(ret, 'green')
elif x > 1.5:
return colored(ret, 'red')
else:
return colored(ret, 'yellow')
if x < 0.75:
return colored(ret, 'green')
elif x > 1.5:
return colored(ret, 'red')
else:
return ret
return colored(ret, 'yellow')
save_ops, save_mem = 0, 0
CNT = 8
@ -101,6 +95,14 @@ class TestSpeed(unittest.TestCase):
# to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size
def f(a, b): return a.permute(1,0).contiguous()
helper_test_generic_square('permute', N, f, f)
def test_double_permute(self):
N = 64
torch.manual_seed(0)
torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device)
tiny_a = Tensor(torch_a.cpu().numpy())
def f(a): return a.permute(1,0,3,2).contiguous()
helper_test_generic(f"double_permute {tiny_a.shape}", partial(f, torch_a), partial(f, tiny_a))
def test_neg(self):
def f(a, b): return -a
@ -159,6 +161,20 @@ class TestSpeed(unittest.TestCase):
def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2)
helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2)
def test_openpilot_conv2d(self):
bs, in_chans, out_chans = 1,12,32
torch.manual_seed(0)
torch_dat = torch.rand(bs, 64, 128, 12).to(torch_device)
torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1).to(torch_device)
tiny_dat = Tensor(torch_dat.cpu().numpy())
tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1)
tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy())
def f1(): return torch_conv(torch_dat.permute(0,3,1,2))
def f2(): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize()
helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2)
def test_conv2d(self):
torch.manual_seed(0)
for bs in [32]:

View File

@ -6,6 +6,7 @@ def prod(x): return math.prod(x)
def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], tuple) or isinstance(x[0], list) else tuple(x)
def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True
def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line
def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape)))
def shape_to_axis(old_shape, new_shape):

View File

@ -84,7 +84,16 @@ class ExplicitExecAST(DeviceBuffer):
# universal for shape tracked
def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), self)
# TODO: creating a new object is making a copy, breaking the thneed compiler
def contiguous(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
#def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP)
def get_first_reduce(shapes):
for i in range(len(shapes[0])):
if not all_same([x[i] for x in shapes]):
return i
return len(shapes[0]) # off the end
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
class ASTKernel:
@ -105,6 +114,9 @@ class ASTKernel:
assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape"
assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size"
# key for lookup in cache (can change, str might not be right)
self.key = str(ast)
def process(self):
# get shape, strides, and offset
# if it's a multiview buffer we take the final view
@ -121,11 +133,7 @@ class ASTKernel:
strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in strides]
# find first mismatch, don't reduce this
first_reduce = -1
for i in range(len(shapes[0])):
if not all_same([x[i] for x in shapes]):
first_reduce = i
break
first_reduce = get_first_reduce(shapes)
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
@ -144,10 +152,14 @@ class ASTKernel:
else:
rets[j].append((shapes[j][i], strides[j][i]))
self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets]
self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged
# include the offsets (as is)
self.offsets = [x.st.views[-1].offset for x in self.bufs]
@property
def shape_len(self): return len(self.shapes[0])
# this should be aware of the three parts to the shape
# * the input/output dimensions
# * the reduce dimensions

View File

@ -1,9 +1,13 @@
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import os
import functools
from typing import Tuple, Union, List
from tinygrad.helpers import prod
# TODO: fix DEBUG import
DEBUG = int(os.getenv("DEBUG", "0"))
def divmodidx(acc, d, mod=True):
lr = f"(idx//{acc})" if acc != 1 else "idx"
return f"({lr}%{d})" if mod else lr # don't mod the top shape dimension
@ -102,7 +106,7 @@ class ShapeTracker:
return self
def reshape(self, *new_shape):
assert all(isinstance(x, int) for x in new_shape)
assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}"
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
# check if this is adding or removing 1s (only)
@ -127,21 +131,26 @@ class ShapeTracker:
while len(new_strides) != len(new_shape):
assert new_shape[len(new_strides)] == 1
new_strides.append(1)
self.views[-1] = View(new_shape, new_strides, self.offset)
return self # early return, it factorized!
break
curr_dim, curr_stride = min_shape_strides.pop(0)
else:
break # didn't factorize
if len(new_shape) == len(new_strides):
self.views[-1] = View(new_shape, new_strides, self.offset)
return self
view = View(new_shape, strides_for_shape(new_shape))
if self.contiguous:
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
else:
if DEBUG >= 2:
print(f"WARNING: reshape from {self.shape} w strides {self.strides} -> {new_shape} is creating another view")
self.views.append(view)
return self
def permute(self, *axis):
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis)
assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}"
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}"
self.views[-1] = View([self.shape[a] for a in axis], [self.strides[a] for a in axis], self.offset)
return self