mirror of https://github.com/commaai/tinygrad.git
try beam search on device (#2085)
* try beam search on device * fix beam with nolocals * ops too --------- Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
parent
c36d306606
commit
a7b18ac325
|
@ -235,7 +235,7 @@ def Dropout(data, ratio=0.5, training_mode=False, seed=None):
|
|||
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
|
||||
def Shape(data, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64)
|
||||
def Shape(data, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32)
|
||||
def Size(data): return prod(data if isinstance(data, list) else data.shape)
|
||||
def Flatten(input, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
||||
|
||||
|
|
|
@ -66,6 +66,6 @@ if __name__ == "__main__":
|
|||
run_schedule(schedule_independent)
|
||||
|
||||
print("**** running real kernels ****")
|
||||
with Context(DEBUG=2):
|
||||
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
|
||||
GlobalCounters.reset()
|
||||
run_schedule(schedule)
|
||||
|
|
|
@ -21,9 +21,9 @@ actions += [
|
|||
# returns time in seconds
|
||||
import shelve
|
||||
logtm = shelve.open(getenv("LOGTM", "")) if getenv("LOGTM", "") else None
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float:
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
|
||||
key = str((lin.ast, lin.applied_opts))
|
||||
if should_copy and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
|
||||
if should_copy and not disable_cache and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
|
||||
if should_copy: lin = lin.copy() # TODO: remove the need for this
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
|
@ -43,7 +43,11 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
|
|||
else:
|
||||
factor = 1
|
||||
# TODO: this is super broken for var_vals
|
||||
# TODO: this is copied from prg.__call__
|
||||
global_size, local_size = prg.launch_dims(var_vals)
|
||||
if local_size is None:
|
||||
local_size = prg.optimize_local_size(global_size, rawbufs)
|
||||
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
|
||||
tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)]
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
|
@ -93,5 +97,5 @@ def beam_search(lin, rawbufs, amt):
|
|||
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
|
||||
best_tm = opts[0][1]
|
||||
beam = [x[0] for x in opts[:amt]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e6:12.2f} us from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
return beam[0]
|
||||
|
|
|
@ -59,7 +59,7 @@ class ContextVar:
|
|||
def __gt__(self, x): return self.value > x
|
||||
def __lt__(self, x): return self.value < x
|
||||
|
||||
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
|
||||
DEBUG, IMAGE, BEAM = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0)
|
||||
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||
|
||||
class Timing(contextlib.ContextDecorator):
|
||||
|
|
|
@ -3,7 +3,7 @@ import time, importlib, inspect, functools, pathlib, itertools, random
|
|||
import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from dataclasses import dataclass
|
||||
|
@ -268,11 +268,18 @@ class Compiled:
|
|||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
if getenv("BEAM"):
|
||||
from tinygrad.features.search import beam_search
|
||||
k = beam_search(k, rawbuffers, getenv("BEAM"))
|
||||
elif not getenv("NOOPT"):
|
||||
if not getenv("NOOPT"):
|
||||
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
|
||||
if BEAM:
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
kb.dont_use_locals = bool(getenv("NOLOCALS"))
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
kb = beam_search(kb, rawbuffers, BEAM.value)
|
||||
baseline, beamtime = time_linearizer(k, rawbuffers, allow_test_size=False, disable_cache=True), time_linearizer(kb, rawbuffers, allow_test_size=False, disable_cache=True)
|
||||
if beamtime < baseline:
|
||||
if DEBUG >= 1: print(f"beam search {beamtime*1e6:<12.2f} beat baseline {baseline*1e6:<12.2f} by {baseline/beamtime:.2f}x")
|
||||
k = kb
|
||||
return self.to_program(k)
|
||||
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
|
|
Loading…
Reference in New Issue