fix ones, BS=2 stable diffusion, caching optimizer (#1312)

* fix ones, BS=2 stable diffusion

* caching optimizer

* print search time

* minor bug fix
This commit is contained in:
George Hotz 2023-07-21 09:55:49 -07:00 committed by GitHub
parent 9746f6d094
commit bfbb8d3d0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 28 deletions

View File

@ -9,7 +9,7 @@ from collections import namedtuple
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad.helpers import dtypes, GlobalCounters
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict
@ -614,8 +614,8 @@ if __name__ == "__main__":
def get_model_output(latent, timestep):
# put into diffuser
unconditional_latent = model.model.diffusion_model(latent, timestep, unconditional_context)
latent = model.model.diffusion_model(latent, timestep, context)
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0))
unconditional_latent, latent = latents[0:1], latents[1:2]
unconditional_guidance_scale = 7.5
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
@ -647,6 +647,7 @@ if __name__ == "__main__":
# this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
e_t = get_model_output(latent, Tensor([timestep]))
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)

View File

@ -104,12 +104,14 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min
return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
local_size.append(var.max+1)
return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
kernel,global_size,local_size,prekernel = [],[],[],[]
global_size: List[int] = []
local_size: List[int] = []
kernel,prekernel = [],[]
pend_close = None
bufs = []
depth = 0
@ -118,19 +120,13 @@ def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int],
for uop,newvar,vin,args in uops:
if uop == UOps.LOOP:
for i,var in enumerate(args[0]):
if isinstance(var, NumNode):
if args[1] == "global" and lang.gid: global_size.append(1)
if args[1] == "local" and lang.lid: local_size.append(1)
# one number, not an index
kk("{")
if args[1] == "global" and lang.gid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if args[1] == "global" and lang.gid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
kk(lang.render_for(var.expr, var.min, var.max))
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)

View File

@ -1,5 +1,5 @@
from typing import Callable
import itertools
import itertools, time
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
from tinygrad.ops import ReduceOps, BinaryOps, LazyOp
from tinygrad.codegen.linearizer import Linearizer
@ -22,9 +22,7 @@ def apply_opt(k, x):
UPCASTS = [1,2,3,4,5,6,7,8]
LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
# optimization
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], runtime, baseline):
import nevergrad as ng
def opt(x):
try:
@ -32,14 +30,15 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
k.process()
apply_opt(k, x)
prg = k.codegen().build(runtime)
tm = min([prg.exec(k.bufs, force_wait=True) for _ in range(3)])*1000
first_tm = prg.exec(k.bufs, force_wait=True)
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True) for _ in range(2)])*1000
return tm
except Exception:
if DEBUG >= 3:
import traceback
traceback.print_exc()
return 10000_000 # 10000 seconds is infinity
k.process()
opts = []
for i in range(k.first_reduce):
# TODO: the upcast always happen first, you might want to reverse this?
@ -48,12 +47,43 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
for i in range(k.shape_len-k.first_reduce):
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
if len(opts) == 0: return
if len(opts) == 0: return "BASELINE"
search_space = prod([len(x.choices) for x in opts])
st = time.perf_counter()
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
recommendation = optimizer.minimize(opt)
apply_opt(k, recommendation.value)
if DEBUG >= 1: print("optimizer hit", k.colored_shape(), "in search space", search_space)
et = time.perf_counter() - st
if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
return recommendation.value if recommendation.loss < baseline else "BASELINE"
# optimization
global_db = None
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
global global_db
k.process()
skey = str(k.key)
if getenv("KOPT") == 2 and global_db is None:
import shelve
global_db = shelve.open("/tmp/kopt_cache")
if global_db is not None and skey in global_db:
choice = global_db[skey]
else:
# get baseline
def get_baseline():
k = create_k()
hand_coded_optimizations(k)
prg = k.codegen().build(runtime)
return min([prg.exec(k.bufs, force_wait=True) for _ in range(5)])*1000
choice = kernel_optimize_search(k, create_k, runtime, get_baseline())
if global_db is not None:
global_db[skey] = choice
global_db.sync()
if choice == "BASELINE": hand_coded_optimizations(k)
else: apply_opt(k, choice)
def required_optimizations(k:Linearizer, early_only=False):
for buf_index,buf in enumerate(k.bufs):

View File

@ -143,7 +143,7 @@ class ASTRunner:
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate