tinygrad/extra/mcts_search.py

175 lines
6.5 KiB
Python

from __future__ import annotations
from typing import List, Optional, Dict, cast
import numpy as np
np.set_printoptions(suppress=True)
import math, functools, time, random, statistics
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Buffer, Device, CompileError
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
class MCTSNode:
def __init__(self, kernel:Kernel, parent=None):
self.kernel:Kernel = kernel
self.t = math.inf
self.n = 0
self.tm = math.inf
self.i = -1
self.parents: List[MCTSNode] = [parent] if parent is not None else []
self.children: Optional[List[MCTSNode]] = None
self.removed_children: List[MCTSNode] = []
def expand_node(node:MCTSNode):
assert node.children is None
node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
def remove_node(node:MCTSNode):
for parent in node.parents:
assert parent.children is not None
parent.children.remove(node)
parent.removed_children.append(node)
C = math.sqrt(2)
TEMP = 0.5
def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
if node.children is None or len(node.children) == 0: return node
unexplored_children = []
explored_children = []
ucb_explored_children = []
for child in node.children:
if child.n == 0: unexplored_children.append(child)
else:
ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
if not math.isinf(ucb):
explored_children.append(child)
ucb_explored_children.append(ucb)
if len(unexplored_children): return random.choice(unexplored_children)
if not len(explored_children): return node
ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP)
return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
# this will expand/remove sometimes
def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]:
if root.children is None: expand_node(root)
while root.children:
# tree traversal
node = _sample_tree(root, best_tm)
if node.children is not None and len(node.children) == 0:
remove_node(node)
continue
# node expansion
if node.n != 0:
if node.children is None: expand_node(node)
assert node.children is not None
if len(node.children) == 0:
remove_node(node)
continue
node = random.choice(node.children)
return node
return None
def backprop(bnode:MCTSNode, tm, strength=1.0):
if bnode.t > tm: bnode.t = tm
bnode.n += strength
for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents))
graph_mcts_cnt = 0
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
global graph_mcts_cnt
# TODO: copied from BEAM
key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
dev = Device[lin.opts.device]
root = MCTSNode(lin)
st = time.perf_counter()
best, best_idx, best_tm = lin, 0, math.inf
seen_libs: Dict[bytes, MCTSNode] = {}
seen_asts: Dict[bytes, MCTSNode] = {}
compile_time, runtime_time = 0.0, 0.0
for i in range(amt):
node = sample_tree(root, best_tm) # sample and expand
if node is None: break # finished the whole tree
node.i = i # when was node explored
opt_ast = node.kernel.get_optimized_ast()
if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
# early check for same optimized AST hit
remove_node(node)
tm = sibling_node.t
else:
seen_asts[opt_ast.key] = node
# lowering (50% of the time)
p = node.kernel.to_program(name_override="test")
# rollout
tm1 = time.perf_counter()
try:
lib = dev.compiler.compile(p.src)
except CompileError:
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
lib = None
tm2 = time.perf_counter()
if lib is None:
tm = math.inf
else:
if (sibling_node:=seen_libs.get(lib, None)) is not None:
# NOTE: these should all be caught by the AST check, need to canonicalize
# remove this node, it's a duplicate
remove_node(node)
tm = sibling_node.t
else:
seen_libs[lib] = node
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
except RuntimeError: tm = math.inf
node.tm = tm
tm3 = time.perf_counter()
compile_time += tm2-tm1
runtime_time += tm3-tm2
# mock rollout
#node.tm = tm = random.random() + 0.1
if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
et = time.perf_counter() - st
if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
# backprop
backprop(node, tm)
if DEBUG>=2: print()
if getenv("MCTSGRAPH"):
import networkx as nx
import os
GRAPHPATH = "/tmp/net"
def save_graph(G, fn, opt=""):
print("saving", G, f"to {fn}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
G = nx.DiGraph()
def add_node(node:MCTSNode):
if node.n == 0: return
for parent in node.parents: G.add_edge(parent, node)
gopts = node.kernel.applied_opts
edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].amt}" if len(gopts) else "ROOT"
G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
if node.children is not None:
for child in node.children+node.removed_children: add_node(child)
add_node(root)
save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR')
graph_mcts_cnt += 1
if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
return best