tinygrad/test/external/external_benchmark_ast.py

60 lines
2.6 KiB
Python

import time, pickle
import plotly.graph_objects as go
from typing import Dict, List, Tuple
from extra.models.resnet import ResNet50
from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, getenv, to_function_name
from tinygrad.engine.schedule import _get_output_groups, _lower_lazybuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.ops import UOp, UOps
if __name__ == "__main__":
mdl = ResNet50()
img = Tensor.empty(64, 3, 224, 224)
out = mdl(img)
output_groups, realizes, _ = _get_output_groups(out.lazydata.lbs, set())
asts: List[UOp] = []
no_rewrite: List[float] = []
for k,v in output_groups.items():
st = time.perf_counter_ns()
lsi = _lower_lazybuffer(v, realizes)[0]
et = time.perf_counter_ns() - st
if lsi.ast.op is UOps.EXT: continue
no_rewrite.append(et*1e-6)
asts.append(lsi.ast)
rewrite: List[float] = []
bufs: List[List[LazyBuffer]] = []
with Context(AST_REWRITE=1):
for k,v in output_groups.items():
st = time.perf_counter_ns()
lsi = _lower_lazybuffer(v, realizes)[0]
bufs.append(v)
et = time.perf_counter_ns() - st
if lsi.ast.op is UOps.EXT: continue
rewrite.append(et*1e-6)
assert len(rewrite) == len(no_rewrite) == len(asts)
kernel_tms: Dict[bytes, Tuple[UOp, float, float, List[LazyBuffer]]] = {k.key:(k, no_rewrite[i], rewrite[i], bufs[i]) for i,k in enumerate(asts)}
pct_change: Dict[bytes, float] = {k:((x-y)/x)*100 for k,(_,x,y,_) in kernel_tms.items()}
slowest_kernels = list(sorted(pct_change.items(), key=lambda x:x[1]))
names = {ast.key:Kernel(ast).name for ast,_,_,_ in kernel_tms.values()}
print("slowest ast rewrites:")
for k,pct in slowest_kernels[:10]:
_, no_rw, rw, outs = kernel_tms[k]
print(f"{names[k]:10s} {no_rw:4.2f} ms -> {rw:4.2f} ms {pct:4.2f}%")
with open("/tmp/kernel_tms", "wb") as f: pickle.dump(kernel_tms, f)
if getenv("GRAPH_TIMING"):
sample = slowest_kernels[:20]
x: List[str] = [to_function_name(names[k]) for k,_ in sample]
y1, y2 = [kernel_tms[k][1] for k,_ in sample], [kernel_tms[k][2] for k,_ in sample]
fig = go.Figure(data=[go.Bar(name="no graph_rewrite", x=x, y=y1, marker=dict(color="#524eed", line=dict(color='rgba(0,0,0,0)'))),
go.Bar(name="graph_rewrite", x=x, y=y2, marker=dict(color="#6fcf97", line=dict(color='rgba(0,0,0,0)')))])
fig.update_layout(barmode="group", paper_bgcolor="black", plot_bgcolor="black",
font={"color":"white"}, yaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"})
fig.show()