mirror of https://github.com/commaai/tinygrad.git
fix ones, BS=2 stable diffusion, caching optimizer (#1312)
* fix ones, BS=2 stable diffusion * caching optimizer * print search time * minor bug fix
This commit is contained in:
parent
9746f6d094
commit
bfbb8d3d0f
|
@ -9,7 +9,7 @@ from collections import namedtuple
|
|||
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
|
@ -614,8 +614,8 @@ if __name__ == "__main__":
|
|||
|
||||
def get_model_output(latent, timestep):
|
||||
# put into diffuser
|
||||
unconditional_latent = model.model.diffusion_model(latent, timestep, unconditional_context)
|
||||
latent = model.model.diffusion_model(latent, timestep, context)
|
||||
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0))
|
||||
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||
|
||||
unconditional_guidance_scale = 7.5
|
||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||
|
@ -647,6 +647,7 @@ if __name__ == "__main__":
|
|||
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
e_t = get_model_output(latent, Tensor([timestep]))
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
|
||||
|
|
|
@ -104,12 +104,14 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
|
|||
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
|
||||
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
||||
assert lidx.max == var.max and lidx.min == var.min
|
||||
return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||
local_size.append(var.max+1)
|
||||
return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
kernel,global_size,local_size,prekernel = [],[],[],[]
|
||||
global_size: List[int] = []
|
||||
local_size: List[int] = []
|
||||
kernel,prekernel = [],[]
|
||||
pend_close = None
|
||||
bufs = []
|
||||
depth = 0
|
||||
|
@ -118,19 +120,13 @@ def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int],
|
|||
for uop,newvar,vin,args in uops:
|
||||
if uop == UOps.LOOP:
|
||||
for i,var in enumerate(args[0]):
|
||||
if isinstance(var, NumNode):
|
||||
if args[1] == "global" and lang.gid: global_size.append(1)
|
||||
if args[1] == "local" and lang.lid: local_size.append(1)
|
||||
# one number, not an index
|
||||
kk("{")
|
||||
if args[1] == "global" and lang.gid:
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
||||
elif args[1] == "local" and lang.lid:
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||
else:
|
||||
if args[1] == "global" and lang.gid:
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
||||
elif args[1] == "local" and lang.lid:
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||
else:
|
||||
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk(lang.render_for(var.expr, var.min, var.max))
|
||||
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Callable
|
||||
import itertools
|
||||
import itertools, time
|
||||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, LazyOp
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
@ -22,9 +22,7 @@ def apply_opt(k, x):
|
|||
|
||||
UPCASTS = [1,2,3,4,5,6,7,8]
|
||||
LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
|
||||
|
||||
# optimization
|
||||
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
||||
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], runtime, baseline):
|
||||
import nevergrad as ng
|
||||
def opt(x):
|
||||
try:
|
||||
|
@ -32,14 +30,15 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
|||
k.process()
|
||||
apply_opt(k, x)
|
||||
prg = k.codegen().build(runtime)
|
||||
tm = min([prg.exec(k.bufs, force_wait=True) for _ in range(3)])*1000
|
||||
first_tm = prg.exec(k.bufs, force_wait=True)
|
||||
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
|
||||
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True) for _ in range(2)])*1000
|
||||
return tm
|
||||
except Exception:
|
||||
if DEBUG >= 3:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 10000_000 # 10000 seconds is infinity
|
||||
k.process()
|
||||
opts = []
|
||||
for i in range(k.first_reduce):
|
||||
# TODO: the upcast always happen first, you might want to reverse this?
|
||||
|
@ -48,12 +47,43 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
|||
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
|
||||
for i in range(k.shape_len-k.first_reduce):
|
||||
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
|
||||
if len(opts) == 0: return
|
||||
if len(opts) == 0: return "BASELINE"
|
||||
search_space = prod([len(x.choices) for x in opts])
|
||||
st = time.perf_counter()
|
||||
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
|
||||
recommendation = optimizer.minimize(opt)
|
||||
apply_opt(k, recommendation.value)
|
||||
if DEBUG >= 1: print("optimizer hit", k.colored_shape(), "in search space", search_space)
|
||||
et = time.perf_counter() - st
|
||||
if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
|
||||
return recommendation.value if recommendation.loss < baseline else "BASELINE"
|
||||
|
||||
# optimization
|
||||
global_db = None
|
||||
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
||||
global global_db
|
||||
|
||||
k.process()
|
||||
skey = str(k.key)
|
||||
|
||||
if getenv("KOPT") == 2 and global_db is None:
|
||||
import shelve
|
||||
global_db = shelve.open("/tmp/kopt_cache")
|
||||
|
||||
if global_db is not None and skey in global_db:
|
||||
choice = global_db[skey]
|
||||
else:
|
||||
# get baseline
|
||||
def get_baseline():
|
||||
k = create_k()
|
||||
hand_coded_optimizations(k)
|
||||
prg = k.codegen().build(runtime)
|
||||
return min([prg.exec(k.bufs, force_wait=True) for _ in range(5)])*1000
|
||||
choice = kernel_optimize_search(k, create_k, runtime, get_baseline())
|
||||
if global_db is not None:
|
||||
global_db[skey] = choice
|
||||
global_db.sync()
|
||||
|
||||
if choice == "BASELINE": hand_coded_optimizations(k)
|
||||
else: apply_opt(k, choice)
|
||||
|
||||
def required_optimizations(k:Linearizer, early_only=False):
|
||||
for buf_index,buf in enumerate(k.bufs):
|
||||
|
|
|
@ -143,7 +143,7 @@ class ASTRunner:
|
|||
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
|
||||
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += self.op_estimate
|
||||
|
|
Loading…
Reference in New Issue