mirror of https://github.com/commaai/tinygrad.git
arange folding with new ge (#6604)
* arange folding with new ge * bump allowed gated * bump allowed speed
This commit is contained in:
parent
224151a958
commit
a1a882b006
|
@ -197,7 +197,7 @@ jobs:
|
|||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=397 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=482 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
|
|
|
@ -291,7 +291,7 @@ class TestHCQ(unittest.TestCase):
|
|||
et = TestHCQ.d0._gpu2cpu_time(sig_en.timestamp, True) - TestHCQ.d0._gpu2cpu_time(sig_st.timestamp, True)
|
||||
|
||||
print(f"exec kernel time: {et:.2f} us")
|
||||
assert 1 <= et <= (2500 if CI else 30)
|
||||
assert 1 <= et <= (3000 if CI else 30)
|
||||
|
||||
def test_speed_copy_bandwidth(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
|
|
@ -242,10 +242,11 @@ def threefry2x32(x: UOp, seed: UOp):
|
|||
|
||||
# ***** main rewriter *****
|
||||
|
||||
def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx3=None, extra=None, vec=None):
|
||||
def loop_collapse(compval, idx, multconst, rng:UOp, reduce, idx2=None, idx3=None, extra=None, vec=None, ne=None, mval:UOp=UOp.const(dtypes.int32, 1)):
|
||||
if getenv("DISABLE_LOOP_COLLAPSE") or rng not in reduce.src: return None # must be the right REDUCE
|
||||
loop_start, loop_end = rng.src
|
||||
if mval.arg >= 0 or loop_start.arg != 0:
|
||||
mval_arg = mval.arg
|
||||
if loop_start.arg != 0:
|
||||
# TODO: support and test this with other mvals and loop_starts
|
||||
if DEBUG >= 1: print(f"WARNING, NOT FOLDING: mval:{mval.arg} loop_start:{loop_start.arg}")
|
||||
return None
|
||||
|
@ -255,7 +256,12 @@ def loop_collapse(compval, idx, mval, multconst, rng:UOp, reduce, idx2=None, idx
|
|||
# idx, mval, loop_start, loop_end
|
||||
def dvec(x): return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
|
||||
idx, mval, loop_start, loop_end = dvec(idx), dvec(mval), dvec(loop_start), dvec(loop_end)
|
||||
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
|
||||
if mval_arg > 0 and ne is not None:
|
||||
comprange = UOp.min(loop_end, UOp.max((idx-compval)//mval + (loop_end-loop_start), loop_start))
|
||||
elif mval_arg < 0 and ne is None:
|
||||
comprange = UOp.min(loop_end, UOp.max((idx-compval-mval)//mval + (loop_end-loop_start), loop_start))
|
||||
else:
|
||||
return None
|
||||
new_reduce_op = comprange.cast(multconst.dtype) * multconst
|
||||
ret = UOp(UOps.REDUCE, reduce.dtype, (new_reduce_op,) + tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
if extra is not None: ret = ret + UOp(UOps.REDUCE, reduce.dtype, (extra,) + reduce.src[1:], reduce.arg)
|
||||
|
@ -335,6 +341,13 @@ constant_folder = PatternMatcher([
|
|||
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
|
||||
.lt(UPat.cvar("compval")).where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# arange loop folding (new ge)
|
||||
(UPat(UOps.REDUCE, src=(UPat.any(m2:=UPat.any(
|
||||
m1:=(UPat.var("idx") + UPat.any(UPat.cvar("mval") * UPat(UOps.RANGE, name="rng"), UPat(UOps.RANGE, name="rng"))),
|
||||
m1 + UPat.var("idx2"), m1 + UPat.var("idx2") + UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=m1))
|
||||
.lt(UPat.cvar("compval")).ne(UPat(UOps.CONST, name="ne", arg=True))
|
||||
.where(UPat.cvar("multconst"), UPat.const(None, 0)), m2 + UPat.var("extra")),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), loop_collapse),
|
||||
# unrolled arange div folding
|
||||
(UPat.var("divs") + UPat.cvar("c"), fold_unrolled_divs),
|
||||
# indexing, with cast or where
|
||||
|
|
|
@ -65,8 +65,8 @@ class MathTrait:
|
|||
def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x))
|
||||
def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self)
|
||||
# TODO: use this one instead
|
||||
#def ge(self, x): return self.lt(x).ne(True)
|
||||
def ge(self, x): return (-self).lt(-x+1)
|
||||
def ge(self, x): return self.lt(x).ne(True)
|
||||
#def ge(self, x): return (-self).lt(-x+1)
|
||||
def max(self, x): return self.alu(BinaryOps.MAX, self.ufix(x))
|
||||
def min(self, x): return -(-self).max(-x)
|
||||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
||||
|
|
Loading…
Reference in New Issue