mirror of https://github.com/commaai/tinygrad.git
mcts search (#5598)
* mcts search * mcts cleanups * mcts cleanup * random shuffle children order * mcts in handcode_opt * src and remove_node * debug 3 to print ast * print the type * mcts in extra
This commit is contained in:
parent
b991097d41
commit
06e336bccb
|
@ -1,5 +1,6 @@
|
|||
from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from extra.mcts_search import mcts_search
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
|
@ -77,7 +78,7 @@ if __name__ == "__main__":
|
|||
running_gflops = 0
|
||||
usage = {}
|
||||
for i,si in enumerate(sched):
|
||||
if DEBUG >= 2: print(si.ast)
|
||||
if DEBUG >= 3: print(si.ast)
|
||||
|
||||
rawbufs = bufs_from_lin(Kernel(si.ast))
|
||||
|
||||
|
@ -87,30 +88,37 @@ if __name__ == "__main__":
|
|||
# always try hand coded opt
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
lin.hand_coded_optimizations()
|
||||
lins.append(lin)
|
||||
lins.append((lin, "HC"))
|
||||
|
||||
# maybe try tensor cores
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
lins.append((lin, "TC"))
|
||||
|
||||
# try a beam search
|
||||
if beam:=getenv("BEAM"):
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
lins.append(lin)
|
||||
lins.append((lin, "BEAM"))
|
||||
|
||||
# try MCTS
|
||||
if mcts:=getenv("MCTS"):
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
lin = mcts_search(lin, rawbufs, mcts)
|
||||
lins.append((lin, "MCTS"))
|
||||
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for lin in lins:
|
||||
for (lin, nm) in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
|
||||
ops = lin.to_program().op_estimate
|
||||
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
|
||||
choices.append((tm, gflops, lin.linearize()))
|
||||
|
||||
# print all kernels
|
||||
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
|
||||
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {nm}")
|
||||
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
|
||||
if getenv("SRC"): print(lin.to_program().src)
|
||||
total_tm += tm
|
||||
running_gflops += gflops * tm
|
||||
if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
from typing import List, Optional
|
||||
import math, functools, time, random
|
||||
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _try_compile_linearized_w_idx, _time_program
|
||||
|
||||
class MCTSNode:
|
||||
def __init__(self, kernel, parent=None):
|
||||
self.kernel = kernel
|
||||
self.t = 0
|
||||
self.n = 0
|
||||
self.parent: Optional[MCTSNode] = parent
|
||||
self.children: Optional[List[MCTSNode]] = None
|
||||
|
||||
def expand_node(node:MCTSNode):
|
||||
assert node.children is None
|
||||
node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
|
||||
random.shuffle(node.children)
|
||||
|
||||
C = math.sqrt(2)
|
||||
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
||||
# TODO: copied from BEAM
|
||||
key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
return ret
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
dev = Device[lin.opts.device]
|
||||
root = MCTSNode(lin)
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
|
||||
|
||||
def remove_node(node):
|
||||
if node.parent is not None:
|
||||
assert node.parent.children is not None
|
||||
node.parent.children.remove(node)
|
||||
|
||||
st = time.perf_counter()
|
||||
best, best_tm = lin, math.inf
|
||||
for i in range(amt):
|
||||
# tree traversal
|
||||
node = root
|
||||
while node.children is not None and len(node.children) != 0:
|
||||
#if DEBUG>=2: print(f"{(node.t/node.n)/best_tm:6.2f} value {node.n:3d}", node.kernel.name)
|
||||
ucb = sorted([(math.inf if child.n == 0 else ((child.t/child.n)/best_tm) + C*math.sqrt(math.log(node.n)/child.n), child)
|
||||
for child in node.children], key=lambda x: x[0], reverse=True) # pylint: disable=not-an-iterable
|
||||
node = ucb[0][1]
|
||||
|
||||
if node.children is not None: break # no more nodes?
|
||||
|
||||
# node expansion
|
||||
expand_node(node)
|
||||
|
||||
# rollout
|
||||
_, compile_ret = _compile_fn((0, node.kernel))
|
||||
if compile_ret is None:
|
||||
remove_node(node)
|
||||
continue
|
||||
|
||||
p, lib, _ = compile_ret
|
||||
try: tm = min(_time_program(p, lib, var_vals, rawbufs, early_stop=best_tm*10/1e6))*1e6
|
||||
except RuntimeError:
|
||||
remove_node(node)
|
||||
continue
|
||||
|
||||
if DEBUG>=2: print(f"\r{time.perf_counter() - st:7.2f}s: {tm:12.2f} us best: {best_tm:12.2f} us {i+1:4d}/{amt:4d} {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
|
||||
if tm < best_tm: best, best_tm = node.kernel, tm
|
||||
|
||||
# backprop
|
||||
bnode: Optional[MCTSNode] = node
|
||||
while bnode is not None:
|
||||
bnode.t += -tm
|
||||
bnode.n += 1
|
||||
bnode = bnode.parent
|
||||
|
||||
if DEBUG>=2: print()
|
||||
if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
|
||||
return best
|
Loading…
Reference in New Issue