diff --git a/test/test_schedule.py b/test/test_schedule.py index 20d024bd..89ac83b5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1,6 +1,7 @@ # this will be the new test_ops for the next level # schedule confirms the right things are capable of fusing # NOTE: this has overlap with external_test_opt.py +# ruff: noqa: E501 import unittest import numpy as np @@ -11,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor from tinygrad.dtype import DType, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite +from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left @@ -1763,5 +1764,36 @@ class TestIndexing(unittest.TestCase): self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape)) self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27)) + def test_permute_rewrite(self): + sink = UOp(UOps.STORE, dtypes.void, arg=None, src=( + x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()), + x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=( + x1, + UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=( + UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( + UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=( + UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( + x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()), + x2,)),)), + UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( + UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), + UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( + UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()), + UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( + x11,)),)),)),)),)) + @track_rewrites + def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right) + ret = rewrite(sink) + assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}" + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b3fc4b1b..96cfefca 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -105,15 +105,15 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),) # push VIEW to loads view_left = merge_views+PatternMatcher([ - # view on reduce - (UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), - # view on elementwise + # view before ALU (UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), ]) # push VIEW to stores view_right = merge_views+PatternMatcher([ + # view on reduce creates a new VIEW + (UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)