mirror of https://github.com/commaai/tinygrad.git
unroll arange is broken (#5918)
* unroll arange is broken * fix unrolled arange * one more test
This commit is contained in:
parent
6740a0a6a0
commit
42f599870c
|
@ -3,21 +3,32 @@ import numpy as np
|
|||
from tinygrad import Tensor, GlobalCounters, dtypes
|
||||
from tinygrad.helpers import Context, getenv
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, Kernel
|
||||
|
||||
class TestArange(unittest.TestCase):
|
||||
def _get_flops(self, N):
|
||||
def _get_flops(self, N, opts=None):
|
||||
GlobalCounters.reset()
|
||||
with Context(NOOPT=1):
|
||||
Tensor.arange(N).realize()
|
||||
return GlobalCounters.global_ops
|
||||
sched = Tensor.arange(N).schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
k = Kernel(sched[-1].ast)
|
||||
if opts is not None:
|
||||
for o in opts: k.apply_opt(o)
|
||||
p = k.to_program()
|
||||
print(p.name)
|
||||
print(p.src)
|
||||
return p.op_estimate
|
||||
|
||||
def test_complexity(self):
|
||||
def test_complexity(self, opts=None):
|
||||
# add 1 to avoid divide by 0. arange is 0 flops now!
|
||||
f1 = self._get_flops(256) + 1
|
||||
f2 = self._get_flops(2560) + 1
|
||||
f1 = self._get_flops(256, opts) + 1
|
||||
f2 = self._get_flops(2560, opts) + 1
|
||||
print(f"{f1=}, {f2=}")
|
||||
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
||||
|
||||
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)])
|
||||
def test_complexity_w_unroll(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)])
|
||||
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)])
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def test_arange_2_reduce(self):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
|
|
|
@ -132,7 +132,7 @@ def reduce_before_expand(reduce, expand, x):
|
|||
red = UOp(UOps.REDUCE, x.dtype, (x,)+reduce.src[1:], reduce.arg)
|
||||
return UOp(expand.op, expand.dtype, tuple(UOp(UOps.GEP, reduce.dtype, (red,), i) for i in range(x.dtype.count)), expand.arg)
|
||||
|
||||
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None):
|
||||
def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, reduce, idx2=None, idx3=None, extra=None):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
|
||||
if mval.arg >= 0 or loop_start.arg != 0:
|
||||
# TODO: support and test this with other mvals and loop_starts
|
||||
|
@ -141,8 +141,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst, rng, redu
|
|||
if idx2 is not None: idx = idx + idx2
|
||||
if idx3 is not None: idx = idx + idx3
|
||||
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
|
||||
return UOp(UOps.REDUCE, reduce.dtype, (comprange.cast(multconst.dtype) * multconst,) +
|
||||
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
||||
ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
|
||||
return ret
|
||||
|
||||
def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
||||
if rng not in reduce.src: return None
|
||||
|
@ -181,8 +183,12 @@ constant_folder = PatternMatcher([
|
|||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") - NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
||||
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
||||
lambda **kwargs: loop_collapse(mval=UOp.const(dtypes.int, -1), **kwargs)),
|
||||
# arange loop folding (unrolled)
|
||||
(NOp(UOps.REDUCE, src=((NOp.var("idx") + NOp.cvar("mval") * NOp(UOps.RANGE, src=(NOp.var("loop_start"), NOp.var("loop_end")), name="rng"))
|
||||
.lt(NOp.cvar("compval")).where(NOp.cvar("multconst"), NOp.const(None, 0)) + NOp.var("extra"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# indexing (with a multiply offset)!
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
|
||||
|
@ -499,6 +505,7 @@ class UOpGraph:
|
|||
UOpGraph.cnt += 1
|
||||
if UOpGraph.cnt != getenv("DEBUG_EXPAND", 0):
|
||||
sink = graph_rewrite(sink, self.folder+expander+float4_folding if self.opts is not None and self.opts.supports_float4 else self.folder+expander)
|
||||
if UOpGraph.cnt != getenv("DEBUG_REDUCE", 0):
|
||||
sink = graph_rewrite(sink, self.folder+expander+reducer)
|
||||
|
||||
# for PTX only
|
||||
|
|
|
@ -218,7 +218,7 @@ def type_verify(uops):
|
|||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
||||
|
@ -228,7 +228,8 @@ def type_verify(uops):
|
|||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
|
||||
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
|
|
Loading…
Reference in New Issue