parallel beam search (#2610)

* better print

* fix beam search with vars

* cleanups

* parallel is not default

* restore that

* bugfix

* cleanups

* bugfix
This commit is contained in:
George Hotz 2023-12-05 10:09:45 -08:00 committed by GitHub
parent 9996f1adf9
commit 35b5e95097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 111 additions and 82 deletions

View File

@ -378,7 +378,7 @@ jobs:
echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel
sudo apt update -y
sudo apt install -y --no-install-recommends git g++ cmake ninja-build llvm-15-dev zlib1g-dev libglew-dev \
flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc
flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc libzstd-dev
- name: Cache gpuocelot
if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton'
id: cache-build

View File

@ -1,7 +1,17 @@
#!/usr/bin/env python
import unittest
import unittest, pickle
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, sym_render, sym_infer, create_rednode
class TestSymbolicPickle(unittest.TestCase):
def test_pickle_variable(self):
dat = Variable("a", 3, 8)
datp = pickle.loads(pickle.dumps(dat))
self.assertEqual(str(datp), "<a[3-8]>")
def test_pickle_variable_times_2(self):
dat = Variable("a", 3, 8)*2
datp = pickle.loads(pickle.dumps(dat))
self.assertEqual(str(datp), "<(a[3-8]*2)>")
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
self.assertEqual(v.render(), s)

View File

@ -468,7 +468,16 @@ class Kernel:
assert padded, "nothing was padded"
return self.simplify_ones()
def required_optimizations(self):
if self.bufs[0].dtype.__class__ is ImageDType:
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
def hand_coded_optimizations(self):
self.required_optimizations()
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \

View File

@ -285,6 +285,7 @@ class Compiled:
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts)
k.required_optimizations()
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
@ -293,6 +294,7 @@ class Compiled:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
# TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization

View File

@ -1,10 +1,11 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, random, math, time
import itertools, random, math, time, multiprocessing, traceback
from tinygrad.lazy import vars_from_ast
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.ops import MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.codegen.linearizer import Linearizer, UOp
from tinygrad.shape.symbolic import sym_infer
from collections import defaultdict
from tinygrad.tensor import Tensor
@ -21,57 +22,17 @@ actions += [
]
if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)]
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
# Set the midpoint value value for var_vals to optimize shapes.
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
try:
lin.linearize()
prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin)
real_global_size = prg.global_size
if allow_test_size and prg.global_size and all_int(tuple(prg.global_size)):
test_global_size = prg.global_size[:]
while prod(test_global_size) > max_global_size:
for j in range(2,-1,-1):
if test_global_size[j] > 16:
test_global_size[j] //= 2
break
factor = prod(prg.global_size) / prod(test_global_size)
prg.global_size = test_global_size
#print(real_global_size, test_global_size, factor)
else:
factor = 1
# TODO: this is copied from prg.__call__
global_size, local_size = prg.launch_dims(var_vals)
prg.global_size = real_global_size
if global_size is not None and prg.global_size is not None and local_size is None and all_int(tuple(prg.global_size)):
local_size = optimize_local_size(prg.clprg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
lra = prg.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size: lra['local_size'] = local_size
tms = []
for _ in range(cnt):
if clear_l2:
# TODO: this is too small for many L2 caches
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
tms.append(prg.clprg(*[x._buf for x in rawbufs], *var_vals.values(), **lra, wait=True)*factor)
except Exception:
if DEBUG >= 4:
import traceback
traceback.print_exc()
print("FAILED")
print(lin.ast)
print(lin.applied_opts)
tms = [float('inf')]
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)
def get_test_global_size(global_size, max_global_size):
test_global_size = global_size[:]
while prod(test_global_size) > max_global_size:
for j in range(2,-1,-1):
if test_global_size[j] > 16:
test_global_size[j] //= 2
break
factor = prod(global_size) / prod(test_global_size)
return test_global_size, factor
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Linearizer) -> List[Buffer]:
@ -102,7 +63,43 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
pass
return acted_lins
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
def try_compile_linearized_w_idx(x):
try: return (x[0], compile_linearizer(Device.DEFAULT, x[1], "test"))
except Exception:
if DEBUG >= 4: traceback.print_exc()
return (x[0], None)
def compile_linearizer(dev:str, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
lin.linearize()
rdev = Device[dev]
assert isinstance(rdev, Compiled)
src, _ = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
return rdev.compiler(src), lin.global_size, lin.local_size
def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
rdev = Device[dev]
assert isinstance(rdev, Compiled)
clprg = rdev.runtime(name, lib)
factor = 1
if global_size is not None:
global_size = [sym_infer(sz, var_vals) for sz in global_size] + [1]*(3-len(global_size))
if local_size is None:
local_size = optimize_local_size(clprg, global_size, rawbufs)
global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
else:
local_size = [sym_infer(sz, var_vals) for sz in local_size] + [1]*(3-len(local_size))
if max_global_size is not None:
global_size, factor = get_test_global_size(global_size, max_global_size=max_global_size)
lra = {}
if global_size: lra['global_size'] = global_size
if local_size: lra['local_size'] = local_size
tms = []
for _ in range(cnt):
if clear_l2:
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
tms.append(clprg(*[x._buf for x in rawbufs], **lra, vals=var_vals.values(), wait=True)*factor)
if early_stop is not None and early_stop < tms[-1]: break
return tms
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
@ -111,45 +108,43 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
return ret
# init the BEAM with the base linearizer
beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))]
beam: List[Tuple[Linearizer, float]] = []
seen_libs = set()
# NOTE: real uops use a weird compare method that's only valid inside a linearizer
seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)}
default_parallel = 1 if Device.DEFAULT == "HIP" else 0
pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
exiting, st = False, time.perf_counter()
dev = Device[Device.DEFAULT]
assert isinstance(dev, Compiled)
while not exiting:
with Timing("linearize: ", enabled=DEBUG>=3):
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam])
# linearize all
for x in acted_lins: x.linearize()
# dedup with uops
acted_lins_dedup = []
for lin in acted_lins:
tuops = tuplize_uops(lin.uops)
if tuops in seen_uops: continue
seen_uops[tuops] = tuple(lin.applied_opts)
acted_lins_dedup.append(lin)
with Timing("compile: ",enabled=DEBUG>=3):
# time linearizers
timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins_dedup]
opts = sorted(timed_lins, key=lambda x: x[1])
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin]
timed_lins: List[Tuple[Linearizer, float]] = []
for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size = proc
if lib in seen_libs: continue
seen_libs.add(lib)
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None)
timed_lins.append((acted_lins[i], min(tms)))
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="")
# done
exiting = len(opts) == 0 or beam[0][1] <= opts[0][1]
opts = sorted(timed_lins, key=lambda x: x[1])
exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1])
if not exiting: beam = opts[:amt]
if DEBUG >= 2: print(f"{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape())
assert len(beam) > 0, "no BEAM items succeeded?!?"
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape())
if pool is not None: pool.close() # the pool is closed
if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts)
if DEBUG >= 3: print(beam[0][0].applied_opts)
return beam[0][0]
def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]:
test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs
MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024
MAX_WORKGROUP = 1024
local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size]
local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice
def try_exec(local_size):
@ -160,3 +155,14 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]
def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float:
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)}
lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin)
tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)

View File

@ -45,7 +45,7 @@ class CUDAProgram:
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False):
if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context))
c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+vals)
c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+tuple(vals))
return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait)
class CUDAAllocator(LRUAllocator):

View File

@ -131,7 +131,9 @@ class Node:
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:Optional[str], nmin:int, nmax:int):
def __new__(cls, *args):
if len(args) == 0: return super().__new__(cls) # fix pickle
expr, nmin, nmax = args
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)