mirror of https://github.com/commaai/tinygrad.git
generic elementwise view rewrite rule + merge_views (#7078)
* generic elementwise view rewrite rule + merge_views [pr] * no pr, views merge
This commit is contained in:
parent
fb29de6cc3
commit
bddba5897a
|
@ -2,7 +2,7 @@ import sys, pickle, atexit
|
|||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
|
||||
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
|
||||
graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
|
||||
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
|
||||
|
@ -103,14 +103,19 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
|||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
view_left = PatternMatcher([
|
||||
merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view on reduce
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
|
||||
# view on valid
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.ALU, src=(UPat(UOps.VALID), UPat.var(), UPat.var()), name="alu", arg=TernaryOps.WHERE),), name="root"),
|
||||
lambda root,alu: UOp(UOps.VALID, dtypes.bool, (root.st.to_uop(),)).where(*alu.src[1:]) if root.st != alu.st else alu),
|
||||
# view on elementwise
|
||||
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
])
|
||||
view_right = PatternMatcher([
|
||||
|
||||
# push VIEW to stores
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
|
|
Loading…
Reference in New Issue