mirror of https://github.com/commaai/tinygrad.git
Fix int child count (#2882)
* pad ops broke coder * that contiguous fixes it * Update lazy.py * recursive add * fix all * revert that * todo test
This commit is contained in:
parent
8a04107d30
commit
8c4a0f8e15
|
@ -1,8 +1,10 @@
|
|||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.device import InterpretedASTRunner
|
||||
from tinygrad.lazy import create_schedule
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.realize import run_schedule, lower_schedule_item
|
||||
|
||||
class TestFusionOp(unittest.TestCase):
|
||||
def test_contiguous_add(self):
|
||||
|
@ -21,5 +23,16 @@ class TestFusionOp(unittest.TestCase):
|
|||
outd = out.data().tolist()
|
||||
assert all(x == 20.0 for x in outd)
|
||||
|
||||
# TODO: fix this test to be fast and remove O(2^n) behavior
|
||||
def test_recursive_add(self):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(12): a = a + a
|
||||
sched = create_schedule([a.lazydata], None)
|
||||
ji = lower_schedule_item(sched[-1])
|
||||
et = time.perf_counter()
|
||||
self.assertLess(et-st, 10.0)
|
||||
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -74,7 +74,7 @@ def log_lazybuffer(lb, scheduled=False):
|
|||
lb = lb.base
|
||||
if lb.realized is None:
|
||||
for x in lb.srcs:
|
||||
log_lazybuffer(x)
|
||||
if nm(x) not in G.nodes: log_lazybuffer(x)
|
||||
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
||||
label = '"' + \
|
||||
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
||||
|
|
|
@ -158,7 +158,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
|
||||
if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
|
||||
r[u] = val
|
||||
else:
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")
|
||||
|
|
Loading…
Reference in New Issue