unroll arange is broken (#5918)

* unroll arange is broken

* fix unrolled arange

* one more test
This commit is contained in:
George Hotz 2024-08-05 12:15:07 -07:00 committed by GitHub
parent 6740a0a6a0
commit 42f599870c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 33 additions and 14 deletions

View File

@ -3,21 +3,32 @@ import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.helpers import Context, getenv
from tinygrad.engine.realize import run_schedule
from tinygrad.codegen.kernel import Opt, OptOps, Kernel
class TestArange(unittest.TestCase):
def _get_flops(self, N):
def _get_flops(self, N, opts=None):
GlobalCounters.reset()
with Context(NOOPT=1):
Tensor.arange(N).realize()
return GlobalCounters.global_ops
sched = Tensor.arange(N).schedule()
self.assertEqual(len(sched), 1)
k = Kernel(sched[-1].ast)
if opts is not None:
for o in opts: k.apply_opt(o)
p = k.to_program()
print(p.name)
print(p.src)
return p.op_estimate
def test_complexity(self):
def test_complexity(self, opts=None):
# add 1 to avoid divide by 0. arange is 0 flops now!
f1 = self._get_flops(256) + 1
f2 = self._get_flops(2560) + 1
f1 = self._get_flops(256, opts) + 1
f2 = self._get_flops(2560, opts) + 1
print(f"{f1=}, {f2=}")
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)])
def test_complexity_w_unroll(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)])
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)])
class TestIndexing(unittest.TestCase):
def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()

View File

@ -132,7 +132,7 @@ def reduce_before_expand(reduce, expand, x):
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg)
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None):
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None):
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
if mval.arg >= 0 or loop_start.arg != 0:
# TODO: support and test this with other mvals and loop_starts
@ -141,8 +141,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
if idx2 is not None: idx = idx + idx2
if idx3 is not None: idx = idx + idx3
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
return UOp(UOps.REDUCE, reduce.dtype, (comprange.cast(multconst.dtype) * multconst,) +
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
new_reduce_op = comprange.cast(multconst.dtype) * multconst
ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
return ret
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
if rng not in reduce.src: return None
@ -181,8 +183,12 @@ constant_folder = PatternMatcher([
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
(NOp(UOps.REDUCE, src=((NOp.var("idx") - NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
# arange loop folding (unrolled)
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# indexing (with a multiply offset)!
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
@ -499,6 +505,7 @@ class UOpGraph:
UOpGraph.cnt += 1
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander)
if UOpGraph.cnt != getenv("DEBUG_REDUCE", 0):
sink = graph_rewrite(sink, self.folder+expander+reducer)
# for PTX only

View File

@ -218,7 +218,7 @@ def type_verify(uops):
if uop is UOps.ALU:
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
@ -228,7 +228,8 @@ def type_verify(uops):
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
def uop_alu_resolve(u:UOp) -> sint: