diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c99c678e..9701e2dd 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -110,6 +110,13 @@ class TestLinearizer(unittest.TestCase): assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9) lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16]) + def test_sum_collapse(self): + t = Tensor.ones(256,256).sum() + sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] + assert len(sched) == 1 + lin = Linearizer(sched[0].ast) + assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" + def helper_realized_ast(r:Tensor): s = r.lazydata.schedule() run_schedule(s[:-1]) # run all kernels except the last one diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 74d4d9d4..f330f9a2 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -363,17 +363,10 @@ class Linearizer(Kernel): self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val) # graph helper functions - def get_recursive_parents(x:List[UOp]) -> List[UOp]: - ret: Set[UOp] = set() - this_round: Set[UOp] = set(x) - while len(this_round): - ret = ret.union(this_round) - next_round: Set[UOp] = set() - for r in this_round: next_round = next_round.union(set(r.vin)) - this_round = next_round - return list(ret) + @functools.lru_cache(None) + def get_recursive_parents(x:UOp) -> Set[UOp]: return set.union(set(x.vin), *[get_recursive_parents(p) for p in x.vin]) - def get_recursive_children(x:UOp) -> List[UOp]: + def get_recursive_children(x:UOp) -> Set[UOp]: deps = set([x]) ssize = 0 while ssize != len(deps): @@ -381,7 +374,7 @@ class Linearizer(Kernel): for u in self.uops: if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): deps.add(u) - return sorted(list(deps), key=self.uops.index) # get the last one + return deps def replace_op(old:UOp, new:UOp): for u in self.uops: @@ -395,7 +388,7 @@ class Linearizer(Kernel): elif u.uop == UOps.LOOP: loop_stack.append([u]) elif u.uop not in [UOps.CONST, UOps.ALU]: loop_stack[-1].append(u) else: - parents = get_recursive_parents([u]) + parents = get_recursive_parents(u) for i in reversed(range(len(loop_stack))): # check backwards and put the uop in the first encounter with some dependency if any(x in parents for x in loop_stack[i]) or i == 0: @@ -411,7 +404,7 @@ class Linearizer(Kernel): if u.uop == UOps.PHI and len(u.vin) == 3: # if the parents of the PHI node don't have the LOOP in their parents, it can be folded # TODO: ADD becomes a MUL, MAX can just become nothing - if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD: + if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD: if DEBUG >= 4: print(f"removing PHI node {u}") del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)] # NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype @@ -438,7 +431,7 @@ class Linearizer(Kernel): for u in self.uops: if u.uop == UOps.LOOP: # add END of loops after the last thing that (recursively) depends on them - self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1) + self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1) elif u.uop == UOps.IF: # END any if statements at the end of the uops self.uop(UOps.END, None, (u,), cachable=False) diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 1fbe5dec..78dc6b23 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,8 +1,8 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable -import itertools, random, math +import itertools, random, math, time from tinygrad.lazy import vars_from_ast from tinygrad.ops import Device, Compiled, MemBuffer -from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int +from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing from tinygrad.codegen.linearizer import Linearizer, UOp from tinygrad.runtime.lib import RawBuffer from collections import defaultdict @@ -18,8 +18,8 @@ actions += [ Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8), Opt(op=OptOps.UPCASTMID, axis=1, amt=4), - Opt(op=OptOps.NOLOCALS), ] +if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] # returns time in seconds def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: @@ -116,28 +116,31 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea # NOTE: real uops use a weird compare method that's only valid inside a linearizer seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)} - while 1: - acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) + exiting, st = False, time.perf_counter() + while not exiting: + with Timing("linearize: ", enabled=DEBUG>=3): + acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) - # dedup with uops (TODO: double linearize not needed) - acted_lins_dedup = [] - for lin in acted_lins: - tuops = tuplize_uops(lin.linearize().uops) - if tuops in seen_uops: - #print(seen_uops[tuops], lin.applied_opts) - continue - seen_uops[tuops] = tuple(lin.applied_opts) - acted_lins_dedup.append(lin) - acted_lins = acted_lins_dedup + # linearize all + for x in acted_lins: x.linearize() - # time linearizers - timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins] - opts = sorted(timed_lins, key=lambda x: x[1]) - if len(opts) == 0 or beam[0][1] <= opts[0][1]: break # we didn't get faster + # dedup with uops + acted_lins_dedup = [] + for lin in acted_lins: + tuops = tuplize_uops(lin.uops) + if tuops in seen_uops: continue + seen_uops[tuops] = tuple(lin.applied_opts) + acted_lins_dedup.append(lin) - # keep the BEAM best - beam = opts[:amt] - if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape()) + with Timing("compile: ",enabled=DEBUG>=3): + # time linearizers + timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins_dedup] + opts = sorted(timed_lins, key=lambda x: x[1]) + + # done + exiting = len(opts) == 0 or beam[0][1] <= opts[0][1] + if not exiting: beam = opts[:amt] + if DEBUG >= 2: print(f"{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape()) if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if DEBUG >= 3: print(beam[0][0].applied_opts)