mirror of https://github.com/commaai/tinygrad.git
faster get_recursive_parents (#2392)
* faster get_recursive_parents * skip test for those * full sum works everywhere * timing * debug print
This commit is contained in:
parent
8798d120bb
commit
80e4ad8bf5
|
@ -110,6 +110,13 @@ class TestLinearizer(unittest.TestCase):
|
|||
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
|
||||
lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16])
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor.ones(256,256).sum()
|
||||
sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = r.lazydata.schedule()
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
|
|
|
@ -363,17 +363,10 @@ class Linearizer(Kernel):
|
|||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# graph helper functions
|
||||
def get_recursive_parents(x:List[UOp]) -> List[UOp]:
|
||||
ret: Set[UOp] = set()
|
||||
this_round: Set[UOp] = set(x)
|
||||
while len(this_round):
|
||||
ret = ret.union(this_round)
|
||||
next_round: Set[UOp] = set()
|
||||
for r in this_round: next_round = next_round.union(set(r.vin))
|
||||
this_round = next_round
|
||||
return list(ret)
|
||||
@functools.lru_cache(None)
|
||||
def get_recursive_parents(x:UOp) -> Set[UOp]: return set.union(set(x.vin), *[get_recursive_parents(p) for p in x.vin])
|
||||
|
||||
def get_recursive_children(x:UOp) -> List[UOp]:
|
||||
def get_recursive_children(x:UOp) -> Set[UOp]:
|
||||
deps = set([x])
|
||||
ssize = 0
|
||||
while ssize != len(deps):
|
||||
|
@ -381,7 +374,7 @@ class Linearizer(Kernel):
|
|||
for u in self.uops:
|
||||
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
||||
deps.add(u)
|
||||
return sorted(list(deps), key=self.uops.index) # get the last one
|
||||
return deps
|
||||
|
||||
def replace_op(old:UOp, new:UOp):
|
||||
for u in self.uops:
|
||||
|
@ -395,7 +388,7 @@ class Linearizer(Kernel):
|
|||
elif u.uop == UOps.LOOP: loop_stack.append([u])
|
||||
elif u.uop not in [UOps.CONST, UOps.ALU]: loop_stack[-1].append(u)
|
||||
else:
|
||||
parents = get_recursive_parents([u])
|
||||
parents = get_recursive_parents(u)
|
||||
for i in reversed(range(len(loop_stack))):
|
||||
# check backwards and put the uop in the first encounter with some dependency
|
||||
if any(x in parents for x in loop_stack[i]) or i == 0:
|
||||
|
@ -411,7 +404,7 @@ class Linearizer(Kernel):
|
|||
if u.uop == UOps.PHI and len(u.vin) == 3:
|
||||
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
|
||||
# TODO: ADD becomes a MUL, MAX can just become nothing
|
||||
if all(x.uop != UOps.LOOP for x in get_recursive_parents(list(u.vin[0:2]))) and u.vin[1].arg == BinaryOps.ADD:
|
||||
if all(x.uop != UOps.LOOP for x in get_recursive_parents(UOp(u.uop, u.dtype, u.vin[0:2], u.arg))) and u.vin[1].arg == BinaryOps.ADD:
|
||||
if DEBUG >= 4: print(f"removing PHI node {u}")
|
||||
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
|
||||
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
|
||||
|
@ -438,7 +431,7 @@ class Linearizer(Kernel):
|
|||
for u in self.uops:
|
||||
if u.uop == UOps.LOOP:
|
||||
# add END of loops after the last thing that (recursively) depends on them
|
||||
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(get_recursive_children(u)[-1])+1)
|
||||
self.uop(UOps.END, None, (u,), cachable=False, insert_before=self.uops.index(sorted(list(get_recursive_children(u)), key=self.uops.index)[-1])+1)
|
||||
elif u.uop == UOps.IF:
|
||||
# END any if statements at the end of the uops
|
||||
self.uop(UOps.END, None, (u,), cachable=False)
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, random, math
|
||||
import itertools, random, math, time
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
|
@ -18,8 +18,8 @@ actions += [
|
|||
Opt(op=OptOps.LOCAL, axis=0, amt=32),
|
||||
Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.GROUP, axis=1, amt=8),
|
||||
Opt(op=OptOps.UPCASTMID, axis=1, amt=4),
|
||||
Opt(op=OptOps.NOLOCALS),
|
||||
]
|
||||
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
|
||||
|
||||
# returns time in seconds
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
|
||||
|
@ -116,28 +116,31 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
|||
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
|
||||
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
|
||||
|
||||
while 1:
|
||||
acted_lins = lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
exiting, st = False, time.perf_counter()
|
||||
while not exiting:
|
||||
with Timing("linearize: ", enabled=DEBUG>=3):
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
|
||||
|
||||
# dedup with uops (TODO: double linearize not needed)
|
||||
acted_lins_dedup = []
|
||||
for lin in acted_lins:
|
||||
tuops = tuplize_uops(lin.linearize().uops)
|
||||
if tuops in seen_uops:
|
||||
#print(seen_uops[tuops], lin.applied_opts)
|
||||
continue
|
||||
seen_uops[tuops] = tuple(lin.applied_opts)
|
||||
acted_lins_dedup.append(lin)
|
||||
acted_lins = acted_lins_dedup
|
||||
# linearize all
|
||||
for x in acted_lins: x.linearize()
|
||||
|
||||
# time linearizers
|
||||
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
if len(opts) == 0 or beam[0][1] <= opts[0][1]: break # we didn't get faster
|
||||
# dedup with uops
|
||||
acted_lins_dedup = []
|
||||
for lin in acted_lins:
|
||||
tuops = tuplize_uops(lin.uops)
|
||||
if tuops in seen_uops: continue
|
||||
seen_uops[tuops] = tuple(lin.applied_opts)
|
||||
acted_lins_dedup.append(lin)
|
||||
|
||||
# keep the BEAM best
|
||||
beam = opts[:amt]
|
||||
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
|
||||
with Timing("compile: ",enabled=DEBUG>=3):
|
||||
# time linearizers
|
||||
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins_dedup]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
|
||||
# done
|
||||
exiting = len(opts) == 0 or beam[0][1] <= opts[0][1]
|
||||
if not exiting: beam = opts[:amt]
|
||||
if DEBUG >= 2: print(f"{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
|
||||
|
||||
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
|
||||
if DEBUG >= 3: print(beam[0][0].applied_opts)
|
||||
|
|
Loading…
Reference in New Issue