mirror of https://github.com/commaai/tinygrad.git
UOps.RANGE is_increasing (#6615)
* UOps.RANGE is_increasing 283 -> 47 valids * test
This commit is contained in:
parent
76aa6416d7
commit
b14c1bc417
|
@ -197,7 +197,7 @@ jobs:
|
|||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=283 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=47 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
|
|
|
@ -186,7 +186,7 @@ if __name__ == "__main__":
|
|||
"arg_size": [8]*len(ei.bufs),
|
||||
})
|
||||
|
||||
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", 0)):
|
||||
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
|
||||
assert gated_read_image_count <= allowed_gated_read_image, \
|
||||
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
|
||||
|
||||
|
|
|
@ -39,6 +39,10 @@ class TestHelpers(unittest.TestCase):
|
|||
self.assertTrue(is_increasing(f2))
|
||||
self.assertTrue(is_increasing(f3))
|
||||
|
||||
rng = UOp(UOps.RANGE, dtypes.int, arg=(2, True), src=(UOp(UOps.CONST, dtypes.int, arg=0, src=()), UOp(UOps.CONST, dtypes.int, arg=5, src=()),))
|
||||
self.assertTrue(is_increasing(rng))
|
||||
self.assertTrue(is_increasing(rng+2))
|
||||
|
||||
class TestValidSimplification(unittest.TestCase):
|
||||
def test_idx_gt_c(self):
|
||||
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
|
||||
|
|
|
@ -165,7 +165,7 @@ def fold_unrolled_divs(divs:UOp):
|
|||
|
||||
def is_increasing(f:UOp):
|
||||
# is f a monotonically increasing function regards its input
|
||||
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL]: return True
|
||||
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
|
Loading…
Reference in New Issue