generic elementwise view rewrite rule + merge_views (#7078)

* generic elementwise view rewrite rule + merge_views [pr]

* no pr, views merge
This commit is contained in:
qazal 2024-10-16 04:36:21 +03:00 committed by GitHub
parent fb29de6cc3
commit bddba5897a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 6 deletions

View File

@ -2,7 +2,7 @@ import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
graph_rewrite, track_rewrites, sint
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
@ -103,14 +103,19 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
view_left = PatternMatcher([
merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view on reduce
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
# view on valid
(UPat(UOps.VIEW, src=(UPat(UOps.ALU, src=(UPat(UOps.VALID), UPat.var(), UPat.var()), name="alu", arg=TernaryOps.WHERE),), name="root"),
lambda root,alu: UOp(UOps.VALID, dtypes.bool, (root.st.to_uop(),)).where(*alu.src[1:]) if root.st != alu.st else alu),
# view on elementwise
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
])
view_right = PatternMatcher([
# push VIEW to stores
view_right = merge_views+PatternMatcher([
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)