fix unrolled arange folding (#6606)

* fix unrolled arange folding

also added flop test to test_arange to make sure it's 0 flop

* skip PTX
This commit is contained in:
chenyu 2024-09-19 09:03:01 -04:00 committed by GitHub
parent eebd23155c
commit a37e92081a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 22 deletions

View File

@ -23,18 +23,21 @@ class TestArange(unittest.TestCase):
np.testing.assert_equal(tt.numpy(), np.arange(N))
return p.op_estimate
def test_complexity(self, opts=None):
def test_complexity(self, opts=None, limit=None):
# add 1 to avoid divide by 0. arange is 0 flops now!
f1 = self._get_flops(256, opts) + 1
f2 = self._get_flops(2560, opts) + 1
print(f"{f1=}, {f2=}")
assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
if limit is not None and not getenv("PTX"):
# PTX counts index ALU in flops
assert f1 <= limit, f"{f1=}, {limit=}"
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)])
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)])
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)])
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)])
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)])
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1)
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1)
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)])

View File

@ -431,15 +431,13 @@ class TestSymbolic(unittest.TestCase):
def test_arange_unrolled4(self):
gidx = Variable("gidx", 0, 2559)
alu0 = -gidx
unrolled_div = (alu0+2561)//-4+(alu0+2562)//-4+(alu0+2560)//-4+(alu0+2559)//-4+2559
self.helper_test_variable(unrolled_div, 0, 2559, "gidx")
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
def test_arange_unrolled2(self):
gidx = Variable("gidx", 0, 2559)
alu0 = -gidx
unrolled_div = (alu0+2559)//-2+(alu0+2560)//-2+2559
self.helper_test_variable(unrolled_div, 0, 2559, "gidx")
unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3
self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)")
def test_gated_load(self):
idx = Variable("idx", 0, 24)

View File

@ -146,19 +146,20 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]:
if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1])
return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None
def fold_unrolled_divs(divs:UOp, c:UOp):
def fold_unrolled_divs(divs:UOp):
# div pattern in unrolled arange
# example: (-x+2561)//-4+(-x+2562)//-4+(-x+2560)//-4+(-x+2559)//-4+2559 -> x
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None
for u in add_chain:
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==-len(add_chain)): return None
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==len(add_chain)): return None
# assumed CONST is the last of an ADD
if not ((s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST): return None
if not ((neg:=s0.src[0]).op is UOps.ALU and neg.arg is BinaryOps.MUL and neg.src[1].op is UOps.CONST and neg.src[1].arg==-1): return None
if ans is None: ans = neg.src[0]
if ans != neg.src[0]: return None
seen_const.append(s0.src[1].arg)
return ans if sorted(seen_const)==list(range(c.arg, c.arg+len(add_chain))) and ans is not None and (ans.vmin, ans.vmax)==(0, c.arg) else None
if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST:
seen_const.append(s0.src[1].arg)
s0 = s0.src[0]
else: seen_const.append(0)
if ans is None: ans = s0
if ans != s0: return None
return ans if ans is not None and sorted(seen_const)==list(range(len(add_chain))) else None
# ***** image load valid simplification *****
@ -349,7 +350,7 @@ constant_folder = PatternMatcher([
.where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
# unrolled arange div folding
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
(UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
# indexing, with cast or where
(UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*
UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),