x//a<b -> x <a*b for positive a (#6622)

openpilot valids 47 -> 37
This commit is contained in:
chenyu 2024-09-20 04:38:47 -04:00 committed by GitHub
parent 72c7087420
commit 5707503048
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 1 deletions

View File

@ -197,7 +197,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=47 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)

View File

@ -69,6 +69,9 @@ class TestValidSimplification(unittest.TestCase):
gidx1 = Special("gidx1", 32)
self.assertEqual(render((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
# same thing, valid has a div
self.assertEqual(render((10, 10, 4), (gidx1//2).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
# 10x20 image, not out of bound
self.assertEqual(render((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
"((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))")

View File

@ -445,6 +445,11 @@ class TestSymbolic(unittest.TestCase):
# TODO: simplify the true branch
self.helper_test_variable(idx.lt(4).where(idx//4, idx.const_like(-1)), -1, 6, "((idx<4)?(idx//4):(-1))")
def test_idiv_lt(self):
idx = Variable("idx", 0, 24)
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))")
@unittest.skip("not supported on uops yet")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):

View File

@ -405,6 +405,9 @@ constant_folder = PatternMatcher([
# c0*x<c1 for negative int c0 and non-positive c1
((UPat.cvar("c0", vec=False)*UPat.var("x")).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: (-x).lt(-(math.floor(-c1.arg/-c0.arg))) if dtypes.is_int(x.dtype) and c0.arg < 0 and c0.arg != -1 and c1.arg <= 0 else None),
# x//c0<c1 for positive int c0
((UPat.var("x")//UPat.cvar("c0", vec=False)).lt(UPat.cvar("c1", vec=False)),
lambda x,c0,c1: x.lt(c1.arg*c0.arg) if dtypes.is_int(x.dtype) and c0.arg > 0 else None),
# mul add lt
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),