From b14c1bc417a917af0ad6d77e3c0bbd747d808beb Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 20 Sep 2024 03:14:52 -0400 Subject: [PATCH] UOps.RANGE is_increasing (#6615) * UOps.RANGE is_increasing 283 -> 47 valids * test --- .github/workflows/test.yml | 2 +- examples/openpilot/compile2.py | 2 +- test/unit/test_image_valid.py | 4 ++++ tinygrad/codegen/uopgraph.py | 2 +- 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0f3beb6c..fa74c53e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -197,7 +197,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model compile and size run: | - PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=283 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=47 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model correctness (float32) diff --git a/examples/openpilot/compile2.py b/examples/openpilot/compile2.py index a1b8caea..f907097a 100644 --- a/examples/openpilot/compile2.py +++ b/examples/openpilot/compile2.py @@ -186,7 +186,7 @@ if __name__ == "__main__": "arg_size": [8]*len(ei.bufs), }) - if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", 0)): + if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1: assert gated_read_image_count <= allowed_gated_read_image, \ f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}" diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index 83a94d48..f512ea92 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -39,6 +39,10 @@ class TestHelpers(unittest.TestCase): self.assertTrue(is_increasing(f2)) self.assertTrue(is_increasing(f3)) + rng = UOp(UOps.RANGE, dtypes.int, arg=(2, True), src=(UOp(UOps.CONST, dtypes.int, arg=0, src=()), UOp(UOps.CONST, dtypes.int, arg=5, src=()),)) + self.assertTrue(is_increasing(rng)) + self.assertTrue(is_increasing(rng+2)) + class TestValidSimplification(unittest.TestCase): def test_idx_gt_c(self): # (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index fed66304..92d90030 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -165,7 +165,7 @@ def fold_unrolled_divs(divs:UOp): def is_increasing(f:UOp): # is f a monotonically increasing function regards its input - if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL]: return True + if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure