mirror of https://github.com/commaai/tinygrad.git
split schedule to view_left and view_right [pr] (#7077)
* split schedule to view_left and view_right [pr] * move valid
This commit is contained in:
parent
8601115976
commit
fb29de6cc3
|
@ -14,7 +14,7 @@ from tinygrad.shape.view import View
|
|||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, reduceop_fusor, st_fixup
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from test.helpers import ast_const, is_dtype_supported, Context, timeit
|
||||
|
@ -1614,7 +1614,7 @@ class TestIndexing(unittest.TestCase):
|
|||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
rsink = graph_rewrite(sink, reduceop_fusor)
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
self.assertEqual(rsink.key, sink.key)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
|
@ -1624,7 +1624,7 @@ class TestIndexing(unittest.TestCase):
|
|||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, reduceop_fusor)
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(sink)
|
||||
|
@ -1635,7 +1635,7 @@ class TestIndexing(unittest.TestCase):
|
|||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, reduceop_fusor)
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
self.assertEqual(sink.key, rsink.key)
|
||||
|
||||
|
@ -1646,7 +1646,7 @@ class TestIndexing(unittest.TestCase):
|
|||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(sink)
|
||||
|
@ -1664,7 +1664,7 @@ class TestIndexing(unittest.TestCase):
|
|||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(rsink)
|
||||
tms.append(et)
|
||||
|
@ -1693,7 +1693,7 @@ class TestIndexing(unittest.TestCase):
|
|||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
|
@ -1716,7 +1716,7 @@ class TestIndexing(unittest.TestCase):
|
|||
alu = swizzle_r+const
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
|
@ -1739,7 +1739,7 @@ class TestIndexing(unittest.TestCase):
|
|||
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
|
@ -1755,7 +1755,7 @@ class TestIndexing(unittest.TestCase):
|
|||
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
# there's an EXPAND pushing through the REDUCE_AXIS
|
||||
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
|
||||
ret = graph_rewrite(swizzle, reduceop_fusor)
|
||||
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
|
||||
# EXPAND is rewritten
|
||||
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
|
||||
# and pushed to the LOAD
|
||||
|
|
|
@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes, DType, PtrDType
|
|||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, reduceop_fusor
|
||||
from tinygrad.engine.schedule import create_schedule, enumerate_bufs
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
|
@ -441,7 +441,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
|||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(reduceop_fusor.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
||||
self.assertEqual(enumerate_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
test_upat = UPat(UOps.CONST, dtypes.bool)
|
||||
|
|
|
@ -103,12 +103,14 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
|||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
# SWIZZLE on VALID merges the views
|
||||
view_left = PatternMatcher([
|
||||
# view on reduce
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
|
||||
# view on valid
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.ALU, src=(UPat(UOps.VALID), UPat.var(), UPat.var()), name="alu", arg=TernaryOps.WHERE),), name="root"),
|
||||
lambda root,alu: UOp(UOps.VALID, dtypes.bool, (root.st.to_uop(),)).where(*alu.src[1:]) if root.st != alu.st else alu),
|
||||
# push a SWIZZLE up to LOAD, through a reduce (eg. expands)
|
||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_up_through_reduce),
|
||||
])
|
||||
view_right = PatternMatcher([
|
||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
|
@ -126,7 +128,7 @@ if getenv("RUN_PROCESS_REPLAY"):
|
|||
|
||||
@track_rewrites
|
||||
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
|
||||
sink = graph_rewrite(base_sink, reduceop_fusor)
|
||||
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
|
||||
ret = graph_rewrite(sink, enumerate_bufs, bufs)
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
|
||||
return ret
|
||||
|
|
Loading…
Reference in New Issue