split schedule to view_left and view_right [pr] (#7077)

* split schedule to view_left and view_right [pr]

* move valid
This commit is contained in:
qazal 2024-10-16 03:39:38 +03:00 committed by GitHub
parent 8601115976
commit fb29de6cc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 17 deletions

View File

@ -14,7 +14,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, reduceop_fusor, st_fixup
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, is_dtype_supported, Context, timeit
@ -1614,7 +1614,7 @@ class TestIndexing(unittest.TestCase):
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, reduceop_fusor)
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
def test_simple_store_reshape(self):
@ -1624,7 +1624,7 @@ class TestIndexing(unittest.TestCase):
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, reduceop_fusor)
rsink = graph_rewrite(sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(sink)
@ -1635,7 +1635,7 @@ class TestIndexing(unittest.TestCase):
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, reduceop_fusor)
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
@ -1646,7 +1646,7 @@ class TestIndexing(unittest.TestCase):
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
rsink, et = timeit(graph_rewrite, sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(sink)
@ -1664,7 +1664,7 @@ class TestIndexing(unittest.TestCase):
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
rsink, et = timeit(graph_rewrite, sink, view_right)
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(rsink)
tms.append(et)
@ -1693,7 +1693,7 @@ class TestIndexing(unittest.TestCase):
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# verify output
k = Kernel(sink)
p = k.to_program()
@ -1716,7 +1716,7 @@ class TestIndexing(unittest.TestCase):
alu = swizzle_r+const
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(sink, view_right)
# verify output
k = Kernel(sink)
p = k.to_program()
@ -1739,7 +1739,7 @@ class TestIndexing(unittest.TestCase):
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
# graph rewrite
sink = graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(sink, view_right)
# verify output
k = Kernel(sink)
p = k.to_program()
@ -1755,7 +1755,7 @@ class TestIndexing(unittest.TestCase):
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an EXPAND pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
ret = graph_rewrite(swizzle, reduceop_fusor)
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
# EXPAND is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD

View File

@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, reduceop_fusor
from tinygrad.engine.schedule import create_schedule, enumerate_bufs
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
@ -441,7 +441,7 @@ class TestIndexingOrdering(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
self.assertEqual(reduceop_fusor.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(enumerate_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool)

View File

@ -103,12 +103,14 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
reduceop_fusor = PatternMatcher([
# SWIZZLE on VALID merges the views
view_left = PatternMatcher([
# view on reduce
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
# view on valid
(UPat(UOps.VIEW, src=(UPat(UOps.ALU, src=(UPat(UOps.VALID), UPat.var(), UPat.var()), name="alu", arg=TernaryOps.WHERE),), name="root"),
lambda root,alu: UOp(UOps.VALID, dtypes.bool, (root.st.to_uop(),)).where(*alu.src[1:]) if root.st != alu.st else alu),
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
])
view_right = PatternMatcher([
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
@ -126,7 +128,7 @@ if getenv("RUN_PROCESS_REPLAY"):
@track_rewrites
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
sink = graph_rewrite(base_sink, reduceop_fusor)
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
ret = graph_rewrite(sink, enumerate_bufs, bufs)
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
return ret