fix no locals search and search both (#2171)

* fix no locals search and search both

* pretty print

* nolocals default no other search
This commit is contained in:
George Hotz 2023-10-30 10:22:50 -07:00 committed by GitHub
parent 194e4ad6f8
commit 608e3ee800
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 9 deletions

View File

@ -232,6 +232,7 @@ class OptimizedKernel(Kernel):
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
assert isinstance(amt, int) and amt != 1, "shift of amt 1 or Node is meaningless"
assert not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals"
if opt.op == OptOps.LOCAL: # cyan
assert axis < self.first_reduce, "can't local a reduce"
assert not(self.tensor_core), "can't local with tensor cores"

View File

@ -20,7 +20,7 @@ actions += [
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
if should_copy and not disable_cache and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
if should_copy and not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
if should_copy: lin = lin.copy() # TODO: remove the need for this
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
try:
@ -85,11 +85,12 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
pass
return acted_lins
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE"):
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True, dont_use_locals=False) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "dont_use_locals": dont_use_locals}
if dont_use_locals: lin.dont_use_locals = True
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
ret = lin.copy()
for o in val: ret.apply_opt(o)
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
# init the BEAM with the base linearizer
@ -122,6 +123,6 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
beam = opts[:amt]
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
diskcache_put("beam_search", key, beam[0][0].applied_opts)
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 2: print(beam[0][0].applied_opts)
return beam[0][0]

View File

@ -169,7 +169,7 @@ def cache_compiled(func):
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 2
VERSION = 3
_db_connection = None
def db_connection():
global _db_connection

View File

@ -288,13 +288,16 @@ class Compiled:
if not getenv("NOOPT"):
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [self.buffer(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
kb.dont_use_locals = bool(getenv("NOLOCALS"))
from tinygrad.features.search import beam_search, time_linearizer
lins = [(f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))), (("tc" if used_tensor_cores else "hc"), k)]
if not bool(getenv("NOLOCALS")) or getenv("NOLOCALS") >= 2:
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if bool(getenv("NOLOCALS")):
lins.append((f"beam{BEAM.value}n", beam_search(kb.copy(), test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)), dont_use_locals=True)))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()