From 4885fce56efae12e76e00356e75ef96c50e18623 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 9 Jan 2023 12:40:01 -0800 Subject: [PATCH] shapetracker from newgpu (#456) * shapetracker from newgpu * touchup ops * test * testst * thneed deletes unused inputs * test * bugfix --- accel/llvm/ops_llvm.py | 7 +++---- extra/thneed.py | 18 +++++++++++------ openpilot/compile.py | 10 ++++++---- test/test_ops.py | 20 +++++++++++++++++-- test/test_shapetracker.py | 6 ++++++ test/test_speed_v_torch.py | 40 ++++++++++++++++++++++++++------------ tinygrad/helpers.py | 1 + tinygrad/ops.py | 22 ++++++++++++++++----- tinygrad/shapetracker.py | 17 ++++++++++++---- 9 files changed, 104 insertions(+), 37 deletions(-) diff --git a/accel/llvm/ops_llvm.py b/accel/llvm/ops_llvm.py index d8b2dd00..5652cc7a 100644 --- a/accel/llvm/ops_llvm.py +++ b/accel/llvm/ops_llvm.py @@ -187,9 +187,8 @@ class LLVMBuffer(ExplicitExecAST): k = ASTKernel(ast) # cached kernel - key = str(ast) # TODO: does this uniquely determine the AST? No! The shapetracker can change. Do this better. - if key in LLVMBuffer.func_cache: - LLVMBuffer.func_cache[key](*[x._buf for x in k.bufs]) + if k.key in LLVMBuffer.func_cache: + LLVMBuffer.func_cache[k.key](*[x._buf for x in k.bufs]) return k.ret # cache miss, we have to process the kernel @@ -362,5 +361,5 @@ class LLVMBuffer(ExplicitExecAST): loop_entry[-1].branch(loop_exit[-1]._block) loop_exit[0].ret_void() - LLVMBuffer.func_cache[key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs)) + LLVMBuffer.func_cache[k.key] = LLVM().exec(module, k.bufs, k.info.flops, sum(len(x._buf) for x in k.bufs)) return k.ret diff --git a/extra/thneed.py b/extra/thneed.py index 58d4124a..4a957dfc 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -35,9 +35,15 @@ class Thneed: if len(nodes[n]['out_edges']) == 0: self.outputs.append(n) - for n in self.inputs.values(): - assert n in self.buffers_to_save, f"{n} was not an input" - self.buffers_to_save.remove(n) + fake_inputs = [] + for k,n in self.inputs.items(): + if n in self.buffers_to_save: + self.buffers_to_save.remove(n) + else: + print(f"WARNING: {k} was not a used input, removing it") + fake_inputs.append(k) + for k in fake_inputs: + del self.inputs[k] def load(self, input_fn): float32 = not FLOAT16 @@ -266,8 +272,8 @@ class Thneed: events.append(prg.clprg(CL().cl_queue, *args)) mt = time.monotonic() CL().cl_queue.finish() - et = time.monotonic() - print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {(et-st)*1000.0:.2f} ms") + et = time.monotonic() - st + print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms") if DEBUGCL: total_runtime = 0 @@ -275,7 +281,7 @@ class Thneed: runtime = (e.profile.end - e.profile.start) print(f"{i:3d} time {total_runtime/1e6:5.2f} ms running {prg.name:20s} with {str(args[0]):15s} {str(args[1]):15s} count {len(args)-2:2d} runtime {runtime/1e3:7.2f} us {prg.options}") total_runtime += runtime - print(f"total runtime: {total_runtime/1e6:.2f} ms") + print(f"total runtime: {total_runtime/1e6:.2f} ms wall time: {et*1000.0:.2f} ms") def optimize_local_workgroup(self): MAX_WORKGROUP = CL.cl_ctx.devices[0].max_work_group_size diff --git a/openpilot/compile.py b/openpilot/compile.py index 3eaa8807..2e990253 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -2,7 +2,8 @@ import os, time, io, pathlib, sys sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) -os.environ['OPT'] = '99' +if os.getenv("OPT", None) is None: + os.environ['OPT'] = '99' if os.getenv("GPU", None) is None: os.environ['OPENCL'] = '1' @@ -70,7 +71,7 @@ def compile(dat, output_fn): # initial run(s) to load weights for _ in range(2): st = time.monotonic() - tinygrad_out = run_onnx(inputs)['outputs'] + tinygrad_out = next(iter(run_onnx(inputs).values())) mt = time.monotonic() tinygrad_out.realize() mt2 = time.monotonic() @@ -87,7 +88,7 @@ def compile(dat, output_fn): # real run inputs, np_inputs = get_random_input_tensors(input_shapes) print("***** REAL RUN *****") - tinygrad_out = run_onnx(inputs)['outputs'] + tinygrad_out = next(iter(run_onnx(inputs).values())) # note, since CL.CACHE is enabled, it doesn't actually run the kernels CL.CACHE = [] @@ -101,7 +102,8 @@ def compile(dat, output_fn): from extra.thneed import Thneed t = Thneed(CL.CACHE, {k:inputs[k].lazydata.realized.cl for k in inputs.keys()}) CL.CACHE = None - t.optimize_local_workgroup() + if int(os.getenv("OPTWG", "0")): + t.optimize_local_workgroup() # save thneed (before run) t.save(output_fn) diff --git a/test/test_ops.py b/test/test_ops.py index 086cfb30..a6937b35 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -14,7 +14,7 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_ato else: ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps] - tst = [Tensor(x.detach().numpy(), requires_grad=True) for x in ts] + tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts] st = time.monotonic() out = torch_fxn(*ts) @@ -225,6 +225,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x, w), lambda x,w: x.conv2d(w), atol=1e-2) + #@unittest.skip("not supported with IMAGE=1") def test_large_bs_conv(self): # large batch size can cause OpenCL image to exceed max image height on macOS # (or cause the conv kernel to overflow short sampling coords) @@ -232,6 +233,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x, w), lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2) + #@unittest.skip("not supported with IMAGE=1") def test_large_ic_conv(self): # large input channel count can cause OpenCL image to exceed max image width on macOS helper_test_op([(1,2048,3,3), (1,2048,3,3)], @@ -245,10 +247,15 @@ class TestOps(unittest.TestCase): lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b), atol=1e-4) def test_simple_conv2d(self): - helper_test_op([(1,1,9,9), (1,1,3,3)], + helper_test_op([(1,4,9,9), (4,4,3,3)], lambda x,w: torch.nn.functional.conv2d(x,w).relu(), lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) + def test_nested_conv2d(self): + helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)], + lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2).relu(), + lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2).relu(), atol=1e-4, grad_rtol=1e-5) + # expect reduce nodes == 3 def test_simple_conv2d_nhwc(self): # weights (from tf): filter_height x filter_width x in_channels x out_channels @@ -301,6 +308,15 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) + def test_depthwise_conv2d(self): + bs = 1 + groups = 32 + rcout = 1 + cin = 1 + helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)], + lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(), + lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5) + def test_grouped_conv2d(self): bs = 4 groups = 5 diff --git a/test/test_shapetracker.py b/test/test_shapetracker.py index 80c7feaf..0b71fe5d 100644 --- a/test/test_shapetracker.py +++ b/test/test_shapetracker.py @@ -94,6 +94,12 @@ class TestComplexShapeTracker(unittest.TestCase): self.st.permute(1, 0, 2, 3) assert self.st.contiguous + def test_fancy_factorize(self): + self.st = ShapeTracker((32, 3, 3, 1)) + self.st.strided(*zip((32, 3, 3, 1), (1, 4096, 32, 1))) + self.st.reshape(*(8, 4, 3, 3)) + assert len(self.st.views) == 1 + def test_super_complex_2_fail(self): self.st = ShapeTracker((4, 4, 4)) self.st.permute(2, 0, 1) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 35345796..f182646e 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -9,10 +9,7 @@ from functools import partial from tinygrad.ops import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d -try: - from termcolor import colored -except ImportError: - colored = lambda x, _: x +from tinygrad.helpers import colored try: from tinygrad.llops.ops_gpu import CL except ImportError: @@ -24,15 +21,12 @@ torch_device = torch.device('mps' if int(os.getenv("MPS", "0")) else 'cpu') def colorize_float(x): ret = f"{x:7.2f}x" - if colored: - if x < 0.75: - return colored(ret, 'green') - elif x > 1.5: - return colored(ret, 'red') - else: - return colored(ret, 'yellow') + if x < 0.75: + return colored(ret, 'green') + elif x > 1.5: + return colored(ret, 'red') else: - return ret + return colored(ret, 'yellow') save_ops, save_mem = 0, 0 CNT = 8 @@ -101,6 +95,14 @@ class TestSpeed(unittest.TestCase): # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size def f(a, b): return a.permute(1,0).contiguous() helper_test_generic_square('permute', N, f, f) + + def test_double_permute(self): + N = 64 + torch.manual_seed(0) + torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device) + tiny_a = Tensor(torch_a.cpu().numpy()) + def f(a): return a.permute(1,0,3,2).contiguous() + helper_test_generic(f"double_permute {tiny_a.shape}", partial(f, torch_a), partial(f, tiny_a)) def test_neg(self): def f(a, b): return -a @@ -159,6 +161,20 @@ class TestSpeed(unittest.TestCase): def f2(a, b): return (a.permute(1,0).reshape(N, 1, N).expand(N, N, N) * b.permute(1,0).reshape(1, N, N).expand(N, N, N)).sum(axis=2) helper_test_generic_square('gemm_unrolled_permute_lr', N, f1, f2) + def test_openpilot_conv2d(self): + bs, in_chans, out_chans = 1,12,32 + torch.manual_seed(0) + torch_dat = torch.rand(bs, 64, 128, 12).to(torch_device) + torch_conv = torch.nn.Conv2d(in_chans, out_chans, 3, bias=None, padding=1).to(torch_device) + + tiny_dat = Tensor(torch_dat.cpu().numpy()) + tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) + tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) + + def f1(): return torch_conv(torch_dat.permute(0,3,1,2)) + def f2(): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize() + helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2) + def test_conv2d(self): torch.manual_seed(0) for bs in [32]: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 91c8da66..722bea3f 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -6,6 +6,7 @@ def prod(x): return math.prod(x) def argfix(*x): return tuple() if len(x) == 0 else tuple(x[0]) if isinstance(x[0], tuple) or isinstance(x[0], list) else tuple(x) def argsort(x): return sorted(range(len(x)), key=x.__getitem__) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python def all_same(items): return all(x == items[0] for x in items) if len(items) > 0 else True +def colored(st, color): return f"\u001b[{30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color)}m{st}\u001b[0m" # replace the termcolor library with one line def reduce_shape(shape, axis): return tuple(1 if i in axis else shape[i] for i in range(len(shape))) def shape_to_axis(old_shape, new_shape): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1debd2fe..15e6da31 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -84,7 +84,16 @@ class ExplicitExecAST(DeviceBuffer): # universal for shape tracked def movement_op(self, op:MovementOps, arg): return type(self)(ShapeTracker(self.st).movement_op(op, arg), self) + + # TODO: creating a new object is making a copy, breaking the thneed compiler def contiguous(self): return self if self.st.contiguous else self.unary_op(UnaryOps.NOOP) + #def contiguous(self): return type(self)(self.shape, hostbuf=self) if self.st.contiguous else self.unary_op(UnaryOps.NOOP) + +def get_first_reduce(shapes): + for i in range(len(shapes[0])): + if not all_same([x[i] for x in shapes]): + return i + return len(shapes[0]) # off the end # ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops class ASTKernel: @@ -105,6 +114,9 @@ class ASTKernel: assert all_same([x.shape for x in self.bufs if x not in self.earlybufs]), "all latebufs must have the same shape" assert all_same([len(x.shape) for x in self.bufs]), "all bufs must have the same shape size" + # key for lookup in cache (can change, str might not be right) + self.key = str(ast) + def process(self): # get shape, strides, and offset # if it's a multiview buffer we take the final view @@ -121,11 +133,7 @@ class ASTKernel: strides = [[s[i] for i in range(len(s)) if not all_ones[i]] for s in strides] # find first mismatch, don't reduce this - first_reduce = -1 - for i in range(len(shapes[0])): - if not all_same([x[i] for x in shapes]): - first_reduce = i - break + first_reduce = get_first_reduce(shapes) # merge dimensions if we can, multi get_shape_strides # TODO: does this always preserve the reduce dimension, NO @@ -144,10 +152,14 @@ class ASTKernel: else: rets[j].append((shapes[j][i], strides[j][i])) self.shapes, self.strides = [[y[0] for y in x] for x in rets], [[y[1] for y in x] for x in rets] + self.first_reduce = get_first_reduce(self.shapes) # update this if axis merged # include the offsets (as is) self.offsets = [x.st.views[-1].offset for x in self.bufs] + @property + def shape_len(self): return len(self.shapes[0]) + # this should be aware of the three parts to the shape # * the input/output dimensions # * the reduce dimensions diff --git a/tinygrad/shapetracker.py b/tinygrad/shapetracker.py index 64f7da8a..db9e37ff 100644 --- a/tinygrad/shapetracker.py +++ b/tinygrad/shapetracker.py @@ -1,9 +1,13 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations +import os import functools from typing import Tuple, Union, List from tinygrad.helpers import prod +# TODO: fix DEBUG import +DEBUG = int(os.getenv("DEBUG", "0")) + def divmodidx(acc, d, mod=True): lr = f"(idx//{acc})" if acc != 1 else "idx" return f"({lr}%{d})" if mod else lr # don't mod the top shape dimension @@ -102,7 +106,7 @@ class ShapeTracker: return self def reshape(self, *new_shape): - assert all(isinstance(x, int) for x in new_shape) + assert all(isinstance(x, int) and x != 0 for x in new_shape), f"shape must be ints and can't contain 0 {new_shape}" assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}" # check if this is adding or removing 1s (only) @@ -127,21 +131,26 @@ class ShapeTracker: while len(new_strides) != len(new_shape): assert new_shape[len(new_strides)] == 1 new_strides.append(1) - self.views[-1] = View(new_shape, new_strides, self.offset) - return self # early return, it factorized! + break curr_dim, curr_stride = min_shape_strides.pop(0) else: break # didn't factorize + if len(new_shape) == len(new_strides): + self.views[-1] = View(new_shape, new_strides, self.offset) + return self + view = View(new_shape, strides_for_shape(new_shape)) if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset else: + if DEBUG >= 2: + print(f"WARNING: reshape from {self.shape} w strides {self.strides} -> {new_shape} is creating another view") self.views.append(view) return self def permute(self, *axis): - assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis) + assert all(isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis), f"invalid permute {axis} for {self.shape}" assert len(set(axis)) == len(axis) and len(axis) == len(self.shape), f"can't permute {self.shape} with {axis}" self.views[-1] = View([self.shape[a] for a in axis], [self.strides[a] for a in axis], self.offset) return self