2024-08-21 14:36:58 +08:00
|
|
|
from typing import List
|
2024-05-12 02:02:44 +08:00
|
|
|
from extra.models.resnet import ResNet50
|
2024-10-12 22:03:04 +08:00
|
|
|
from tinygrad import Tensor, Device, nn
|
2024-09-30 13:52:33 +08:00
|
|
|
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
|
2024-08-17 06:17:57 +08:00
|
|
|
from tinygrad.ops import UOps
|
2024-07-13 11:02:19 +08:00
|
|
|
from tinygrad.codegen.kernel import Kernel
|
2024-10-11 14:19:10 +08:00
|
|
|
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
|
2024-10-11 15:27:33 +08:00
|
|
|
from tinygrad.codegen.linearize import linearize_uop
|
|
|
|
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
2024-08-24 09:10:46 +08:00
|
|
|
from tinygrad.engine.search import beam_search, bufs_from_lin
|
2024-05-12 02:02:44 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
mdl = ResNet50()
|
2024-10-12 22:03:04 +08:00
|
|
|
for p in nn.state.get_parameters(mdl): p.replace(Tensor.empty(p.shape))
|
2024-05-12 02:02:44 +08:00
|
|
|
img = Tensor.empty(64, 3, 224, 224)
|
|
|
|
|
2024-08-21 14:36:58 +08:00
|
|
|
PROFILE = getenv("PROFILE", 0)
|
2024-07-13 08:37:49 +08:00
|
|
|
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
2024-08-20 21:51:27 +08:00
|
|
|
SCHEDULE_ONLY = getenv("SCHEDULE_ONLY", 0)
|
2024-05-12 02:02:44 +08:00
|
|
|
|
2024-08-21 14:36:58 +08:00
|
|
|
with Timing("all "):
|
|
|
|
with Timing("***** model tensor in "):
|
2024-05-12 02:02:44 +08:00
|
|
|
out = mdl(img)
|
|
|
|
|
2024-08-21 14:36:58 +08:00
|
|
|
if not FORWARD_ONLY:
|
|
|
|
with Timing("***** model schedule in "):
|
2024-07-13 08:37:49 +08:00
|
|
|
sched = out.schedule()
|
2024-05-12 02:02:44 +08:00
|
|
|
|
2024-08-21 14:36:58 +08:00
|
|
|
if not SCHEDULE_ONLY:
|
2024-08-30 09:26:24 +08:00
|
|
|
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values())
|
|
|
|
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
|
2024-08-21 14:36:58 +08:00
|
|
|
kernels: List[Kernel] = []
|
2024-08-24 09:10:46 +08:00
|
|
|
with Timing(f"***** model opts({len(asts):2d}) in "):
|
2024-08-20 21:51:27 +08:00
|
|
|
for ast in asts:
|
|
|
|
k = Kernel(ast)
|
2024-08-24 09:10:46 +08:00
|
|
|
if BEAM:
|
|
|
|
with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value)
|
2024-09-06 17:22:30 +08:00
|
|
|
elif NOOPT: pass
|
2024-08-24 09:10:46 +08:00
|
|
|
else: k.hand_coded_optimizations()
|
2024-08-20 21:51:27 +08:00
|
|
|
kernels.append(k)
|
|
|
|
|
2024-10-11 14:19:10 +08:00
|
|
|
with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels]
|
2024-08-21 14:36:58 +08:00
|
|
|
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
2024-08-30 09:26:24 +08:00
|
|
|
with Timing("***** model rewrite in "):
|
|
|
|
rewritten_uops = []
|
2024-08-30 12:20:36 +08:00
|
|
|
for i,(k,u) in enumerate(zip(kernels, uops)):
|
|
|
|
with Timing(f"rewrite {i:2d} {k.name}{' '*(50-ansilen(k.name))}", enabled=getenv("VERBOSE", 0)):
|
2024-08-30 09:26:24 +08:00
|
|
|
rewritten_uops.append(full_graph_rewrite(u, k.opts))
|
|
|
|
uops = rewritten_uops
|
2024-08-21 14:36:58 +08:00
|
|
|
if getenv("LINEARIZE", 1):
|
2024-08-30 09:26:24 +08:00
|
|
|
with Timing("***** model linearize in "): uops = [linearize_uop(u) for u in uops]
|
2024-08-21 14:36:58 +08:00
|
|
|
print(sum(len(u) for u in uops))
|
2024-08-30 09:26:24 +08:00
|
|
|
if getenv("SRC", 0):
|
|
|
|
renderer = Device[Device.DEFAULT].renderer
|
|
|
|
for k,u in zip(kernels, uops): print(renderer.render(k.name, u))
|