From 570750304824ae65dff25688d69330c41d4538e3 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 20 Sep 2024 04:38:47 -0400 Subject: [PATCH] x//a x 37 --- .github/workflows/test.yml | 2 +- test/unit/test_image_valid.py | 3 +++ test/unit/test_uop_symbolic.py | 5 +++++ tinygrad/codegen/uopgraph.py | 3 +++ 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fa74c53e..b00ffb0f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -197,7 +197,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model compile and size run: | - PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=47 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model correctness (float32) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index f512ea92..2111a125 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -69,6 +69,9 @@ class TestValidSimplification(unittest.TestCase): gidx1 = Special("gidx1", 32) self.assertEqual(render((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), "read_imagef(data0, smp, (int2)(gidx0,gidx1))") + # same thing, valid has a div + self.assertEqual(render((10, 10, 4), (gidx1//2).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), + "read_imagef(data0, smp, (int2)(gidx0,gidx1))") # 10x20 image, not out of bound self.assertEqual(render((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))), "((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))") diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index d65a53a9..34da16f6 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -445,6 +445,11 @@ class TestSymbolic(unittest.TestCase): # TODO: simplify the true branch self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):(-1))") + def test_idiv_lt(self): + idx = Variable("idx", 0, 24) + self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)") + self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))") + @unittest.skip("not supported on uops yet") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 92d90030..908803ae 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -405,6 +405,9 @@ constant_folder = PatternMatcher([ # c0*x 0 else None), # mul add lt (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)), lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),