mirror of https://github.com/commaai/tinygrad.git
fix ones, BS=2 stable diffusion, caching optimizer (#1312)
* fix ones, BS=2 stable diffusion * caching optimizer * print search time * minor bug fix
This commit is contained in:
parent
9746f6d094
commit
bfbb8d3d0f
|
@ -9,7 +9,7 @@ from collections import namedtuple
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import dtypes, GlobalCounters
|
||||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||||
from extra.utils import download_file
|
from extra.utils import download_file
|
||||||
from tinygrad.state import torch_load, load_state_dict
|
from tinygrad.state import torch_load, load_state_dict
|
||||||
|
@ -614,8 +614,8 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
def get_model_output(latent, timestep):
|
def get_model_output(latent, timestep):
|
||||||
# put into diffuser
|
# put into diffuser
|
||||||
unconditional_latent = model.model.diffusion_model(latent, timestep, unconditional_context)
|
latents = model.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep.expand(2, *timestep.shape[1:]), unconditional_context.cat(context, dim=0))
|
||||||
latent = model.model.diffusion_model(latent, timestep, context)
|
unconditional_latent, latent = latents[0:1], latents[1:2]
|
||||||
|
|
||||||
unconditional_guidance_scale = 7.5
|
unconditional_guidance_scale = 7.5
|
||||||
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
|
||||||
|
@ -647,6 +647,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# this is diffusion
|
# this is diffusion
|
||||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||||
|
GlobalCounters.reset()
|
||||||
t.set_description("%3d %3d" % (index, timestep))
|
t.set_description("%3d %3d" % (index, timestep))
|
||||||
e_t = get_model_output(latent, Tensor([timestep]))
|
e_t = get_model_output(latent, Tensor([timestep]))
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
|
||||||
|
|
|
@ -104,12 +104,14 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
|
||||||
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
|
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
|
||||||
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
||||||
assert lidx.max == var.max and lidx.min == var.min
|
assert lidx.max == var.max and lidx.min == var.min
|
||||||
return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||||
local_size.append(var.max+1)
|
local_size.append(var.max+1)
|
||||||
return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||||
|
|
||||||
def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||||
kernel,global_size,local_size,prekernel = [],[],[],[]
|
global_size: List[int] = []
|
||||||
|
local_size: List[int] = []
|
||||||
|
kernel,prekernel = [],[]
|
||||||
pend_close = None
|
pend_close = None
|
||||||
bufs = []
|
bufs = []
|
||||||
depth = 0
|
depth = 0
|
||||||
|
@ -118,19 +120,13 @@ def uops_to_cstyle(uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int],
|
||||||
for uop,newvar,vin,args in uops:
|
for uop,newvar,vin,args in uops:
|
||||||
if uop == UOps.LOOP:
|
if uop == UOps.LOOP:
|
||||||
for i,var in enumerate(args[0]):
|
for i,var in enumerate(args[0]):
|
||||||
if isinstance(var, NumNode):
|
|
||||||
if args[1] == "global" and lang.gid: global_size.append(1)
|
|
||||||
if args[1] == "local" and lang.lid: local_size.append(1)
|
|
||||||
# one number, not an index
|
|
||||||
kk("{")
|
|
||||||
else:
|
|
||||||
if args[1] == "global" and lang.gid:
|
if args[1] == "global" and lang.gid:
|
||||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
||||||
elif args[1] == "local" and lang.lid:
|
elif args[1] == "local" and lang.lid:
|
||||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||||
else:
|
else:
|
||||||
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
|
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||||
kk(lang.render_for(var.expr, var.min, var.max))
|
kk("{" if isinstance(var, NumNode) else lang.render_for(var.expr, var.min, var.max))
|
||||||
depth += 1
|
depth += 1
|
||||||
elif uop == UOps.BARRIER:
|
elif uop == UOps.BARRIER:
|
||||||
kk(lang.barrier)
|
kk(lang.barrier)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import itertools
|
import itertools, time
|
||||||
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
|
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
|
||||||
from tinygrad.ops import ReduceOps, BinaryOps, LazyOp
|
from tinygrad.ops import ReduceOps, BinaryOps, LazyOp
|
||||||
from tinygrad.codegen.linearizer import Linearizer
|
from tinygrad.codegen.linearizer import Linearizer
|
||||||
|
@ -22,9 +22,7 @@ def apply_opt(k, x):
|
||||||
|
|
||||||
UPCASTS = [1,2,3,4,5,6,7,8]
|
UPCASTS = [1,2,3,4,5,6,7,8]
|
||||||
LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
|
LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
|
||||||
|
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], runtime, baseline):
|
||||||
# optimization
|
|
||||||
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
|
||||||
import nevergrad as ng
|
import nevergrad as ng
|
||||||
def opt(x):
|
def opt(x):
|
||||||
try:
|
try:
|
||||||
|
@ -32,14 +30,15 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
||||||
k.process()
|
k.process()
|
||||||
apply_opt(k, x)
|
apply_opt(k, x)
|
||||||
prg = k.codegen().build(runtime)
|
prg = k.codegen().build(runtime)
|
||||||
tm = min([prg.exec(k.bufs, force_wait=True) for _ in range(3)])*1000
|
first_tm = prg.exec(k.bufs, force_wait=True)
|
||||||
|
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
|
||||||
|
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True) for _ in range(2)])*1000
|
||||||
return tm
|
return tm
|
||||||
except Exception:
|
except Exception:
|
||||||
if DEBUG >= 3:
|
if DEBUG >= 3:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return 10000_000 # 10000 seconds is infinity
|
return 10000_000 # 10000 seconds is infinity
|
||||||
k.process()
|
|
||||||
opts = []
|
opts = []
|
||||||
for i in range(k.first_reduce):
|
for i in range(k.first_reduce):
|
||||||
# TODO: the upcast always happen first, you might want to reverse this?
|
# TODO: the upcast always happen first, you might want to reverse this?
|
||||||
|
@ -48,12 +47,43 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
||||||
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
|
opts.append(ng.p.TransitionChoice([(i,s,"L") for s in LOCALS if k.full_shape[i]%s == 0]))
|
||||||
for i in range(k.shape_len-k.first_reduce):
|
for i in range(k.shape_len-k.first_reduce):
|
||||||
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
|
opts.append(ng.p.TransitionChoice([(i,s,"R") for s in UPCASTS if k.full_shape[k.first_reduce+i]%s == 0]))
|
||||||
if len(opts) == 0: return
|
if len(opts) == 0: return "BASELINE"
|
||||||
search_space = prod([len(x.choices) for x in opts])
|
search_space = prod([len(x.choices) for x in opts])
|
||||||
|
st = time.perf_counter()
|
||||||
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
|
optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, 200))
|
||||||
recommendation = optimizer.minimize(opt)
|
recommendation = optimizer.minimize(opt)
|
||||||
apply_opt(k, recommendation.value)
|
et = time.perf_counter() - st
|
||||||
if DEBUG >= 1: print("optimizer hit", k.colored_shape(), "in search space", search_space)
|
if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}")
|
||||||
|
return recommendation.value if recommendation.loss < baseline else "BASELINE"
|
||||||
|
|
||||||
|
# optimization
|
||||||
|
global_db = None
|
||||||
|
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
|
||||||
|
global global_db
|
||||||
|
|
||||||
|
k.process()
|
||||||
|
skey = str(k.key)
|
||||||
|
|
||||||
|
if getenv("KOPT") == 2 and global_db is None:
|
||||||
|
import shelve
|
||||||
|
global_db = shelve.open("/tmp/kopt_cache")
|
||||||
|
|
||||||
|
if global_db is not None and skey in global_db:
|
||||||
|
choice = global_db[skey]
|
||||||
|
else:
|
||||||
|
# get baseline
|
||||||
|
def get_baseline():
|
||||||
|
k = create_k()
|
||||||
|
hand_coded_optimizations(k)
|
||||||
|
prg = k.codegen().build(runtime)
|
||||||
|
return min([prg.exec(k.bufs, force_wait=True) for _ in range(5)])*1000
|
||||||
|
choice = kernel_optimize_search(k, create_k, runtime, get_baseline())
|
||||||
|
if global_db is not None:
|
||||||
|
global_db[skey] = choice
|
||||||
|
global_db.sync()
|
||||||
|
|
||||||
|
if choice == "BASELINE": hand_coded_optimizations(k)
|
||||||
|
else: apply_opt(k, choice)
|
||||||
|
|
||||||
def required_optimizations(k:Linearizer, early_only=False):
|
def required_optimizations(k:Linearizer, early_only=False):
|
||||||
for buf_index,buf in enumerate(k.bufs):
|
for buf_index,buf in enumerate(k.bufs):
|
||||||
|
|
|
@ -143,7 +143,7 @@ class ASTRunner:
|
||||||
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
|
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
|
||||||
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
|
||||||
if DEBUG >= 2:
|
if DEBUG >= 2:
|
||||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(29-ansilen(self.display_name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||||
GlobalCounters.kernel_count += 1
|
GlobalCounters.kernel_count += 1
|
||||||
GlobalCounters.global_ops += self.op_estimate
|
GlobalCounters.global_ops += self.op_estimate
|
||||||
|
|
Loading…
Reference in New Issue