From 06e336bccb5f0e0e623cdfaf81d281695e048e81 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 19 Jul 2024 21:38:39 -0700 Subject: [PATCH] mcts search (#5598) * mcts search * mcts cleanups * mcts cleanup * random shuffle children order * mcts in handcode_opt * src and remove_node * debug 3 to print ast * print the type * mcts in extra --- examples/handcode_opt.py | 20 +++++++--- extra/mcts_search.py | 81 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 extra/mcts_search.py diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index 0961e7ad..3f258264 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -1,5 +1,6 @@ from typing import List from extra.models.resnet import ResNet50 +from extra.mcts_search import mcts_search from examples.mlperf.helpers import get_mlperf_bert_model from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel @@ -77,7 +78,7 @@ if __name__ == "__main__": running_gflops = 0 usage = {} for i,si in enumerate(sched): - if DEBUG >= 2: print(si.ast) + if DEBUG >= 3: print(si.ast) rawbufs = bufs_from_lin(Kernel(si.ast)) @@ -87,30 +88,37 @@ if __name__ == "__main__": # always try hand coded opt lin = Kernel(si.ast, opts=device.renderer) lin.hand_coded_optimizations() - lins.append(lin) + lins.append((lin, "HC")) # maybe try tensor cores lin = Kernel(si.ast, opts=device.renderer) if lin.apply_tensor_cores(): - lins.append(lin) + lins.append((lin, "TC")) # try a beam search if beam:=getenv("BEAM"): lin = Kernel(si.ast, opts=device.renderer) lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1))) - lins.append(lin) + lins.append((lin, "BEAM")) + + # try MCTS + if mcts:=getenv("MCTS"): + lin = Kernel(si.ast, opts=device.renderer) + lin = mcts_search(lin, rawbufs, mcts) + lins.append((lin, "MCTS")) # benchmark the programs choices = [] - for lin in lins: + for (lin, nm) in lins: tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10) ops = lin.to_program().op_estimate gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm choices.append((tm, gflops, lin.linearize())) # print all kernels - if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS") + if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {nm}") tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0] + if getenv("SRC"): print(lin.to_program().src) total_tm += tm running_gflops += gflops * tm if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0) diff --git a/extra/mcts_search.py b/extra/mcts_search.py new file mode 100644 index 00000000..1b0fb9ee --- /dev/null +++ b/extra/mcts_search.py @@ -0,0 +1,81 @@ +from typing import List, Optional +import math, functools, time, random +from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put +from tinygrad.codegen.kernel import Kernel +from tinygrad.device import Buffer, Device +from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _try_compile_linearized_w_idx, _time_program + +class MCTSNode: + def __init__(self, kernel, parent=None): + self.kernel = kernel + self.t = 0 + self.n = 0 + self.parent: Optional[MCTSNode] = parent + self.children: Optional[List[MCTSNode]] = None + +def expand_node(node:MCTSNode): + assert node.children is None + node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()] + random.shuffle(node.children) + +C = math.sqrt(2) +def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel: + # TODO: copied from BEAM + key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix} + if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None: + ret = lin.copy() + for o in val[len(lin.applied_opts):]: ret.apply_opt(o) + return ret + + rawbufs = _ensure_buffer_alloc(rawbufs) + var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()} + dev = Device[lin.opts.device] + root = MCTSNode(lin) + _compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler) + + def remove_node(node): + if node.parent is not None: + assert node.parent.children is not None + node.parent.children.remove(node) + + st = time.perf_counter() + best, best_tm = lin, math.inf + for i in range(amt): + # tree traversal + node = root + while node.children is not None and len(node.children) != 0: + #if DEBUG>=2: print(f"{(node.t/node.n)/best_tm:6.2f} value {node.n:3d}", node.kernel.name) + ucb = sorted([(math.inf if child.n == 0 else ((child.t/child.n)/best_tm) + C*math.sqrt(math.log(node.n)/child.n), child) + for child in node.children], key=lambda x: x[0], reverse=True) # pylint: disable=not-an-iterable + node = ucb[0][1] + + if node.children is not None: break # no more nodes? + + # node expansion + expand_node(node) + + # rollout + _, compile_ret = _compile_fn((0, node.kernel)) + if compile_ret is None: + remove_node(node) + continue + + p, lib, _ = compile_ret + try: tm = min(_time_program(p, lib, var_vals, rawbufs, early_stop=best_tm*10/1e6))*1e6 + except RuntimeError: + remove_node(node) + continue + + if DEBUG>=2: print(f"\r{time.perf_counter() - st:7.2f}s: {tm:12.2f} us best: {best_tm:12.2f} us {i+1:4d}/{amt:4d} {node.kernel.colored_shape()}\033[K", end="") # noqa: E501 + if tm < best_tm: best, best_tm = node.kernel, tm + + # backprop + bnode: Optional[MCTSNode] = node + while bnode is not None: + bnode.t += -tm + bnode.n += 1 + bnode = bnode.parent + + if DEBUG>=2: print() + if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts) + return best