mirror of https://github.com/commaai/tinygrad.git
KOPT is over, BEAM is upstream (#2071)
* create cache for q learning * make linter happy * global beam * where it belongs * bugfix * ditch the kopt, use the beam * faster lin and DEBUG=2 okay * remove kopt, move search to features
This commit is contained in:
parent
e4660b024f
commit
c36d306606
|
@ -170,9 +170,6 @@ jobs:
|
|||
run: |
|
||||
PYTHONPATH="." python test/external/dist/test_world.py
|
||||
PYTHONPATH="." python test/external/dist/test_collectives.py
|
||||
- if: ${{ matrix.task == 'realworld' }}
|
||||
name: Test KOPT
|
||||
run: PYTHONPATH="." KOPT=1 BUDGET=20 GPU=1 DEBUG=1 python -m pytest -rA -n=auto test/models/test_real_world.py
|
||||
- if: ${{ matrix.task == 'realworld' }}
|
||||
name: Run GPT2
|
||||
run: |
|
||||
|
|
|
@ -43,7 +43,6 @@ LLVM | [1] | enable LLVM backend
|
|||
LLVMOPT | [1] | enable slightly more expensive LLVM optimizations
|
||||
LAZY | [1] | enable lazy operations (this is the default)
|
||||
OPT | [1-3] | optimization level
|
||||
KOPT | [1-2] | kernel optimization, 1 turns it on, 2 caches the found optimizations
|
||||
BUDGET | [#] | kernel optimization search budget
|
||||
GRAPH | [1] | create a graph of all operations (requires graphviz)
|
||||
GRAPHPATH | [/path/to] | where to put the generated graph
|
||||
|
@ -177,7 +176,6 @@ TORCHCUDA | [1] | enable the torch cuda backend
|
|||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
KOPT | [1] | enable kernel optimization
|
||||
KCACHE | [1] | enable kernel cache
|
||||
|
||||
### test/external/external_test_opt.py
|
||||
|
|
|
@ -3,8 +3,8 @@ from models.resnet import ResNet50
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, Device, Compiled
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, flatten
|
||||
from tinygrad.features.search import time_linearizer, beam_search
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
@ -62,17 +62,7 @@ if __name__ == "__main__":
|
|||
for ao in global_db[str(lin.ast)]:
|
||||
lin.apply_opt(ao)
|
||||
else:
|
||||
best_tm = float('inf')
|
||||
beam = [lin]
|
||||
while 1:
|
||||
acted_lins = flatten([get_linearizer_actions(lin).items() for lin in beam])
|
||||
timed_lins = [(v,time_linearizer(v, rawbufs)) for k,v in acted_lins if k != 0]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
|
||||
best_tm = opts[0][1]
|
||||
beam = [x[0] for x in opts[:getenv("BEAM")]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
lin = beam[0]
|
||||
lin = beam_search(lin, rawbufs, getenv("BEAM"))
|
||||
global_db[str(lin.ast)] = lin.applied_opts
|
||||
lins.append(lin)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from tqdm import tqdm
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.codegen.search import actions
|
||||
from tinygrad.features.search import actions
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
tactions = set()
|
||||
|
|
|
@ -4,7 +4,7 @@ from tinygrad.nn import Linear
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.optim import Adam
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import actions
|
||||
from tinygrad.features.search import actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
|
||||
|
||||
INNER = 32
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
import math, random
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.features.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.optimization.pretrain_policynet import PolicyNet
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
|
|
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
|||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions
|
||||
from tinygrad.features.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
from extra.optimization.pretrain_policynet import PolicyNet
|
||||
from extra.optimization.pretrain_valuenet import ValueNet
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.features.search import bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
|
|
|
@ -13,9 +13,6 @@ OSX = platform.system() == "Darwin"
|
|||
|
||||
def compile_and_test_ast(ast, local_size=None):
|
||||
k = CLCodegen(ast)
|
||||
if getenv("KOPT", 0):
|
||||
from extra.kernel_search import apply_optimization
|
||||
apply_optimization(k, ast, 10, getenv("KCACHE", 0))
|
||||
prg = k.codegen().build(CLProgram)
|
||||
if local_size is not None: prg.local_size = local_size
|
||||
for i in range(5): prg(prg.lower(k.bufs))
|
||||
|
|
|
@ -6,34 +6,12 @@ from tinygrad.nn.state import get_parameters
|
|||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.ops import Device, GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.helpers import CI, dtypes, getenv, prod
|
||||
from tinygrad.features.kopt import kernel_optimize_opts
|
||||
|
||||
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
||||
from examples.stable_diffusion import UNetModel
|
||||
|
||||
def kopt_search_hook(k, create_k, to_prg, baseline, bufs, var_vals):
|
||||
import nevergrad as ng
|
||||
wanna_output = bufs[0].toCPU().copy()
|
||||
def check_opt(x):
|
||||
try:
|
||||
k = create_k()
|
||||
for o in x: k.apply_opt(o)
|
||||
prg = to_prg(k)
|
||||
first_tm = prg.exec(bufs, var_vals, force_wait=True, optimizing=True)
|
||||
np.testing.assert_allclose(wanna_output, bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
|
||||
return first_tm
|
||||
except Exception:
|
||||
return 10000_000 # 10000 seconds is infinity
|
||||
opts = kernel_optimize_opts(k)
|
||||
if not opts: return "BASELINE"
|
||||
search_space = prod([len(x.choices) for x in opts])
|
||||
budget = getenv("BUDGET", 20) # THIS IS TEST BUDGET
|
||||
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, budget))
|
||||
recommendation = optimizer.minimize(check_opt)
|
||||
return recommendation.value if recommendation.loss < baseline else "BASELINE"
|
||||
|
||||
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
tms = []
|
||||
for _ in range(4):
|
||||
|
@ -70,15 +48,9 @@ class TestRealWorld(unittest.TestCase):
|
|||
def setUp(self):
|
||||
self.old_type = Tensor.default_type
|
||||
np.random.seed(2002)
|
||||
# TODO: abstract better to remove this junk
|
||||
if getenv("KOPT"):
|
||||
self.oldfunc = getattr(__import__("tinygrad.features.kopt", fromlist=["kernel_optimize_search"]), "kernel_optimize_search")
|
||||
setattr(__import__("tinygrad.features.kopt", fromlist=["kernel_optimize_search"]), "kernel_optimize_search", kopt_search_hook)
|
||||
|
||||
def tearDown(self):
|
||||
Tensor.default_type = self.old_type
|
||||
if getenv("KOPT"):
|
||||
setattr(__import__("tinygrad.features.kopt", fromlist=["kernel_optimize_search"]), "kernel_optimize_search", self.oldfunc)
|
||||
|
||||
@unittest.skipUnless(not CI, "too big for CI")
|
||||
def test_stable_diffusion(self):
|
||||
|
@ -111,7 +83,6 @@ class TestRealWorld(unittest.TestCase):
|
|||
def test(t): return model(t, 0).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 129 if CI else 369, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(getenv("KOPT"), "cifar hangs with KOPT")
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM")
|
||||
def test_train_cifar(self):
|
||||
# TODO: with default device
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from __future__ import annotations
|
||||
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
|
||||
from copy import deepcopy
|
||||
import itertools
|
||||
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps, Device, Compiled
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen
|
||||
|
@ -84,6 +86,9 @@ class Kernel:
|
|||
self.global_size: Optional[List[int]] = None
|
||||
self.local_size: Optional[List[int]] = None
|
||||
|
||||
def copy(self):
|
||||
return deepcopy(self)
|
||||
|
||||
@property
|
||||
def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)]
|
||||
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
from typing import Callable
|
||||
import time
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
from tinygrad.helpers import DEBUG, prod, getenv
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
|
||||
def get_divisors(n, min_div = 1, max_div = 512):
|
||||
if min_div > 1: yield 1
|
||||
for d in range(min_div, min(max_div, n//2) + 1):
|
||||
if n % d == 0: yield d
|
||||
|
||||
def kernel_optimize_opts(k:Linearizer):
|
||||
import nevergrad as ng
|
||||
opts = []
|
||||
for i in range(k.first_reduce):
|
||||
# TODO: the upcast always happen first, you might want to reverse this?
|
||||
# TODO: the order of the locals might improve things too
|
||||
opts.append(ng.p.TransitionChoice([Opt(OptOps.UPCAST,i,s) for s in get_divisors(k.full_shape[i], max_div=8)]))
|
||||
opts.append(ng.p.TransitionChoice([Opt(OptOps.LOCAL,i,s) for s in get_divisors(k.full_shape[i], min_div=4)]))
|
||||
for i in range(k.shape_len-k.first_reduce):
|
||||
opts.append(ng.p.TransitionChoice([Opt(OptOps.UNROLL,i,s) for s in get_divisors(k.full_shape[k.first_reduce+i], max_div=8)]))
|
||||
return opts
|
||||
|
||||
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline, bufs, var_vals):
|
||||
import nevergrad as ng
|
||||
def opt(x):
|
||||
try:
|
||||
k = create_k()
|
||||
for o in x: k.apply_opt(o)
|
||||
prg = to_prg(k)
|
||||
first_tm = prg.exec(bufs, var_vals, force_wait=True, optimizing=True)
|
||||
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
|
||||
tm = min([first_tm]+[prg.exec(bufs, var_vals, force_wait=True, optimizing=True) for _ in range(2)])*1000
|
||||
return tm
|
||||
except Exception:
|
||||
if DEBUG >= 3:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 10000_000 # 10000 seconds is infinity
|
||||
opts = kernel_optimize_opts(k)
|
||||
if not opts: return "BASELINE"
|
||||
search_space = prod([len(x.choices) for x in opts])
|
||||
st = time.perf_counter()
|
||||
budget = getenv("BUDGET", 200)
|
||||
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, budget))
|
||||
recommendation = optimizer.minimize(opt)
|
||||
et = time.perf_counter() - st
|
||||
if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
|
||||
return recommendation.value if recommendation.loss < baseline else "BASELINE"
|
||||
|
||||
# optimization
|
||||
global_db = None
|
||||
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, bufs, key):
|
||||
global global_db
|
||||
|
||||
skey = str(key)
|
||||
|
||||
if getenv("KOPT") == 2 and global_db is None:
|
||||
import shelve
|
||||
global_db = shelve.open("/tmp/kopt_cache")
|
||||
|
||||
if global_db is not None and skey in global_db:
|
||||
choice = global_db[skey]
|
||||
elif k.has_variable_shape():
|
||||
# don't optimize variable shapes
|
||||
choice = "BASELINE"
|
||||
else:
|
||||
var_vals = {k:k.min for k in vars_from_ast(k.ast)}
|
||||
# get baseline
|
||||
def get_baseline():
|
||||
k = create_k()
|
||||
k.hand_coded_optimizations()
|
||||
prg = to_prg(k)
|
||||
return min([prg.exec(bufs, var_vals, force_wait=True, optimizing=True) for _ in range(5)])*1000
|
||||
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline(), bufs, var_vals)
|
||||
if global_db is not None:
|
||||
global_db[skey] = choice
|
||||
global_db.sync()
|
||||
|
||||
if choice == "BASELINE":
|
||||
k.hand_coded_optimizations()
|
||||
else:
|
||||
for o in choice: k.apply_opt(o)
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Dict, List, cast, DefaultDict, Optional
|
||||
from copy import deepcopy
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, getenv, ImageDType, flatten
|
||||
from tinygrad.helpers import prod, getenv, ImageDType, flatten, DEBUG
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
|
@ -18,16 +17,18 @@ actions += [
|
|||
Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256),
|
||||
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256)
|
||||
]
|
||||
device:Compiled = cast(Compiled, Device[Device.DEFAULT])
|
||||
|
||||
# returns time in seconds
|
||||
logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None
|
||||
import shelve
|
||||
logtm = shelve.open(getenv("LOGTM", "")) if getenv("LOGTM", "") else None
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float:
|
||||
if should_copy: lin = deepcopy(lin) # TODO: remove the need for this
|
||||
key = str((lin.ast, lin.applied_opts))
|
||||
if should_copy and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
|
||||
if should_copy: lin = lin.copy() # TODO: remove the need for this
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
lin.linearize()
|
||||
prg = device.to_program(lin)
|
||||
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
|
||||
real_global_size = prg.global_size[:]
|
||||
if allow_test_size:
|
||||
test_global_size = prg.global_size[:]
|
||||
|
@ -41,14 +42,16 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
|
|||
#print(real_global_size, test_global_size, factor)
|
||||
else:
|
||||
factor = 1
|
||||
tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)]
|
||||
# TODO: this is super broken for var_vals
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)]
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
#print("FAILED")
|
||||
#print(lin.ast)
|
||||
#print(lin.applied_opts)
|
||||
tms = [float('inf')]
|
||||
if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n")
|
||||
if logtm is not None: logtm[key] = tms
|
||||
return min(tms)
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
|
@ -57,17 +60,17 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
|||
for x in lin.membufs: bufsts[x.idx].append(x)
|
||||
rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
rawbufs[k] = device.buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
||||
rawbufs[k] = cast(Compiled, Device[Device.DEFAULT]).buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
return cast(List[RawBuffer], rawbufs)
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]:
|
||||
acted_lins = {0:deepcopy(lin)}
|
||||
def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]:
|
||||
acted_lins = {0:lin.copy()} if include_0 else {}
|
||||
for i,a in enumerate(actions):
|
||||
if a.axis >= lin.shape_len: continue
|
||||
if lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue
|
||||
lin2 = deepcopy(lin)
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
up, lcl = 1, 1
|
||||
|
@ -79,3 +82,16 @@ def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]:
|
|||
except Exception:
|
||||
pass
|
||||
return acted_lins
|
||||
|
||||
def beam_search(lin, rawbufs, amt):
|
||||
best_tm = float('inf')
|
||||
beam: List[Linearizer] = [lin]
|
||||
while 1:
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin in beam])
|
||||
timed_lins = [(v,time_linearizer(v, rawbufs)) for v in acted_lins]
|
||||
opts = sorted(timed_lins, key=lambda x: x[1])
|
||||
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
|
||||
best_tm = opts[0][1]
|
||||
beam = [x[0] for x in opts[:amt]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
return beam[0]
|
|
@ -268,8 +268,9 @@ class Compiled:
|
|||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
from tinygrad.features.kopt import kernel_optimize
|
||||
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, self.linearizer_opts), self.to_program, rawbuffers, ast)
|
||||
if getenv("BEAM"):
|
||||
from tinygrad.features.search import beam_search
|
||||
k = beam_search(k, rawbuffers, getenv("BEAM"))
|
||||
elif not getenv("NOOPT"):
|
||||
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
|
||||
return self.to_program(k)
|
||||
|
|
Loading…
Reference in New Issue