mirror of https://github.com/commaai/tinygrad.git
test a rewrite of permuted reduce [pr] (#7093)
* test a rewrite of permuted reduce [pr] * addd rewrite tracker * expected * passes
This commit is contained in:
parent
56fbd408a1
commit
6acda43a2c
|
@ -1,6 +1,7 @@
|
||||||
# this will be the new test_ops for the next level
|
# this will be the new test_ops for the next level
|
||||||
# schedule confirms the right things are capable of fusing
|
# schedule confirms the right things are capable of fusing
|
||||||
# NOTE: this has overlap with external_test_opt.py
|
# NOTE: this has overlap with external_test_opt.py
|
||||||
|
# ruff: noqa: E501
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -11,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
|
||||||
from tinygrad.dtype import DType, PtrDType
|
from tinygrad.dtype import DType, PtrDType
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View
|
from tinygrad.shape.view import View
|
||||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
|
||||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||||
|
@ -1763,5 +1764,36 @@ class TestIndexing(unittest.TestCase):
|
||||||
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
|
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
|
||||||
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
|
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
|
||||||
|
|
||||||
|
def test_permute_rewrite(self):
|
||||||
|
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
|
x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||||
|
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
|
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||||
|
x1,
|
||||||
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
|
||||||
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
|
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
|
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||||
|
x2,)),)),
|
||||||
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
|
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||||
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||||
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
|
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||||
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||||
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
|
x11,)),)),)),)),))
|
||||||
|
@track_rewrites
|
||||||
|
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||||
|
ret = rewrite(sink)
|
||||||
|
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main(verbosity=2)
|
unittest.main(verbosity=2)
|
||||||
|
|
|
@ -105,15 +105,15 @@ merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),)
|
||||||
|
|
||||||
# push VIEW to loads
|
# push VIEW to loads
|
||||||
view_left = merge_views+PatternMatcher([
|
view_left = merge_views+PatternMatcher([
|
||||||
# view on reduce
|
# view before ALU
|
||||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
|
||||||
# view on elementwise
|
|
||||||
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
|
||||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||||
])
|
])
|
||||||
|
|
||||||
# push VIEW to stores
|
# push VIEW to stores
|
||||||
view_right = merge_views+PatternMatcher([
|
view_right = merge_views+PatternMatcher([
|
||||||
|
# view on reduce creates a new VIEW
|
||||||
|
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||||
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||||
|
|
Loading…
Reference in New Issue