canonicalize simplex lt (#6658)

(X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints
This commit is contained in:
chenyu 2024-09-22 23:04:47 -04:00 committed by GitHub
parent 46e360fdc0
commit 1923932339
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 32 additions and 3 deletions

View File

@ -197,7 +197,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=43 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)

View File

@ -120,7 +120,7 @@ class TestValidSimplification(unittest.TestCase):
# TODO: simplify further
self.assertEqual(render(shape, valid, idx),
"(((((idx1*8)+ridx2)<1)!=1)?read_imagef(data0, smp, (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
"(((((idx2+ridx1)<1)!=1)&(((idx1+ridx2)<1)!=1))?read_imagef(data0, smp, (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
@ -141,7 +141,7 @@ class TestValidSimplification(unittest.TestCase):
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
self.assertEqual(render(shape, valid, idx),
"(((((idx1*8)+ridx2)<1)!=1)?read_imagef(data0, smp, (int2)((((idx1*24)+(ridx2*3)+ridx0+765)%768),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
"(((((idx2+ridx1)<1)!=1)&(((idx1+ridx2)<1)!=1))?read_imagef(data0, smp, (int2)((((idx1*24)+(ridx2*3)+ridx0+765)%768),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)

View File

@ -448,6 +448,19 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))")
def test_simplex_lt(self):
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
d = Variable("d", -3, 3)
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)")
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*(-3))+(b*4))<1)!=1)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
@unittest.skip("not supported on uops yet")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):

View File

@ -163,6 +163,19 @@ def fold_unrolled_divs(divs:UOp):
# ***** image load valid simplification *****
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in _get_chain(X, BinaryOps.ADD):
# assumed the const is the last src of MUL
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) and u.vmin >= 0): return None
ret.append(u)
return functools.reduce(operator.add, ret) if changed else None
def is_increasing(f:UOp):
# is f a monotonically increasing function regards its input
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
@ -415,6 +428,9 @@ constant_folder = PatternMatcher([
# generic lt folding
(UPat.var("x").lt(UPat.cvar("c", vec=False)),
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# canonicalize a simplex with positive coefficients > 0
# not x < 1 -> X > 0
(UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None),
# ** div **
# # div folding
(UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c: