mirror of https://github.com/commaai/tinygrad.git
canonicalize simplex lt (#6658)
(X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints
This commit is contained in:
parent
46e360fdc0
commit
1923932339
|
@ -197,7 +197,7 @@ jobs:
|
|||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=43 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
|
|
|
@ -120,7 +120,7 @@ class TestValidSimplification(unittest.TestCase):
|
|||
|
||||
# TODO: simplify further
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"(((((idx1*8)+ridx2)<1)!=1)?read_imagef(data0, smp, (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
"(((((idx2+ridx1)<1)!=1)&(((idx1+ridx2)<1)!=1))?read_imagef(data0, smp, (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
|
||||
def test_openpilot_conv2(self):
|
||||
# conv in test/external/external_test_valid_remove.py
|
||||
|
@ -141,7 +141,7 @@ class TestValidSimplification(unittest.TestCase):
|
|||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"(((((idx1*8)+ridx2)<1)!=1)?read_imagef(data0, smp, (int2)((((idx1*24)+(ridx2*3)+ridx0+765)%768),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
"(((((idx2+ridx1)<1)!=1)&(((idx1+ridx2)<1)!=1))?read_imagef(data0, smp, (int2)((((idx1*24)+(ridx2*3)+ridx0+765)%768),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
|
|
|
@ -448,6 +448,19 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
|
||||
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))")
|
||||
|
||||
def test_simplex_lt(self):
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
c = Variable("c", 0, 3)
|
||||
d = Variable("d", -3, 3)
|
||||
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)")
|
||||
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
|
||||
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
|
||||
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*(-3))+(b*4))<1)!=1)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
|
|
|
@ -163,6 +163,19 @@ def fold_unrolled_divs(divs:UOp):
|
|||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
# returns x0 + x1 + ... in such case, or None if not
|
||||
changed, ret = False, []
|
||||
for u in _get_chain(X, BinaryOps.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) and u.vmin >= 0): return None
|
||||
ret.append(u)
|
||||
return functools.reduce(operator.add, ret) if changed else None
|
||||
|
||||
def is_increasing(f:UOp):
|
||||
# is f a monotonically increasing function regards its input
|
||||
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
|
||||
|
@ -415,6 +428,9 @@ constant_folder = PatternMatcher([
|
|||
# generic lt folding
|
||||
(UPat.var("x").lt(UPat.cvar("c", vec=False)),
|
||||
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
# not x < 1 -> X > 0
|
||||
(UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
# ** div **
|
||||
# # div folding
|
||||
(UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c:
|
||||
|
|
Loading…
Reference in New Issue