faster get_recursive_parents (#2392)

* faster get_recursive_parents

* skip test for those

* full sum works everywhere

* timing

* debug print
This commit is contained in:
George Hotz 2023-11-22 20:37:19 -08:00 committed by GitHub
parent 8798d120bb
commit 80e4ad8bf5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 36 deletions

View File

@ -110,6 +110,13 @@ class TestLinearizer(unittest.TestCase):
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16])
def test_sum_collapse(self):
t = Tensor.ones(256,256).sum()
sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def helper_realized_ast(r:Tensor):
s = r.lazydata.schedule()
run_schedule(s[:-1]) # run all kernels except the last one

View File

@ -363,17 +363,10 @@ class Linearizer(Kernel):
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# graph helper functions
def get_recursive_parents(x:List[UOp]) -> List[UOp]:
ret: Set[UOp] = set()
this_round: Set[UOp] = set(x)
while len(this_round):
ret = ret.union(this_round)
next_round: Set[UOp] = set()
for r in this_round: next_round = next_round.union(set(r.vin))
this_round = next_round
return list(ret)
@functools.lru_cache(None)
def get_recursive_parents(x:UOp) -> Set[UOp]: return set.union(set(x.vin), *[get_recursive_parents(p) for p in x.vin])
def get_recursive_children(x:UOp) -> List[UOp]:
def get_recursive_children(x:UOp) -> Set[UOp]:
deps = set([x])
ssize = 0
while ssize != len(deps):
@ -381,7 +374,7 @@ class Linearizer(Kernel):
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return sorted(list(deps), key=self.uops.index) # get the last one
return deps
def replace_op(old:UOp, new:UOp):
for u in self.uops:
@ -395,7 +388,7 @@ class Linearizer(Kernel):
elif u.uop == UOps.LOOP: loop_stack.append([u])
elif u.uop not in [UOps.CONST, UOps.ALU]: loop_stack[-1].append(u)
else:
parents = get_recursive_parents([u])
parents = get_recursive_parents(u)
for i in reversed(range(len(loop_stack))):
# check backwards and put the uop in the first encounter with some dependency
if any(x in parents for x in loop_stack[i]) or i == 0:
@ -411,7 +404,7 @@ class Linearizer(Kernel):
if u.uop == UOps.PHI and len(u.vin) == 3:
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
# TODO: ADD becomes a MUL, MAX can just become nothing
if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD:
if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD:
if DEBUG >= 4: print(f"removing PHI node {u}")
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
@ -438,7 +431,7 @@ class Linearizer(Kernel):
for u in self.uops:
if u.uop == UOps.LOOP:
# add END of loops after the last thing that (recursively) depends on them
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1)
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1)
elif u.uop == UOps.IF:
# END any if statements at the end of the uops
self.uop(UOps.END, None, (u,), cachable=False)

View File

@ -1,8 +1,8 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math
import itertools, random, math, time
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
@ -18,8 +18,8 @@ actions += [
Opt(op=OptOps.LOCAL, axis=0, amt=32),
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
Opt(op=OptOps.NOLOCALS),
]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
@ -116,28 +116,31 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
while 1:
acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
exiting, st = False, time.perf_counter()
while not exiting:
with Timing("linearize: ", enabled=DEBUG>=3):
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
# dedup with uops (TODO: double linearize not needed)
# linearize all
for x in acted_lins: x.linearize()
# dedup with uops
acted_lins_dedup = []
for lin in acted_lins:
tuops = tuplize_uops(lin.linearize().uops)
if tuops in seen_uops:
#print(seen_uops[tuops], lin.applied_opts)
continue
tuops = tuplize_uops(lin.uops)
if tuops in seen_uops: continue
seen_uops[tuops] = tuple(lin.applied_opts)
acted_lins_dedup.append(lin)
acted_lins = acted_lins_dedup
with Timing("compile: ",enabled=DEBUG>=3):
# time linearizers
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins]
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins_dedup]
opts = sorted(timed_lins, key=lambda x: x[1])
if len(opts) == 0 or beam[0][1] <= opts[0][1]: break # we didn't get faster
# keep the BEAM best
beam = opts[:amt]
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
# done
exiting = len(opts) == 0 or beam[0][1] <= opts[0][1]
if not exiting: beam = opts[:amt]
if DEBUG >= 2: print(f"{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts)