mirror of https://github.com/commaai/tinygrad.git
175 lines
6.5 KiB
Python
175 lines
6.5 KiB
Python
from __future__ import annotations
|
|
from typing import List, Optional, Dict, cast
|
|
import numpy as np
|
|
np.set_printoptions(suppress=True)
|
|
import math, functools, time, random, statistics
|
|
from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache_put, colored, Profiling
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.device import Buffer, Device, CompileError
|
|
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
|
|
|
|
class MCTSNode:
|
|
def __init__(self, kernel:Kernel, parent=None):
|
|
self.kernel:Kernel = kernel
|
|
self.t = math.inf
|
|
self.n = 0
|
|
self.tm = math.inf
|
|
self.i = -1
|
|
self.parents: List[MCTSNode] = [parent] if parent is not None else []
|
|
self.children: Optional[List[MCTSNode]] = None
|
|
self.removed_children: List[MCTSNode] = []
|
|
|
|
def expand_node(node:MCTSNode):
|
|
assert node.children is None
|
|
node.children = [MCTSNode(x, node) for x in get_kernel_actions(node.kernel, include_0=False).values()]
|
|
|
|
def remove_node(node:MCTSNode):
|
|
for parent in node.parents:
|
|
assert parent.children is not None
|
|
parent.children.remove(node)
|
|
parent.removed_children.append(node)
|
|
|
|
C = math.sqrt(2)
|
|
TEMP = 0.5
|
|
def _sample_tree(node:MCTSNode, best_tm:float) -> MCTSNode:
|
|
if node.children is None or len(node.children) == 0: return node
|
|
unexplored_children = []
|
|
explored_children = []
|
|
ucb_explored_children = []
|
|
for child in node.children:
|
|
if child.n == 0: unexplored_children.append(child)
|
|
else:
|
|
ucb = -child.t/best_tm + C*math.sqrt(math.log(node.n)/child.n)
|
|
if not math.isinf(ucb):
|
|
explored_children.append(child)
|
|
ucb_explored_children.append(ucb)
|
|
if len(unexplored_children): return random.choice(unexplored_children)
|
|
if not len(explored_children): return node
|
|
ucb_exp = np.exp(np.array(ucb_explored_children)/TEMP)
|
|
return _sample_tree(explored_children[np.random.choice(len(ucb_exp), p=ucb_exp/np.sum(ucb_exp))], best_tm)
|
|
|
|
# this will expand/remove sometimes
|
|
def sample_tree(root:MCTSNode, best_tm:float) -> Optional[MCTSNode]:
|
|
if root.children is None: expand_node(root)
|
|
while root.children:
|
|
# tree traversal
|
|
node = _sample_tree(root, best_tm)
|
|
|
|
if node.children is not None and len(node.children) == 0:
|
|
remove_node(node)
|
|
continue
|
|
|
|
# node expansion
|
|
if node.n != 0:
|
|
if node.children is None: expand_node(node)
|
|
assert node.children is not None
|
|
if len(node.children) == 0:
|
|
remove_node(node)
|
|
continue
|
|
node = random.choice(node.children)
|
|
return node
|
|
return None
|
|
|
|
def backprop(bnode:MCTSNode, tm, strength=1.0):
|
|
if bnode.t > tm: bnode.t = tm
|
|
bnode.n += strength
|
|
for parent in bnode.parents: backprop(parent, tm, strength/len(bnode.parents))
|
|
|
|
graph_mcts_cnt = 0
|
|
def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
|
global graph_mcts_cnt
|
|
# TODO: copied from BEAM
|
|
key = {"ast": lin.ast.key, "amt": amt, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
|
if not getenv("IGNORE_MCTS_CACHE") and CACHELEVEL >= 1 and (val:=diskcache_get("mcts_search", key)) is not None:
|
|
ret = lin.copy()
|
|
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
|
return ret
|
|
|
|
rawbufs = _ensure_buffer_alloc(rawbufs)
|
|
var_vals = {k:(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
|
dev = Device[lin.opts.device]
|
|
root = MCTSNode(lin)
|
|
|
|
st = time.perf_counter()
|
|
best, best_idx, best_tm = lin, 0, math.inf
|
|
seen_libs: Dict[bytes, MCTSNode] = {}
|
|
seen_asts: Dict[bytes, MCTSNode] = {}
|
|
compile_time, runtime_time = 0.0, 0.0
|
|
for i in range(amt):
|
|
node = sample_tree(root, best_tm) # sample and expand
|
|
if node is None: break # finished the whole tree
|
|
node.i = i # when was node explored
|
|
|
|
opt_ast = node.kernel.get_optimized_ast()
|
|
if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
|
|
# early check for same optimized AST hit
|
|
remove_node(node)
|
|
tm = sibling_node.t
|
|
else:
|
|
seen_asts[opt_ast.key] = node
|
|
|
|
# lowering (50% of the time)
|
|
p = node.kernel.to_program(name_override="test")
|
|
|
|
# rollout
|
|
tm1 = time.perf_counter()
|
|
try:
|
|
lib = dev.compiler.compile(p.src)
|
|
except CompileError:
|
|
# NOTE: many of these "compiler errors" are caused by bad code output from the lowerer
|
|
lib = None
|
|
tm2 = time.perf_counter()
|
|
if lib is None:
|
|
tm = math.inf
|
|
else:
|
|
if (sibling_node:=seen_libs.get(lib, None)) is not None:
|
|
# NOTE: these should all be caught by the AST check, need to canonicalize
|
|
# remove this node, it's a duplicate
|
|
remove_node(node)
|
|
tm = sibling_node.t
|
|
else:
|
|
seen_libs[lib] = node
|
|
try: tm = statistics.median(_time_program(p, lib, var_vals, rawbufs, cnt=3, early_stop=best_tm*5/1e6))*1e6
|
|
except RuntimeError: tm = math.inf
|
|
node.tm = tm
|
|
tm3 = time.perf_counter()
|
|
compile_time += tm2-tm1
|
|
runtime_time += tm3-tm2
|
|
|
|
# mock rollout
|
|
#node.tm = tm = random.random() + 0.1
|
|
|
|
if tm < best_tm: best, best_idx, best_tm = node.kernel, i, tm
|
|
et = time.perf_counter() - st
|
|
if DEBUG>=2: print(f"\r{et:7.2f}s {colored(f'{compile_time*100/et:3.0f}%', 'cyan')} {colored(f'{runtime_time*100/et:3.0f}%', 'red')}: {tm:12.2f} us best: {best_tm:12.2f} us @ {best_idx+1:4d} {i+1:4d}/{amt:4d} {int(round((i+1)/et)):4d}/s {node.kernel.colored_shape()}\033[K", end="") # noqa: E501
|
|
|
|
# backprop
|
|
backprop(node, tm)
|
|
if DEBUG>=2: print()
|
|
|
|
if getenv("MCTSGRAPH"):
|
|
import networkx as nx
|
|
import os
|
|
GRAPHPATH = "/tmp/net"
|
|
def save_graph(G, fn, opt=""):
|
|
print("saving", G, f"to {fn}.svg")
|
|
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
|
|
os.system(f'dot {opt} -Tsvg {fn}.dot -o {fn}.svg')
|
|
|
|
G = nx.DiGraph()
|
|
def add_node(node:MCTSNode):
|
|
if node.n == 0: return
|
|
for parent in node.parents: G.add_edge(parent, node)
|
|
gopts = node.kernel.applied_opts
|
|
edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].amt}" if len(gopts) else "ROOT"
|
|
G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}",
|
|
fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '')
|
|
if node.children is not None:
|
|
for child in node.children+node.removed_children: add_node(child)
|
|
add_node(root)
|
|
save_graph(G, f"{GRAPHPATH}.{graph_mcts_cnt}.mcts", '-Grankdir=LR')
|
|
graph_mcts_cnt += 1
|
|
|
|
if CACHELEVEL >= 1: diskcache_put("mcts_search", key, best.applied_opts)
|
|
return best
|