UOps.RANGE is_increasing (#6615)

* UOps.RANGE is_increasing

283 -> 47 valids

* test
This commit is contained in:
chenyu 2024-09-20 03:14:52 -04:00 committed by GitHub
parent 76aa6416d7
commit b14c1bc417
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 7 additions and 3 deletions

View File

@ -197,7 +197,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=283 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=47 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)

View File

@ -186,7 +186,7 @@ if __name__ == "__main__":
"arg_size": [8]*len(ei.bufs),
})
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", 0)):
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
assert gated_read_image_count <= allowed_gated_read_image, \
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"

View File

@ -39,6 +39,10 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(is_increasing(f2))
self.assertTrue(is_increasing(f3))
rng = UOp(UOps.RANGE, dtypes.int, arg=(2, True), src=(UOp(UOps.CONST, dtypes.int, arg=0, src=()), UOp(UOps.CONST, dtypes.int, arg=5, src=()),))
self.assertTrue(is_increasing(rng))
self.assertTrue(is_increasing(rng+2))
class TestValidSimplification(unittest.TestCase):
def test_idx_gt_c(self):
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid

View File

@ -165,7 +165,7 @@ def fold_unrolled_divs(divs:UOp):
def is_increasing(f:UOp):
# is f a monotonically increasing function regards its input
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL]: return True
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure