From bfbb8d3d0f085cb547c57fa82405f837b71f7864 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 21 Jul 2023 09:55:49 -0700 Subject: [PATCH] fix ones, BS=2 stable diffusion, caching optimizer (#1312) * fix ones, BS=2 stable diffusion * caching optimizer * print search time * minor bug fix --- examples/stable_diffusion.py | 7 ++--- tinygrad/codegen/cstyle.py | 26 ++++++++----------- tinygrad/codegen/optimizer.py | 48 ++++++++++++++++++++++++++++------- tinygrad/ops.py | 2 +- 4 files changed, 55 insertions(+), 28 deletions(-) diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index b51447a7..2293f6f6 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -9,7 +9,7 @@ from collections import namedtuple from tqdm import tqdm from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes +from tinygrad.helpers import dtypes, GlobalCounters from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from extra.utils import download_file from tinygrad.state import torch_load, load_state_dict @@ -614,8 +614,8 @@ if __name__ == "__main__": def get_model_output(latent, timestep): # put into diffuser - unconditional_latent = model.model.diffusion_model(latent, timestep, unconditional_context) - latent = model.model.diffusion_model(latent, timestep, context) + latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0)) + unconditional_latent, latent = latents[0:1], latents[1:2] unconditional_guidance_scale = 7.5 e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent) @@ -647,6 +647,7 @@ if __name__ == "__main__": # this is diffusion for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): + GlobalCounters.reset() t.set_description("%3d %3d" % (index, timestep)) e_t = get_model_output(latent, Tensor([timestep])) x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index) diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index b42b9873..26f7a833 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -104,12 +104,14 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1) lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1) assert lidx.max == var.max and lidx.min == var.min - return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */" + return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */" local_size.append(var.max+1) - return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */" + return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */" def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: - kernel,global_size,local_size,prekernel = [],[],[],[] + global_size: List[int] = [] + local_size: List[int] = [] + kernel,prekernel = [],[] pend_close = None bufs = [] depth = 0 @@ -118,19 +120,13 @@ def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], for uop,newvar,vin,args in uops: if uop == UOps.LOOP: for i,var in enumerate(args[0]): - if isinstance(var, NumNode): - if args[1] == "global" and lang.gid: global_size.append(1) - if args[1] == "local" and lang.lid: local_size.append(1) - # one number, not an index - kk("{") + if args[1] == "global" and lang.gid: + kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid)) + elif args[1] == "local" and lang.lid: + kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid)) else: - if args[1] == "global" and lang.gid: - kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid)) - elif args[1] == "local" and lang.lid: - kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid)) - else: - if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling - kk(lang.render_for(var.expr, var.min, var.max)) + if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling + kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max)) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) diff --git a/tinygrad/codegen/optimizer.py b/tinygrad/codegen/optimizer.py index 813875d6..faed4055 100644 --- a/tinygrad/codegen/optimizer.py +++ b/tinygrad/codegen/optimizer.py @@ -1,5 +1,5 @@ from typing import Callable -import itertools +import itertools, time from tinygrad.helpers import DEBUG, prod, getenv, ImageDType from tinygrad.ops import ReduceOps, BinaryOps, LazyOp from tinygrad.codegen.linearizer import Linearizer @@ -22,9 +22,7 @@ def apply_opt(k, x): UPCASTS = [1,2,3,4,5,6,7,8] LOCALS = [1,2,3,4,5,6,7,8,16,24,32] - -# optimization -def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime): +def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], runtime, baseline): import nevergrad as ng def opt(x): try: @@ -32,14 +30,15 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime): k.process() apply_opt(k, x) prg = k.codegen().build(runtime) - tm = min([prg.exec(k.bufs, force_wait=True) for _ in range(3)])*1000 + first_tm = prg.exec(k.bufs, force_wait=True) + if baseline*5 < first_tm*1000: return first_tm*1000 # very slow + tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True) for _ in range(2)])*1000 return tm except Exception: if DEBUG >= 3: import traceback traceback.print_exc() return 10000_000 # 10000 seconds is infinity - k.process() opts = [] for i in range(k.first_reduce): # TODO: the upcast always happen first, you might want to reverse this? @@ -48,12 +47,43 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime): opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0])) for i in range(k.shape_len-k.first_reduce): opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0])) - if len(opts) == 0: return + if len(opts) == 0: return "BASELINE" search_space = prod([len(x.choices) for x in opts]) + st = time.perf_counter() optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200)) recommendation = optimizer.minimize(opt) - apply_opt(k, recommendation.value) - if DEBUG >= 1: print("optimizer hit", k.colored_shape(), "in search space", search_space) + et = time.perf_counter() - st + if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}") + return recommendation.value if recommendation.loss < baseline else "BASELINE" + +# optimization +global_db = None +def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime): + global global_db + + k.process() + skey = str(k.key) + + if getenv("KOPT") == 2 and global_db is None: + import shelve + global_db = shelve.open("/tmp/kopt_cache") + + if global_db is not None and skey in global_db: + choice = global_db[skey] + else: + # get baseline + def get_baseline(): + k = create_k() + hand_coded_optimizations(k) + prg = k.codegen().build(runtime) + return min([prg.exec(k.bufs, force_wait=True) for _ in range(5)])*1000 + choice = kernel_optimize_search(k, create_k, runtime, get_baseline()) + if global_db is not None: + global_db[skey] = choice + global_db.sync() + + if choice == "BASELINE": hand_coded_optimizations(k) + else: apply_opt(k, choice) def required_optimizations(k:Linearizer, early_only=False): for buf_index,buf in enumerate(k.bufs): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e74544c7..78647dca 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -143,7 +143,7 @@ class ASTRunner: (self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None, *rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) GlobalCounters.kernel_count += 1 GlobalCounters.global_ops += self.op_estimate