mirror of https://github.com/commaai/tinygrad.git
fix unrolled arange folding (#6606)
* fix unrolled arange folding also added flop test to test_arange to make sure it's 0 flop * skip PTX
This commit is contained in:
parent
eebd23155c
commit
a37e92081a
|
@ -23,18 +23,21 @@ class TestArange(unittest.TestCase):
|
|||
np.testing.assert_equal(tt.numpy(), np.arange(N))
|
||||
return p.op_estimate
|
||||
|
||||
def test_complexity(self, opts=None):
|
||||
def test_complexity(self, opts=None, limit=None):
|
||||
# add 1 to avoid divide by 0. arange is 0 flops now!
|
||||
f1 = self._get_flops(256, opts) + 1
|
||||
f2 = self._get_flops(2560, opts) + 1
|
||||
print(f"{f1=}, {f2=}")
|
||||
assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
||||
if limit is not None and not getenv("PTX"):
|
||||
# PTX counts index ALU in flops
|
||||
assert f1 <= limit, f"{f1=}, {limit=}"
|
||||
|
||||
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)])
|
||||
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)])
|
||||
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)])
|
||||
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)])
|
||||
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)])
|
||||
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1)
|
||||
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1)
|
||||
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1)
|
||||
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1)
|
||||
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
|
||||
|
||||
@unittest.skip("doesn't work yet")
|
||||
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)])
|
||||
|
|
|
@ -431,15 +431,13 @@ class TestSymbolic(unittest.TestCase):
|
|||
|
||||
def test_arange_unrolled4(self):
|
||||
gidx = Variable("gidx", 0, 2559)
|
||||
alu0 = -gidx
|
||||
unrolled_div = (alu0+2561)//-4+(alu0+2562)//-4+(alu0+2560)//-4+(alu0+2559)//-4+2559
|
||||
self.helper_test_variable(unrolled_div, 0, 2559, "gidx")
|
||||
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
|
||||
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
|
||||
|
||||
def test_arange_unrolled2(self):
|
||||
gidx = Variable("gidx", 0, 2559)
|
||||
alu0 = -gidx
|
||||
unrolled_div = (alu0+2559)//-2+(alu0+2560)//-2+2559
|
||||
self.helper_test_variable(unrolled_div, 0, 2559, "gidx")
|
||||
unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3
|
||||
self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)")
|
||||
|
||||
def test_gated_load(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
|
|
|
@ -146,19 +146,20 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
|||
if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1])
|
||||
return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None
|
||||
|
||||
def fold_unrolled_divs(divs:UOp, c:UOp):
|
||||
def fold_unrolled_divs(divs:UOp):
|
||||
# div pattern in unrolled arange
|
||||
# example: (-x+2561)//-4+(-x+2562)//-4+(-x+2560)//-4+(-x+2559)//-4+2559 -> x
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None
|
||||
for u in add_chain:
|
||||
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==-len(add_chain)): return None
|
||||
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==len(add_chain)): return None
|
||||
# assumed CONST is the last of an ADD
|
||||
if not ((s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST): return None
|
||||
if not ((neg:=s0.src[0]).op is UOps.ALU and neg.arg is BinaryOps.MUL and neg.src[1].op is UOps.CONST and neg.src[1].arg==-1): return None
|
||||
if ans is None: ans = neg.src[0]
|
||||
if ans != neg.src[0]: return None
|
||||
seen_const.append(s0.src[1].arg)
|
||||
return ans if sorted(seen_const)==list(range(c.arg, c.arg+len(add_chain))) and ans is not None and (ans.vmin, ans.vmax)==(0, c.arg) else None
|
||||
if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST:
|
||||
seen_const.append(s0.src[1].arg)
|
||||
s0 = s0.src[0]
|
||||
else: seen_const.append(0)
|
||||
if ans is None: ans = s0
|
||||
if ans != s0: return None
|
||||
return ans if ans is not None and sorted(seen_const)==list(range(len(add_chain))) else None
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
|
@ -349,7 +350,7 @@ constant_folder = PatternMatcher([
|
|||
.where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# unrolled arange div folding
|
||||
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
|
||||
(UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
|
||||
# indexing, with cast or where
|
||||
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
|
||||
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
|
||||
|
|
Loading…
Reference in New Issue