From a37e92081aa66fdf66c7c7a1a8215d744ff285a5 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 19 Sep 2024 09:03:01 -0400 Subject: [PATCH] fix unrolled arange folding (#6606) * fix unrolled arange folding also added flop test to test_arange to make sure it's 0 flop * skip PTX --- test/test_arange.py | 15 +++++++++------ test/unit/test_uop_symbolic.py | 10 ++++------ tinygrad/codegen/uopgraph.py | 21 +++++++++++---------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/test/test_arange.py b/test/test_arange.py index 8d44ae1c..09597cec 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -23,18 +23,21 @@ class TestArange(unittest.TestCase): np.testing.assert_equal(tt.numpy(), np.arange(N)) return p.op_estimate - def test_complexity(self, opts=None): + def test_complexity(self, opts=None, limit=None): # add 1 to avoid divide by 0. arange is 0 flops now! f1 = self._get_flops(256, opts) + 1 f2 = self._get_flops(2560, opts) + 1 print(f"{f1=}, {f2=}") assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X" + if limit is not None and not getenv("PTX"): + # PTX counts index ALU in flops + assert f1 <= limit, f"{f1=}, {limit=}" - def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)]) - def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)]) - def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)]) - def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)]) - def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)]) + def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1) + def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1) + def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1) + def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1) + def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1) @unittest.skip("doesn't work yet") def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)]) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 211d6328..d65a53a9 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -431,15 +431,13 @@ class TestSymbolic(unittest.TestCase): def test_arange_unrolled4(self): gidx = Variable("gidx", 0, 2559) - alu0 = -gidx - unrolled_div = (alu0+2561)//-4+(alu0+2562)//-4+(alu0+2560)//-4+(alu0+2559)//-4+2559 - self.helper_test_variable(unrolled_div, 0, 2559, "gidx") + unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4 + self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)") def test_arange_unrolled2(self): gidx = Variable("gidx", 0, 2559) - alu0 = -gidx - unrolled_div = (alu0+2559)//-2+(alu0+2560)//-2+2559 - self.helper_test_variable(unrolled_div, 0, 2559, "gidx") + unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3 + self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)") def test_gated_load(self): idx = Variable("idx", 0, 24) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 1582f735..7d05226e 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -146,19 +146,20 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]: if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1]) return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None -def fold_unrolled_divs(divs:UOp, c:UOp): +def fold_unrolled_divs(divs:UOp): # div pattern in unrolled arange - # example: (-x+2561)//-4+(-x+2562)//-4+(-x+2560)//-4+(-x+2559)//-4+2559 -> x + # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x add_chain, seen_const, ans = list(_get_chain(divs, BinaryOps.ADD)), [], None for u in add_chain: - if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==-len(add_chain)): return None + if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST and u.src[1].arg==len(add_chain)): return None # assumed CONST is the last of an ADD - if not ((s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST): return None - if not ((neg:=s0.src[0]).op is UOps.ALU and neg.arg is BinaryOps.MUL and neg.src[1].op is UOps.CONST and neg.src[1].arg==-1): return None - if ans is None: ans = neg.src[0] - if ans != neg.src[0]: return None - seen_const.append(s0.src[1].arg) - return ans if sorted(seen_const)==list(range(c.arg, c.arg+len(add_chain))) and ans is not None and (ans.vmin, ans.vmax)==(0, c.arg) else None + if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST: + seen_const.append(s0.src[1].arg) + s0 = s0.src[0] + else: seen_const.append(0) + if ans is None: ans = s0 + if ans != s0: return None + return ans if ans is not None and sorted(seen_const)==list(range(len(add_chain))) else None # ***** image load valid simplification ***** @@ -349,7 +350,7 @@ constant_folder = PatternMatcher([ .where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),), arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse), # unrolled arange div folding - (UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs), + (UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), # indexing, with cast or where (UPat(UOps.REDUCE, src=(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()* UPat(UOps.LOAD, src=(UPat.var("buf"), UPat.any(UPat.var("add")+UPat.var("mul")*UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),