try beam search on device (#2085)

* try beam search on device

* fix beam with nolocals

* ops too

---------

Co-authored-by: Comma Device <device@comma.ai>
This commit is contained in:
George Hotz 2023-10-16 12:52:42 -07:00 committed by GitHub
parent c36d306606
commit a7b18ac325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 22 additions and 11 deletions

View File

@ -235,7 +235,7 @@ def Dropout(data, ratio=0.5, training_mode=False, seed=None):
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
def Shape(data, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64)
def Shape(data, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32)
def Size(data): return prod(data if isinstance(data, list) else data.shape)
def Flatten(input, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)

View File

@ -66,6 +66,6 @@ if __name__ == "__main__":
run_schedule(schedule_independent)
print("**** running real kernels ****")
with Context(DEBUG=2):
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
GlobalCounters.reset()
run_schedule(schedule)

View File

@ -21,9 +21,9 @@ actions += [
# returns time in seconds
import shelve
logtm = shelve.open(getenv("LOGTM", "")) if getenv("LOGTM", "") else None
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float:
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
key = str((lin.ast, lin.applied_opts))
if should_copy and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
if should_copy and not disable_cache and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
if should_copy: lin = lin.copy() # TODO: remove the need for this
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
try:
@ -43,7 +43,11 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
else:
factor = 1
# TODO: this is super broken for var_vals
# TODO: this is copied from prg.__call__
global_size, local_size = prg.launch_dims(var_vals)
if local_size is None:
local_size = prg.optimize_local_size(global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)]
prg.global_size = real_global_size
except Exception:
@ -93,5 +97,5 @@ def beam_search(lin, rawbufs, amt):
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
best_tm = opts[0][1]
beam = [x[0] for x in opts[:amt]]
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape())
if DEBUG >= 1: print(f"{opts[0][1]*1e6:12.2f} us from {len(opts):3d} actions", beam[0].colored_shape())
return beam[0]

View File

@ -59,7 +59,7 @@ class ContextVar:
def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x
DEBUG, IMAGE = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0)
DEBUG, IMAGE, BEAM = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(contextlib.ContextDecorator):

View File

@ -3,7 +3,7 @@ import time, importlib, inspect, functools, pathlib, itertools, random
import numpy as np
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
from dataclasses import dataclass
@ -268,11 +268,18 @@ class Compiled:
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if getenv("BEAM"):
from tinygrad.features.search import beam_search
k = beam_search(k, rawbuffers, getenv("BEAM"))
elif not getenv("NOOPT"):
if not getenv("NOOPT"):
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
if BEAM:
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
kb.dont_use_locals = bool(getenv("NOLOCALS"))
from tinygrad.features.search import beam_search, time_linearizer
kb = beam_search(kb, rawbuffers, BEAM.value)
baseline, beamtime = time_linearizer(k, rawbuffers, allow_test_size=False, disable_cache=True), time_linearizer(kb, rawbuffers, allow_test_size=False, disable_cache=True)
if beamtime < baseline:
if DEBUG >= 1: print(f"beam search {beamtime*1e6:<12.2f} beat baseline {baseline*1e6:<12.2f} by {baseline/beamtime:.2f}x")
k = kb
return self.to_program(k)
if getenv("ENABLE_METHOD_CACHE", 1):