Fix int child count (#2882)

* pad ops broke coder

* that contiguous fixes it

* Update lazy.py

* recursive add

* fix all

* revert that

* todo test
This commit is contained in:
George Hotz 2023-12-20 21:06:27 -08:00 committed by GitHub
parent 8a04107d30
commit 8c4a0f8e15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 3 deletions

View File

@ -1,8 +1,10 @@
import unittest
import time
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.device import InterpretedASTRunner
from tinygrad.lazy import create_schedule
from tinygrad.realize import run_schedule
from tinygrad.realize import run_schedule, lower_schedule_item
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):
@ -21,5 +23,16 @@ class TestFusionOp(unittest.TestCase):
outd = out.data().tolist()
assert all(x == 20.0 for x in outd)
# TODO: fix this test to be fast and remove O(2^n) behavior
def test_recursive_add(self):
st = time.perf_counter()
a = Tensor([1,2,3,4])
for _ in range(12): a = a + a
sched = create_schedule([a.lazydata], None)
ji = lower_schedule_item(sched[-1])
et = time.perf_counter()
self.assertLess(et-st, 10.0)
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -74,7 +74,7 @@ def log_lazybuffer(lb, scheduled=False):
lb = lb.base
if lb.realized is None:
for x in lb.srcs:
log_lazybuffer(x)
if nm(x) not in G.nodes: log_lazybuffer(x)
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \

View File

@ -158,7 +158,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"):
r[u] = val
else:
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};")