mirror of https://github.com/commaai/tinygrad.git
faster rewrite, no folder in expand/reduce [run_process_replay] (#6216)
* faster rewrite, no folder in expand/reduce [run_process_replay] * is removing the expander there okay * parens * don't reconstruct exact match uop * fast do_reduce * expand pyint * most of the parents gains with less lines
This commit is contained in:
parent
16f420f7a7
commit
2c42e9c2c6
|
@ -1,57 +1,44 @@
|
|||
from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import Profiling, Timing, getenv
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = ResNet50()
|
||||
img = Tensor.empty(64, 3, 224, 224)
|
||||
|
||||
PROFILE = getenv("PROFILE", 1)
|
||||
PROFILE = getenv("PROFILE", 0)
|
||||
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
||||
SCHEDULE_ONLY = getenv("SCHEDULE_ONLY", 0)
|
||||
|
||||
with Profiling(PROFILE):
|
||||
with Timing("***** model forward in "):
|
||||
with Timing("all "):
|
||||
with Timing("***** model tensor in "):
|
||||
out = mdl(img)
|
||||
|
||||
if not FORWARD_ONLY:
|
||||
with Profiling(PROFILE):
|
||||
with Timing("***** model schedule in "):
|
||||
sched = out.schedule()
|
||||
|
||||
if not SCHEDULE_ONLY:
|
||||
asts = {x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values()
|
||||
kernels = []
|
||||
with Profiling(PROFILE):
|
||||
with Timing("***** model uops in "):
|
||||
kernels: List[Kernel] = []
|
||||
with Timing("***** model opts in "):
|
||||
for ast in asts:
|
||||
k = Kernel(ast)
|
||||
k.hand_coded_optimizations()
|
||||
kernels.append(k)
|
||||
|
||||
with Profiling(PROFILE, fn="/tmp/schedule.prof"):
|
||||
with Timing("***** model linearize in "):
|
||||
for k in kernels: k.linearize()
|
||||
|
||||
#renderer = Device[Device.DEFAULT].renderer
|
||||
#with Profiling(PROFILE, fn="/tmp/schedule.prof"):
|
||||
# with Timing("***** model render in "):
|
||||
# for n,u in uops: renderer.render(n, u)
|
||||
|
||||
# snakeviz /tmp/schedule.prof
|
||||
#with Profiling(PROFILE, fn="/tmp/schedule.prof"):
|
||||
# with Timing("***** model lower in "):
|
||||
# eis = list(lower_schedule(sched))
|
||||
|
||||
# random makes this slow
|
||||
#with Profiling(PROFILE):
|
||||
# with Timing("***** model run in "):
|
||||
# for ei in eis: ei.run()
|
||||
|
||||
# this is all wait
|
||||
#with Profiling(PROFILE):
|
||||
# with Timing("***** model finish in "):
|
||||
# out.data()
|
||||
|
||||
with Timing("***** model lower in "): uops = [ast_to_uop(k.get_optimized_ast(), k.opts) for k in kernels]
|
||||
with Profiling(PROFILE, fn="/tmp/rewrite.prof"):
|
||||
with Timing("***** model rewrite in "): uops = [full_graph_rewrite(u, k.opts) for u in uops]
|
||||
if getenv("LINEARIZE", 1):
|
||||
with Timing("***** model linearize in "): uops = [linearize_uop(u, skip_check=False) for u in uops]
|
||||
print(sum(len(u) for u in uops))
|
||||
if getenv("GRAPHUOPS", 0):
|
||||
for u in uops:
|
||||
from tinygrad.engine.graph import graph_uops
|
||||
graph_uops(u)
|
||||
|
|
|
@ -417,7 +417,7 @@ def do_expand(root:UOp):
|
|||
acc_number = 0
|
||||
def do_reduce(root:UOp):
|
||||
global acc_number
|
||||
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
|
||||
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].sparents)
|
||||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
assert root.dtype is not None
|
||||
|
@ -495,7 +495,7 @@ reducer = PatternMatcher([
|
|||
(UPat(UOps.LOAD, src=(UPat(name="buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE}, dtype=dtypes.pyint, name="x"),
|
||||
no_pyint = PatternMatcher([(UPat({UOps.CONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND}, dtype=dtypes.pyint, name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32, x.src, x.arg))])
|
||||
|
||||
# *** uop graph ***
|
||||
|
@ -527,8 +527,8 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != getenv("DEBUG_EXPAND", 0):
|
||||
sink = graph_rewrite(sink, folder+expander+float4_folding if opts is not None and opts.supports_float4 else folder+expander)
|
||||
sink = graph_rewrite(sink, folder+expander+reducer)
|
||||
sink = graph_rewrite(sink, folder+(expander+float4_folding if opts is not None and opts.supports_float4 else expander))
|
||||
sink = graph_rewrite(sink, folder+reducer)
|
||||
|
||||
# for PTX only
|
||||
if opts is not None and opts.extra_matcher is not None: sink = graph_rewrite(sink, folder+opts.extra_matcher)
|
||||
|
|
|
@ -5,7 +5,7 @@ import math, operator, ctypes, struct, functools, hashlib, itertools
|
|||
from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
|
||||
from tinygrad.helpers import merge_dicts, pretty_print, prod
|
||||
from tinygrad.helpers import pretty_print, prod
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
@ -166,7 +166,7 @@ class UOp:
|
|||
@staticmethod
|
||||
def store(*src:UOp, **kwargs): return type((src:=(*src, *kwargs.values()))[0])(UOps.STORE, None, src)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return merge_dicts([{x:None for x in self.src}]+[x.parents for x in self.src])
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@functools.cached_property
|
||||
|
@ -301,9 +301,11 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
|||
replace: Dict[UOp, UOp] = {}
|
||||
def __inner_rewrite(n:UOp) -> UOp:
|
||||
if rn := replace.get(n): return rn
|
||||
replace_source = (n.op, n.dtype, tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
||||
replace_source = (n.op, n.dtype, new_src:=tuple(__inner_rewrite(y) for y in n.src), n.arg)
|
||||
if found := nodes.get(replace_source): replace[n] = found
|
||||
else: nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x:=UOp(*replace_source))) else x
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
nodes[replace_source] = replace[n] = found = __inner_rewrite(new_x) if (new_x := pm.rewrite(x)) else x
|
||||
return found
|
||||
return __inner_rewrite(sink)
|
||||
|
||||
|
|
Loading…
Reference in New Issue