From 1923932339866c64a8065906fbf1c77ae3ae80e4 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 22 Sep 2024 23:04:47 -0400 Subject: [PATCH] canonicalize simplex lt (#6658) (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints --- .github/workflows/test.yml | 2 +- test/unit/test_image_valid.py | 4 ++-- test/unit/test_uop_symbolic.py | 13 +++++++++++++ tinygrad/codegen/uopgraph.py | 16 ++++++++++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b00ffb0f..2afc877f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -197,7 +197,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model compile and size run: | - PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=43 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model correctness (float32) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index ebc05ae0..e960855d 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -120,7 +120,7 @@ class TestValidSimplification(unittest.TestCase): # TODO: simplify further self.assertEqual(render(shape, valid, idx), - "(((((idx1*8)+ridx2)<1)!=1)?read_imagef(data0, smp, (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501 + "(((((idx2+ridx1)<1)!=1)&(((idx1+ridx2)<1)!=1))?read_imagef(data0, smp, (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501 def test_openpilot_conv2(self): # conv in test/external/external_test_valid_remove.py @@ -141,7 +141,7 @@ class TestValidSimplification(unittest.TestCase): idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))) self.assertEqual(render(shape, valid, idx), - "(((((idx1*8)+ridx2)<1)!=1)?read_imagef(data0, smp, (int2)((((idx1*24)+(ridx2*3)+ridx0+765)%768),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501 + "(((((idx2+ridx1)<1)!=1)&(((idx1+ridx2)<1)!=1))?read_imagef(data0, smp, (int2)((((idx1*24)+(ridx2*3)+ridx0+765)%768),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501 def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 5dc0b1bb..ad222612 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -448,6 +448,19 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)") self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))") + def test_simplex_lt(self): + a = Variable("a", 0, 3) + b = Variable("b", 0, 3) + c = Variable("c", 0, 3) + d = Variable("d", -3, 3) + self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)") + self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)") + self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)") + self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*(-3))+(b*4))<1)!=1)") # negative coeff, should not be simplified + self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified + self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)") + self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)") + @unittest.skip("not supported on uops yet") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 3c1943e4..cdefdeff 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -163,6 +163,19 @@ def fold_unrolled_divs(divs:UOp): # ***** image load valid simplification ***** +def canonicalize_simplex(X:UOp) -> Optional[UOp]: + # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. + # returns x0 + x1 + ... in such case, or None if not + changed, ret = False, [] + for u in _get_chain(X, BinaryOps.ADD): + # assumed the const is the last src of MUL + if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: + changed = True + u = u.src[0] + if not (u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) and u.vmin >= 0): return None + ret.append(u) + return functools.reduce(operator.add, ret) if changed else None + def is_increasing(f:UOp): # is f a monotonically increasing function regards its input if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True @@ -415,6 +428,9 @@ constant_folder = PatternMatcher([ # generic lt folding (UPat.var("x").lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None), + # canonicalize a simplex with positive coefficients > 0 + # not x < 1 -> X > 0 + (UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # # div folding (UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c: