test a rewrite of permuted reduce [pr] (#7093)

* test a rewrite of permuted reduce [pr]

* addd rewrite tracker

* expected

* passes
This commit is contained in:
qazal 2024-10-16 12:49:54 +03:00 committed by GitHub
parent 56fbd408a1
commit 6acda43a2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 4 deletions

View File

@ -1,6 +1,7 @@
# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
# ruff: noqa: E501
import unittest
import numpy as np
@ -11,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.dtype import DType, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
@ -1763,5 +1764,36 @@ class TestIndexing(unittest.TestCase):
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_permute_rewrite(self):
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),))
@track_rewrites
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
ret = rewrite(sink)
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -105,15 +105,15 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),)
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view on reduce
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# view on elementwise
# view before ALU
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
])
# push VIEW to stores
view_right = merge_views+PatternMatcher([
# view on reduce creates a new VIEW
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)