2024-07-05 23:01:20 +08:00
|
|
|
from typing import List
|
2024-07-13 03:05:01 +08:00
|
|
|
from extra.models.resnet import ResNet50
|
2024-07-20 12:38:39 +08:00
|
|
|
from extra.mcts_search import mcts_search
|
2024-07-05 23:01:20 +08:00
|
|
|
from examples.mlperf.helpers import get_mlperf_bert_model
|
|
|
|
from tinygrad import Tensor, Device, dtypes, nn
|
2024-07-13 09:50:55 +08:00
|
|
|
from tinygrad.codegen.kernel import Kernel
|
2024-07-05 23:01:20 +08:00
|
|
|
from tinygrad.device import Compiled
|
|
|
|
from tinygrad.engine.schedule import create_schedule
|
|
|
|
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
2024-07-21 00:39:36 +08:00
|
|
|
from tinygrad.helpers import DEBUG, ansilen, getenv, colored
|
2024-07-19 18:05:33 +08:00
|
|
|
from tinygrad.ops import MetaOps
|
2024-07-05 23:01:20 +08:00
|
|
|
from tinygrad.shape.symbolic import sym_infer
|
|
|
|
|
2024-07-13 03:05:01 +08:00
|
|
|
def get_sched_resnet():
|
|
|
|
mdl = ResNet50()
|
|
|
|
optim = (nn.optim.LARS if getenv("LARS") else nn.optim.SGD)(nn.state.get_parameters(mdl))
|
2024-07-13 08:52:28 +08:00
|
|
|
BS = getenv("BS", 64)
|
2024-07-05 23:01:20 +08:00
|
|
|
|
2024-07-13 03:05:01 +08:00
|
|
|
# run model twice to get only what changes, these are the kernels of the model
|
2024-07-05 23:01:20 +08:00
|
|
|
seen = set()
|
2024-07-13 03:05:01 +08:00
|
|
|
for _ in range(2):
|
2024-07-13 08:52:28 +08:00
|
|
|
out = mdl(Tensor.empty(BS, 3, 224, 224))
|
2024-07-13 03:05:01 +08:00
|
|
|
targets = [out.lazydata]
|
|
|
|
if getenv("BACKWARD"):
|
|
|
|
optim.zero_grad()
|
2024-07-13 08:52:28 +08:00
|
|
|
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
|
2024-07-13 03:05:01 +08:00
|
|
|
targets += [x.lazydata for x in optim.schedule_step()]
|
|
|
|
sched = create_schedule(targets, seen)
|
|
|
|
print(f"schedule length {len(sched)}")
|
|
|
|
return sched
|
2024-07-05 23:01:20 +08:00
|
|
|
|
2024-07-13 03:05:01 +08:00
|
|
|
def get_sched_bert():
|
|
|
|
mdl = get_mlperf_bert_model()
|
|
|
|
optim = nn.optim.LAMB(nn.state.get_parameters(mdl))
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# fake data
|
|
|
|
BS = getenv("BS", 2)
|
|
|
|
input_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
|
|
|
segment_ids = Tensor.empty((BS, 512), dtype=dtypes.float32)
|
|
|
|
attention_mask = Tensor.empty((BS, 512), dtype=dtypes.default_float)
|
2024-07-08 21:07:44 +08:00
|
|
|
masked_positions = Tensor.empty((BS, 76), dtype=dtypes.float32)
|
|
|
|
masked_lm_ids = Tensor.empty((BS, 76), dtype=dtypes.float32)
|
|
|
|
masked_lm_weights = Tensor.empty((BS, 76), dtype=dtypes.float32)
|
2024-07-05 23:01:20 +08:00
|
|
|
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
|
|
|
|
|
|
|
|
# run model twice to get only what changes, these are the kernels of the model
|
2024-07-13 03:05:01 +08:00
|
|
|
seen = set()
|
|
|
|
for _ in range(2):
|
2024-07-05 23:01:20 +08:00
|
|
|
lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
|
|
|
|
targets = [lm_logits.lazydata, seq_relationship_logits.lazydata]
|
|
|
|
if getenv("BACKWARD"):
|
|
|
|
optim.zero_grad()
|
|
|
|
loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
|
|
|
# ignore grad norm and loss scaler for now
|
|
|
|
loss.backward()
|
|
|
|
targets += [x.lazydata for x in optim.schedule_step()]
|
|
|
|
sched = create_schedule(targets, seen)
|
|
|
|
print(f"schedule length {len(sched)}")
|
2024-07-13 03:05:01 +08:00
|
|
|
return sched
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if getenv("HALF", 1):
|
|
|
|
dtypes.default_float = dtypes.half
|
|
|
|
|
|
|
|
# the device we are optimizing for
|
|
|
|
device: Compiled = Device[Device.DEFAULT]
|
|
|
|
if getenv("BACKWARD"): Tensor.training = True
|
|
|
|
print(f"optimizing for {Device.DEFAULT}")
|
|
|
|
|
|
|
|
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
|
2024-07-18 10:41:23 +08:00
|
|
|
sched = [x for x in sched if x.ast.op is MetaOps.KERNEL]
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# focus on one kernel
|
|
|
|
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
|
|
|
|
|
|
|
# work with the schedule
|
|
|
|
total_tm = 0
|
|
|
|
running_gflops = 0
|
2024-07-13 02:21:18 +08:00
|
|
|
usage = {}
|
2024-07-05 23:01:20 +08:00
|
|
|
for i,si in enumerate(sched):
|
2024-07-20 12:38:39 +08:00
|
|
|
if DEBUG >= 3: print(si.ast)
|
2024-07-05 23:01:20 +08:00
|
|
|
|
2024-07-13 09:50:55 +08:00
|
|
|
rawbufs = bufs_from_lin(Kernel(si.ast))
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# "linearize" the op into uops in different ways
|
2024-07-13 09:50:55 +08:00
|
|
|
lins:List[Kernel] = []
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# always try hand coded opt
|
2024-07-13 09:50:55 +08:00
|
|
|
lin = Kernel(si.ast, opts=device.renderer)
|
2024-07-05 23:01:20 +08:00
|
|
|
lin.hand_coded_optimizations()
|
2024-07-20 12:38:39 +08:00
|
|
|
lins.append((lin, "HC"))
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# maybe try tensor cores
|
2024-07-13 09:50:55 +08:00
|
|
|
lin = Kernel(si.ast, opts=device.renderer)
|
2024-07-05 23:01:20 +08:00
|
|
|
if lin.apply_tensor_cores():
|
2024-07-20 12:38:39 +08:00
|
|
|
lins.append((lin, "TC"))
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# try a beam search
|
|
|
|
if beam:=getenv("BEAM"):
|
2024-07-13 09:50:55 +08:00
|
|
|
lin = Kernel(si.ast, opts=device.renderer)
|
2024-07-05 23:01:20 +08:00
|
|
|
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
|
2024-07-20 12:38:39 +08:00
|
|
|
lins.append((lin, "BEAM"))
|
|
|
|
|
|
|
|
# try MCTS
|
|
|
|
if mcts:=getenv("MCTS"):
|
|
|
|
lin = Kernel(si.ast, opts=device.renderer)
|
|
|
|
lin = mcts_search(lin, rawbufs, mcts)
|
|
|
|
lins.append((lin, "MCTS"))
|
2024-07-05 23:01:20 +08:00
|
|
|
|
|
|
|
# benchmark the programs
|
|
|
|
choices = []
|
2024-07-20 12:38:39 +08:00
|
|
|
for (lin, nm) in lins:
|
2024-07-21 10:43:50 +08:00
|
|
|
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
2024-07-26 04:55:08 +08:00
|
|
|
ops = (prg:=lin.to_program()).op_estimate
|
2024-07-13 08:52:28 +08:00
|
|
|
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
|
2024-07-26 04:55:08 +08:00
|
|
|
choices.append((tm, gflops, lin, prg, nm))
|
2024-07-05 23:01:20 +08:00
|
|
|
|
2024-07-21 00:39:36 +08:00
|
|
|
sorted_choices = sorted(choices, key=lambda x: x[0])
|
|
|
|
if DEBUG >= 1: # print all kernels
|
2024-07-26 04:55:08 +08:00
|
|
|
for tm, gflops, lin, prg, nm in choices:
|
|
|
|
print(f" kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS -- {colored(nm, 'green') if lin is sorted_choices[0][2] else nm}")
|
2024-07-21 00:39:36 +08:00
|
|
|
|
2024-07-26 04:55:08 +08:00
|
|
|
tm, gflops, lin, prg, nm = sorted_choices[0]
|
2024-07-21 11:36:32 +08:00
|
|
|
if getenv("SRC"):
|
|
|
|
print(si.ast)
|
|
|
|
print(lin.applied_opts)
|
|
|
|
print(lin.to_program().src)
|
2024-07-05 23:01:20 +08:00
|
|
|
total_tm += tm
|
|
|
|
running_gflops += gflops * tm
|
2024-07-13 02:21:18 +08:00
|
|
|
if (key := str([str(m) for m in si.metadata] if si.metadata is not None else None)) not in usage: usage[key] = (0, 0)
|
|
|
|
usage[key] = (usage[key][0] + tm, usage[key][1] + 1)
|
2024-07-26 04:55:08 +08:00
|
|
|
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.name+' '*(37-ansilen(lin.name))} {str(prg.global_size):18s} {str(prg.local_size):12s} takes {tm*1000:7.2f} ms, {gflops:6.0f} GFLOPS {[str(m) for m in si.metadata] if si.metadata is not None else ''}")
|
2024-07-05 23:01:20 +08:00
|
|
|
print(f"******* total {total_tm*1000:.2f} ms, {running_gflops/total_tm:6.0f} GFLOPS")
|
2024-07-13 02:21:18 +08:00
|
|
|
print("usage:")
|
|
|
|
for k in sorted(usage, key=lambda x: -usage[x][0])[:10]:
|
|
|
|
print(f"{usage[k][0]*1000:.2f} ms: {k} ({usage[k][1]} times)")
|