mirror of https://github.com/commaai/tinygrad.git
60 lines
2.6 KiB
Python
60 lines
2.6 KiB
Python
import time, pickle
|
|
import plotly.graph_objects as go
|
|
from typing import Dict, List, Tuple
|
|
from extra.models.resnet import ResNet50
|
|
from tinygrad import Tensor
|
|
from tinygrad.codegen.kernel import Kernel
|
|
from tinygrad.helpers import Context, getenv, to_function_name
|
|
from tinygrad.engine.schedule import _get_output_groups, _lower_lazybuffer
|
|
from tinygrad.engine.lazy import LazyBuffer
|
|
from tinygrad.ops import UOp, UOps
|
|
|
|
if __name__ == "__main__":
|
|
mdl = ResNet50()
|
|
img = Tensor.empty(64, 3, 224, 224)
|
|
out = mdl(img)
|
|
output_groups, realizes, _ = _get_output_groups(out.lazydata.lbs, set())
|
|
|
|
asts: List[UOp] = []
|
|
no_rewrite: List[float] = []
|
|
for k,v in output_groups.items():
|
|
st = time.perf_counter_ns()
|
|
lsi = _lower_lazybuffer(v, realizes)[0]
|
|
et = time.perf_counter_ns() - st
|
|
if lsi.ast.op is UOps.EXT: continue
|
|
no_rewrite.append(et*1e-6)
|
|
asts.append(lsi.ast)
|
|
|
|
rewrite: List[float] = []
|
|
bufs: List[List[LazyBuffer]] = []
|
|
with Context(AST_REWRITE=1):
|
|
for k,v in output_groups.items():
|
|
st = time.perf_counter_ns()
|
|
lsi = _lower_lazybuffer(v, realizes)[0]
|
|
bufs.append(v)
|
|
et = time.perf_counter_ns() - st
|
|
if lsi.ast.op is UOps.EXT: continue
|
|
rewrite.append(et*1e-6)
|
|
|
|
assert len(rewrite) == len(no_rewrite) == len(asts)
|
|
|
|
kernel_tms: Dict[bytes, Tuple[UOp, float, float, List[LazyBuffer]]] = {k.key:(k, no_rewrite[i], rewrite[i], bufs[i]) for i,k in enumerate(asts)}
|
|
pct_change: Dict[bytes, float] = {k:((x-y)/x)*100 for k,(_,x,y,_) in kernel_tms.items()}
|
|
slowest_kernels = list(sorted(pct_change.items(), key=lambda x:x[1]))
|
|
names = {ast.key:Kernel(ast).name for ast,_,_,_ in kernel_tms.values()}
|
|
print("slowest ast rewrites:")
|
|
for k,pct in slowest_kernels[:10]:
|
|
_, no_rw, rw, outs = kernel_tms[k]
|
|
print(f"{names[k]:10s} {no_rw:4.2f} ms -> {rw:4.2f} ms {pct:4.2f}%")
|
|
with open("/tmp/kernel_tms", "wb") as f: pickle.dump(kernel_tms, f)
|
|
|
|
if getenv("GRAPH_TIMING"):
|
|
sample = slowest_kernels[:20]
|
|
x: List[str] = [to_function_name(names[k]) for k,_ in sample]
|
|
y1, y2 = [kernel_tms[k][1] for k,_ in sample], [kernel_tms[k][2] for k,_ in sample]
|
|
fig = go.Figure(data=[go.Bar(name="no graph_rewrite", x=x, y=y1, marker=dict(color="#524eed", line=dict(color='rgba(0,0,0,0)'))),
|
|
go.Bar(name="graph_rewrite", x=x, y=y2, marker=dict(color="#6fcf97", line=dict(color='rgba(0,0,0,0)')))])
|
|
fig.update_layout(barmode="group", paper_bgcolor="black", plot_bgcolor="black",
|
|
font={"color":"white"}, yaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"})
|
|
fig.show()
|