test a rewrite of permuted reduce [pr] (#7093)

* test a rewrite of permuted reduce [pr]

* addd rewrite tracker

* expected

* passes
This commit is contained in:
qazal 2024-10-16 12:49:54 +03:00 committed by GitHub
parent 56fbd408a1
commit 6acda43a2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 4 deletions

View File

@ -1,6 +1,7 @@
# this will be the new test_ops for the next level # this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing # schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py # NOTE: this has overlap with external_test_opt.py
# ruff: noqa: E501
import unittest import unittest
import numpy as np import numpy as np
@ -11,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.dtype import DType, PtrDType from tinygrad.dtype import DType, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
@ -1763,5 +1764,36 @@ class TestIndexing(unittest.TestCase):
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape)) self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27)) self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_permute_rewrite(self):
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),))
@track_rewrites
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
ret = rewrite(sink)
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@ -105,15 +105,15 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),)
# push VIEW to loads # push VIEW to loads
view_left = merge_views+PatternMatcher([ view_left = merge_views+PatternMatcher([
# view on reduce # view before ALU
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# view on elementwise
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"), (UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
]) ])
# push VIEW to stores # push VIEW to stores
view_right = merge_views+PatternMatcher([ view_right = merge_views+PatternMatcher([
# view on reduce creates a new VIEW
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes) # push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)