mcts search (#5598)

* mcts search

* mcts cleanups

* mcts cleanup

* random shuffle children order

* mcts in handcode_opt

* src and remove_node

* debug 3 to print ast

* print the type

* mcts in extra
This commit is contained in:
George Hotz 2024-07-19 21:38:39 -07:00 committed by GitHub
parent b991097d41
commit 06e336bccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 95 additions and 6 deletions

View File

@ -1,5 +1,6 @@
from typing import List
from extra.models.resnet import ResNet50
from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel
@ -77,7 +78,7 @@ if __name__ == "__main__":
running_gflops = 0
usage = {}
for i,si in enumerate(sched):
if DEBUG >= 2: print(si.ast)
if DEBUG >= 3: print(si.ast)
rawbufs = bufs_from_lin(Kernel(si.ast))
@ -87,30 +88,37 @@ if __name__ == "__main__":
# always try hand coded opt
lin = Kernel(si.ast, opts=device.renderer)
lin.hand_coded_optimizations()
lins.append(lin)
lins.append((lin, "HC"))
# maybe try tensor cores
lin = Kernel(si.ast, opts=device.renderer)
if lin.apply_tensor_cores():
lins.append(lin)
lins.append((lin, "TC"))
# try a beam search
if beam:=getenv("BEAM"):
lin = Kernel(si.ast, opts=device.renderer)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append(lin)
lins.append((lin, "BEAM"))
# try MCTS
if mcts:=getenv("MCTS"):
lin = Kernel(si.ast, opts=device.renderer)
lin = mcts_search(lin, rawbufs, mcts)
lins.append((lin, "MCTS"))
# benchmark the programs
choices = []
for lin in lins:
for (lin, nm) in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10)
ops = lin.to_program().op_estimate
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
choices.append((tm, gflops, lin.linearize()))
# print all kernels
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS")
if DEBUG >= 1: print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {nm}")
tm, gflops, lin = sorted(choices, key=lambda x: x[0])[0]
if getenv("SRC"): print(lin.to_program().src)
total_tm += tm
running_gflops += gflops * tm
if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)

81
extra/mcts_search.py Normal file
View File

@ -0,0 +1,81 @@
from typing import List, Optional
import math, functools, time, random
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Buffer, Device
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _try_compile_linearized_w_idx, _time_program
class MCTSNode:
def __init__(self, kernel, parent=None):
self.kernel = kernel
self.t = 0
self.n = 0
self.parent: Optional[MCTSNode] = parent
self.children: Optional[List[MCTSNode]] = None
def expand_node(node:MCTSNode):
assert node.children is None
node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
random.shuffle(node.children)
C = math.sqrt(2)
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
# TODO: copied from BEAM
key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
dev = Device[lin.opts.device]
root = MCTSNode(lin)
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=dev.compiler)
def remove_node(node):
if node.parent is not None:
assert node.parent.children is not None
node.parent.children.remove(node)
st = time.perf_counter()
best, best_tm = lin, math.inf
for i in range(amt):
# tree traversal
node = root
while node.children is not None and len(node.children) != 0:
#if DEBUG>=2: print(f"{(node.t/node.n)/best_tm:6.2f} value {node.n:3d}", node.kernel.name)
ucb = sorted([(math.inf if child.n == 0 else ((child.t/child.n)/best_tm) + C*math.sqrt(math.log(node.n)/child.n), child)
for child in node.children], key=lambda x: x[0], reverse=True) # pylint: disable=not-an-iterable
node = ucb[0][1]
if node.children is not None: break # no more nodes?
# node expansion
expand_node(node)
# rollout
_, compile_ret = _compile_fn((0, node.kernel))
if compile_ret is None:
remove_node(node)
continue
p, lib, _ = compile_ret
try: tm = min(_time_program(p, lib, var_vals, rawbufs, early_stop=best_tm*10/1e6))*1e6
except RuntimeError:
remove_node(node)
continue
if DEBUG>=2: print(f"\r{time.perf_counter() - st:7.2f}s: {tm:12.2f} us best: {best_tm:12.2f} us {i+1:4d}/{amt:4d} {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
if tm < best_tm: best, best_tm = node.kernel, tm
# backprop
bnode: Optional[MCTSNode] = node
while bnode is not None:
bnode.t += -tm
bnode.n += 1
bnode = bnode.parent
if DEBUG>=2: print()
if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
return best