fix ones, BS=2 stable diffusion, caching optimizer (#1312)

* fix ones, BS=2 stable diffusion

* caching optimizer

* print search time

* minor bug fix
This commit is contained in:
George Hotz 2023-07-21 09:55:49 -07:00 committed by GitHub
parent 9746f6d094
commit bfbb8d3d0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 55 additions and 28 deletions

View File

@ -9,7 +9,7 @@ from collections import namedtuple
from tqdm import tqdm from tqdm import tqdm
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes from tinygrad.helpers import dtypes, GlobalCounters
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict from tinygrad.state import torch_load, load_state_dict
@ -614,8 +614,8 @@ if __name__ == "__main__":
def get_model_output(latent, timestep): def get_model_output(latent, timestep):
# put into diffuser # put into diffuser
unconditional_latent = model.model.diffusion_model(latent, timestep, unconditional_context) latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0))
latent = model.model.diffusion_model(latent, timestep, context) unconditional_latent, latent = latents[0:1], latents[1:2]
unconditional_guidance_scale = 7.5 unconditional_guidance_scale = 7.5
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent) e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
@ -647,6 +647,7 @@ if __name__ == "__main__":
# this is diffusion # this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])): for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep)) t.set_description("%3d %3d" % (index, timestep))
e_t = get_model_output(latent, Tensor([timestep])) e_t = get_model_output(latent, Tensor([timestep]))
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index) x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)

View File

@ -104,12 +104,14 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1) lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1) lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min assert lidx.max == var.max and lidx.min == var.min
return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */" return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
local_size.append(var.max+1) local_size.append(var.max+1)
return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */" return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
kernel,global_size,local_size,prekernel = [],[],[],[] global_size: List[int] = []
local_size: List[int] = []
kernel,prekernel = [],[]
pend_close = None pend_close = None
bufs = [] bufs = []
depth = 0 depth = 0
@ -118,19 +120,13 @@ def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int],
for uop,newvar,vin,args in uops: for uop,newvar,vin,args in uops:
if uop == UOps.LOOP: if uop == UOps.LOOP:
for i,var in enumerate(args[0]): for i,var in enumerate(args[0]):
if isinstance(var, NumNode): if args[1] == "global" and lang.gid:
if args[1] == "global" and lang.gid: global_size.append(1) kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
if args[1] == "local" and lang.lid: local_size.append(1) elif args[1] == "local" and lang.lid:
# one number, not an index kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
kk("{")
else: else:
if args[1] == "global" and lang.gid: if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid)) kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
kk(lang.render_for(var.expr, var.min, var.max))
depth += 1 depth += 1
elif uop == UOps.BARRIER: elif uop == UOps.BARRIER:
kk(lang.barrier) kk(lang.barrier)

View File

@ -1,5 +1,5 @@
from typing import Callable from typing import Callable
import itertools import itertools, time
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
from tinygrad.ops import ReduceOps, BinaryOps, LazyOp from tinygrad.ops import ReduceOps, BinaryOps, LazyOp
from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.linearizer import Linearizer
@ -22,9 +22,7 @@ def apply_opt(k, x):
UPCASTS = [1,2,3,4,5,6,7,8] UPCASTS = [1,2,3,4,5,6,7,8]
LOCALS = [1,2,3,4,5,6,7,8,16,24,32] LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], runtime, baseline):
# optimization
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
import nevergrad as ng import nevergrad as ng
def opt(x): def opt(x):
try: try:
@ -32,14 +30,15 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
k.process() k.process()
apply_opt(k, x) apply_opt(k, x)
prg = k.codegen().build(runtime) prg = k.codegen().build(runtime)
tm = min([prg.exec(k.bufs, force_wait=True) for _ in range(3)])*1000 first_tm = prg.exec(k.bufs, force_wait=True)
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True) for _ in range(2)])*1000
return tm return tm
except Exception: except Exception:
if DEBUG >= 3: if DEBUG >= 3:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return 10000_000 # 10000 seconds is infinity return 10000_000 # 10000 seconds is infinity
k.process()
opts = [] opts = []
for i in range(k.first_reduce): for i in range(k.first_reduce):
# TODO: the upcast always happen first, you might want to reverse this? # TODO: the upcast always happen first, you might want to reverse this?
@ -48,12 +47,43 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0])) opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
for i in range(k.shape_len-k.first_reduce): for i in range(k.shape_len-k.first_reduce):
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0])) opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
if len(opts) == 0: return if len(opts) == 0: return "BASELINE"
search_space = prod([len(x.choices) for x in opts]) search_space = prod([len(x.choices) for x in opts])
st = time.perf_counter()
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200)) optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
recommendation = optimizer.minimize(opt) recommendation = optimizer.minimize(opt)
apply_opt(k, recommendation.value) et = time.perf_counter() - st
if DEBUG >= 1: print("optimizer hit", k.colored_shape(), "in search space", search_space) if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
return recommendation.value if recommendation.loss < baseline else "BASELINE"
# optimization
global_db = None
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
global global_db
k.process()
skey = str(k.key)
if getenv("KOPT") == 2 and global_db is None:
import shelve
global_db = shelve.open("/tmp/kopt_cache")
if global_db is not None and skey in global_db:
choice = global_db[skey]
else:
# get baseline
def get_baseline():
k = create_k()
hand_coded_optimizations(k)
prg = k.codegen().build(runtime)
return min([prg.exec(k.bufs, force_wait=True) for _ in range(5)])*1000
choice = kernel_optimize_search(k, create_k, runtime, get_baseline())
if global_db is not None:
global_db[skey] = choice
global_db.sync()
if choice == "BASELINE": hand_coded_optimizations(k)
else: apply_opt(k, choice)
def required_optimizations(k:Linearizer, early_only=False): def required_optimizations(k:Linearizer, early_only=False):
for buf_index,buf in enumerate(k.bufs): for buf_index,buf in enumerate(k.bufs):

View File

@ -143,7 +143,7 @@ class ASTRunner:
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None, (self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et *rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
if DEBUG >= 2: if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1 GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate GlobalCounters.global_ops += self.op_estimate