diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fd166f5d..d705cd1f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -378,7 +378,7 @@ jobs: echo 'Acquire::http::Pipeline-Depth "5";' | sudo tee -a /etc/apt/apt.conf.d/99parallel sudo apt update -y sudo apt install -y --no-install-recommends git g++ cmake ninja-build llvm-15-dev zlib1g-dev libglew-dev \ - flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc + flex bison libfl-dev libboost-thread-dev libboost-filesystem-dev nvidia-cuda-toolkit-gcc libzstd-dev - name: Cache gpuocelot if: matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton' id: cache-build diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 664a3c95..2b4397c6 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -1,7 +1,17 @@ #!/usr/bin/env python -import unittest +import unittest, pickle from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, sym_render, sym_infer, create_rednode +class TestSymbolicPickle(unittest.TestCase): + def test_pickle_variable(self): + dat = Variable("a", 3, 8) + datp = pickle.loads(pickle.dumps(dat)) + self.assertEqual(str(datp), "") + def test_pickle_variable_times_2(self): + dat = Variable("a", 3, 8)*2 + datp = pickle.loads(pickle.dumps(dat)) + self.assertEqual(str(datp), "<(a[3-8]*2)>") + class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): self.assertEqual(v.render(), s) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 3bcbc806..6d39053d 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -468,7 +468,16 @@ class Kernel: assert padded, "nothing was padded" return self.simplify_ones() + def required_optimizations(self): + if self.bufs[0].dtype.__class__ is ImageDType: + unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0] + assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}" + if len(unit_stride_axes_mul_4) and all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: + self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + def hand_coded_optimizations(self): + self.required_optimizations() + # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ diff --git a/tinygrad/device.py b/tinygrad/device.py index 8494eb28..2c4607b7 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -285,6 +285,7 @@ class Compiled: print_tree(ast) from tinygrad.codegen.linearizer import Linearizer k = Linearizer(ast, self.linearizer_opts) + k.required_optimizations() if not NOOPT: if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() if BEAM >= 1: @@ -293,6 +294,7 @@ class Compiled: lins.append(("hc", Linearizer(ast, self.linearizer_opts))) lins[-1][1].hand_coded_optimizations() kb = Linearizer(ast, self.linearizer_opts) + kb.required_optimizations() from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index 9be8bb89..4b017853 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,10 +1,11 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable -import itertools, random, math, time +import itertools, random, math, time, multiprocessing, traceback from tinygrad.lazy import vars_from_ast from tinygrad.device import Device, Compiled, Buffer from tinygrad.ops import MemBuffer -from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing +from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.codegen.linearizer import Linearizer, UOp +from tinygrad.shape.symbolic import sym_infer from collections import defaultdict from tinygrad.tensor import Tensor @@ -21,57 +22,17 @@ actions += [ ] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] -# returns time in seconds -def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: - key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} - if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) +def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) - # Set the midpoint value value for var_vals to optimize shapes. - var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} - try: - lin.linearize() - prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin) - real_global_size = prg.global_size - if allow_test_size and prg.global_size and all_int(tuple(prg.global_size)): - test_global_size = prg.global_size[:] - while prod(test_global_size) > max_global_size: - for j in range(2,-1,-1): - if test_global_size[j] > 16: - test_global_size[j] //= 2 - break - factor = prod(prg.global_size) / prod(test_global_size) - prg.global_size = test_global_size - #print(real_global_size, test_global_size, factor) - else: - factor = 1 - - # TODO: this is copied from prg.__call__ - global_size, local_size = prg.launch_dims(var_vals) - prg.global_size = real_global_size - if global_size is not None and prg.global_size is not None and local_size is None and all_int(tuple(prg.global_size)): - local_size = optimize_local_size(prg.clprg, global_size, rawbufs) - global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] - - lra = prg.runtime_args.copy() - if global_size: lra['global_size'] = global_size - if local_size: lra['local_size'] = local_size - - tms = [] - for _ in range(cnt): - if clear_l2: - # TODO: this is too small for many L2 caches - with Context(DEBUG=0): Tensor.rand(1024,1024).realize() - tms.append(prg.clprg(*[x._buf for x in rawbufs], *var_vals.values(), **lra, wait=True)*factor) - except Exception: - if DEBUG >= 4: - import traceback - traceback.print_exc() - print("FAILED") - print(lin.ast) - print(lin.applied_opts) - tms = [float('inf')] - if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) - return min(tms) +def get_test_global_size(global_size, max_global_size): + test_global_size = global_size[:] + while prod(test_global_size) > max_global_size: + for j in range(2,-1,-1): + if test_global_size[j] > 16: + test_global_size[j] //= 2 + break + factor = prod(global_size) / prod(test_global_size) + return test_global_size, factor # get (scrap) buffers for timing the linearizer def bufs_from_lin(lin:Linearizer) -> List[Buffer]: @@ -102,7 +63,43 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz pass return acted_lins -def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops]) +def try_compile_linearized_w_idx(x): + try: return (x[0], compile_linearizer(Device.DEFAULT, x[1], "test")) + except Exception: + if DEBUG >= 4: traceback.print_exc() + return (x[0], None) + +def compile_linearizer(dev:str, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]: + lin.linearize() + rdev = Device[dev] + assert isinstance(rdev, Compiled) + src, _ = rdev.renderer(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping + return rdev.compiler(src), lin.global_size, lin.local_size + +def time_program(dev:str, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): + rdev = Device[dev] + assert isinstance(rdev, Compiled) + clprg = rdev.runtime(name, lib) + factor = 1 + if global_size is not None: + global_size = [sym_infer(sz, var_vals) for sz in global_size] + [1]*(3-len(global_size)) + if local_size is None: + local_size = optimize_local_size(clprg, global_size, rawbufs) + global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] + else: + local_size = [sym_infer(sz, var_vals) for sz in local_size] + [1]*(3-len(local_size)) + if max_global_size is not None: + global_size, factor = get_test_global_size(global_size, max_global_size=max_global_size) + lra = {} + if global_size: lra['global_size'] = global_size + if local_size: lra['local_size'] = local_size + tms = [] + for _ in range(cnt): + if clear_l2: + with Context(DEBUG=0): Tensor.rand(1024,1024).realize() + tms.append(clprg(*[x._buf for x in rawbufs], **lra, vals=var_vals.values(), wait=True)*factor) + if early_stop is not None and early_stop < tms[-1]: break + return tms def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer: key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT} @@ -111,45 +108,43 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea for o in val[len(lin.applied_opts):]: ret.apply_opt(o) return ret - # init the BEAM with the base linearizer - beam: List[Tuple[Linearizer, float]] = [(lin, time_linearizer(lin, rawbufs, allow_test_size=allow_test_size))] + beam: List[Tuple[Linearizer, float]] = [] + seen_libs = set() - # NOTE: real uops use a weird compare method that's only valid inside a linearizer - seen_uops = {tuplize_uops(lin.linearize().uops): tuple(lin.applied_opts)} + default_parallel = 1 if Device.DEFAULT == "HIP" else 0 + pool = multiprocessing.Pool(multiprocessing.cpu_count()) if getenv("PARALLEL", default_parallel) else None + var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} exiting, st = False, time.perf_counter() + dev = Device[Device.DEFAULT] + assert isinstance(dev, Compiled) while not exiting: - with Timing("linearize: ", enabled=DEBUG>=3): - acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) - - # linearize all - for x in acted_lins: x.linearize() - - # dedup with uops - acted_lins_dedup = [] - for lin in acted_lins: - tuops = tuplize_uops(lin.uops) - if tuops in seen_uops: continue - seen_uops[tuops] = tuple(lin.applied_opts) - acted_lins_dedup.append(lin) - - with Timing("compile: ",enabled=DEBUG>=3): - # time linearizers - timed_lins: List[Tuple[Linearizer, float]] = [(v,time_linearizer(v,rawbufs,allow_test_size=allow_test_size)) for v in acted_lins_dedup] - opts = sorted(timed_lins, key=lambda x: x[1]) + acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin,_ in beam]) if len(beam) else [lin] + timed_lins: List[Tuple[Linearizer, float]] = [] + for i,proc in (pool.imap_unordered(try_compile_linearized_w_idx, enumerate(acted_lins)) if pool is not None else map(try_compile_linearized_w_idx, enumerate(acted_lins))): + if proc is None: continue + lib, global_size, local_size = proc + if lib in seen_libs: continue + seen_libs.add(lib) + tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else None) + timed_lins.append((acted_lins[i], min(tms))) + if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # done - exiting = len(opts) == 0 or beam[0][1] <= opts[0][1] + opts = sorted(timed_lins, key=lambda x: x[1]) + exiting = len(opts) == 0 or (len(beam) > 0 and beam[0][1] <= opts[0][1]) if not exiting: beam = opts[:amt] - if DEBUG >= 2: print(f"{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions", beam[0][0].colored_shape()) + assert len(beam) > 0, "no BEAM items succeeded?!?" + if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s:", colored(f"{beam[0][1]*1e6:12.2f} us", "green" if exiting else None), f"from {len(acted_lins):3d} -> {len(opts):3d} actions\033[K", beam[0][0].colored_shape()) + if pool is not None: pool.close() # the pool is closed if CACHELEVEL >= 1: diskcache_put("beam_search", key, beam[0][0].applied_opts) if DEBUG >= 3: print(beam[0][0].applied_opts) return beam[0][0] def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buffer]) -> List[int]: test_rawbuffers = [Buffer(rawbufs[0].device, rawbufs[0].size, rawbufs[0].dtype), *rawbufs[1:]] if rawbufs[0] in rawbufs[1:] else rawbufs - MAX_WORKGROUP = clprg.max_work_group_size() if hasattr(clprg, 'max_work_group_size') else 1024 + MAX_WORKGROUP = 1024 local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in global_size] local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice def try_exec(local_size): @@ -160,3 +155,14 @@ def optimize_local_size(clprg:Callable, global_size:List[int], rawbufs:List[Buff ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) assert not math.isinf(ret[0]), "all optimize_local_size exec failed" return ret[1] + +def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: + key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size, "clear_l2": clear_l2, "device": Device.DEFAULT} + if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) + + var_vals = {k:(k.max+k.min)//2 for k in vars_from_ast(lin.ast)} + lib, global_size, local_size = compile_linearizer(Device.DEFAULT, lin) + tms = time_program(Device.DEFAULT, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) + + if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) + return min(tms) diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 92857411..5c6eda4d 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -45,7 +45,7 @@ class CUDAProgram: def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], vals:Tuple[int, ...]=(), wait=False): if not CUDACPU: check(cuda.cuCtxSetCurrent(self.device.context)) - c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+vals) + c_kernel_input_config = encode_args_cuda_style(bufs, vals, cuda.CUdeviceptr_v2, (1,2,0))[0] if not CUDACPU else (bufs+tuple(vals)) return cu_time_execution(lambda: check(cuda.cuLaunchKernel(self.prg, *global_size, *local_size, 0, None, None, c_kernel_input_config)), enable=wait) class CUDAAllocator(LRUAllocator): diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 5ffeafa9..3099339a 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -131,7 +131,9 @@ class Node: # 4 basic node types class Variable(Node): - def __new__(cls, expr:Optional[str], nmin:int, nmax:int): + def __new__(cls, *args): + if len(args) == 0: return super().__new__(cls) # fix pickle + expr, nmin, nmax = args assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}" if nmin == nmax: return NumNode(nmin) return super().__new__(cls)