tinygrad/examples/handcode_resnet50_opt.py

80 lines
3.0 KiB
Python

from typing import List
from models.resnet import ResNet50
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.runtime.ops_metal import renderer, MetalProgram, RawMetalBuffer
from tinygrad.helpers import ansilen, DEBUG
from extra.utils import print_tree
if __name__ == "__main__":
mdl = ResNet50()
seen = set()
# first model run to init the weights, they are saved in seen
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = [x for x in sched if x[0].op not in LoadOps]
# work with the schedule
total_tm = 0
for i,(op,out,inp) in enumerate(sched):
if DEBUG >= 2: print_tree(op)
# enable only one kernel to focus on it
#if i != 1: continue
# "linearize" the op into uops in different ways
lins:List[Linearizer] = []
if i == 1:
# through careful work, we discovered 1,8,0
for big_chomp in [1,2]: #[1,2,4,8,16]:
for lil_chomp in [2,4,7,8,14]:
for upcasted in [0,1,2]:
lin = Linearizer(op, LinearizerOptions(device="METAL"))
lin.reshape_and_permute(lambda x: (4096//big_chomp,big_chomp,56//lil_chomp,lil_chomp,56//lil_chomp,lil_chomp)+x[-2:], [0,2,4,1,3,5,6,7])
lin.upcasted += upcasted
lin.local_dims += 3
lins.append(lin)
else:
# try with and without tensor cores
for tc in [0,1]:
lin = Linearizer(op, LinearizerOptions(device="METAL"))
lin.hand_coded_optimizations(use_tensor_cores=tc)
lins.append(lin)
# create output/input buffers
rout = RawMetalBuffer(out.st.size(), out.dtype)
rin = [RawMetalBuffer(x.st.size(), x.dtype) for x in inp]
# benchmark the programs
choices = []
for lin in lins:
# render the code and create the program
lin.linearize()
code = renderer(lin.function_name, lin.uops)
prg = MetalProgram(lin.function_name, code)
# print the kernel code if you want
#print(code)
# benchmark it by running 10 times
try:
tm = min([prg(lin.global_size, lin.local_size, rout, *rin, wait=True) for _ in range(10)])
choices.append((tm, lin))
except AssertionError:
tm = float('inf')
# print all kernels
if DEBUG >= 1: print(f" kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {lin.info.flops*1e-9/tm:6.0f} GFLOPS")
tm, lin = sorted(choices, key=lambda x: x[0])[0]
print(f"*** {total_tm*1000:7.2f} ms : kernel {i:2d} {lin.display_name+' '*(37-ansilen(lin.display_name))} {str(lin.global_size):18s} {str(lin.local_size):12s} takes {tm*1000:7.2f} ms, {lin.info.flops*1e-9/tm:6.0f} GFLOPS")
total_tm += tm
print(f"******* total {total_tm*1000:.2f} ms")