From 42f599870cebd1d1e5c9ed14becc2c6130b3a164 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:15:07 -0700 Subject: [PATCH] unroll arange is broken (#5918) * unroll arange is broken * fix unrolled arange * one more test --- test/test_arange.py | 25 ++++++++++++++++++------- tinygrad/codegen/uopgraph.py | 17 ++++++++++++----- tinygrad/codegen/uops.py | 5 +++-- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/test/test_arange.py b/test/test_arange.py index 32d2e045..cdec73b3 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -3,21 +3,32 @@ import numpy as np from tinygrad import Tensor, GlobalCounters, dtypes from tinygrad.helpers import Context, getenv from tinygrad.engine.realize import run_schedule +from tinygrad.codegen.kernel import Opt, OptOps, Kernel class TestArange(unittest.TestCase): - def _get_flops(self, N): + def _get_flops(self, N, opts=None): GlobalCounters.reset() - with Context(NOOPT=1): - Tensor.arange(N).realize() - return GlobalCounters.global_ops + sched = Tensor.arange(N).schedule() + self.assertEqual(len(sched), 1) + k = Kernel(sched[-1].ast) + if opts is not None: + for o in opts: k.apply_opt(o) + p = k.to_program() + print(p.name) + print(p.src) + return p.op_estimate - def test_complexity(self): + def test_complexity(self, opts=None): # add 1 to avoid divide by 0. arange is 0 flops now! - f1 = self._get_flops(256) + 1 - f2 = self._get_flops(2560) + 1 + f1 = self._get_flops(256, opts) + 1 + f2 = self._get_flops(2560, opts) + 1 print(f"{f1=}, {f2=}") assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X" + def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)]) + def test_complexity_w_unroll(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)]) + def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)]) + class TestIndexing(unittest.TestCase): def test_arange_2_reduce(self): needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous() diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 758dd2e3..4aeaf8a0 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -132,7 +132,7 @@ def reduce_before_expand(reduce, expand, x): red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg) return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg) -def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None): +def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None): if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE if mval.arg >= 0 or loop_start.arg != 0: # TODO: support and test this with other mvals and loop_starts @@ -141,8 +141,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu if idx2 is not None: idx = idx + idx2 if idx3 is not None: idx = idx + idx3 comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start)) - return UOp(UOps.REDUCE, reduce.dtype, (comprange.cast(multconst.dtype) * multconst,) + - tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg) + new_reduce_op = comprange.cast(multconst.dtype) * multconst + ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg) + if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg) + return ret def index_collapse(idx,rng,buf,add,mul,ld,reduce): if rng not in reduce.src: return None @@ -181,8 +183,12 @@ constant_folder = PatternMatcher([ (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng")) .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), (NOp(UOps.REDUCE, src=((NOp.var("idx") - NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng")) - .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), - lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)), + .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), + lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)), + # arange loop folding (unrolled) + (NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng")) + .lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),), + arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # indexing (with a multiply offset)! (NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()* NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),), @@ -499,6 +505,7 @@ class UOpGraph: UOpGraph.cnt += 1 if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0): sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander) + if UOpGraph.cnt != getenv("DEBUG_REDUCE", 0): sink = graph_rewrite(sink, self.folder+expander+reducer) # for PTX only diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 65bc9b41..ed392a78 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -218,7 +218,7 @@ def type_verify(uops): if uop is UOps.ALU: if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: - assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" + assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}" assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg is BinaryOps.IDIV: assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}" @@ -228,7 +228,8 @@ def type_verify(uops): assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg == TernaryOps.WHERE: - assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}" + assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \ + f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" def uop_alu_resolve(u:UOp) -> sint: