mirror of https://github.com/commaai/tinygrad.git
test a rewrite of permuted reduce [pr] (#7093)
* test a rewrite of permuted reduce [pr] * addd rewrite tracker * expected * passes
This commit is contained in:
parent
56fbd408a1
commit
6acda43a2c
|
@ -1,6 +1,7 @@
|
|||
# this will be the new test_ops for the next level
|
||||
# schedule confirms the right things are capable of fusing
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
# ruff: noqa: E501
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
@ -11,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
|
|||
from tinygrad.dtype import DType, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
|
@ -1763,5 +1764,36 @@ class TestIndexing(unittest.TestCase):
|
|||
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
|
||||
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
|
||||
|
||||
def test_permute_rewrite(self):
|
||||
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||
x1,
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2,)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11,)),)),)),)),))
|
||||
@track_rewrites
|
||||
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
ret = rewrite(sink)
|
||||
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -105,15 +105,15 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),)
|
|||
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view on reduce
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
# view on elementwise
|
||||
# view before ALU
|
||||
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
])
|
||||
|
||||
# push VIEW to stores
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# view on reduce creates a new VIEW
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
|
|
Loading…
Reference in New Issue