Proposal: Better UOps.SWIZZLE (#6309)

* better UOps.SWIZZLE

* test_swizzle_rewrite

* add it to docs

* show a diff

* a lil more verbose

* two teeny notes

* hotfix: sink
This commit is contained in:
qazal 2024-08-29 20:39:48 +08:00 committed by GitHub
parent 8c50ef8b7c
commit 07942ef361
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 94 additions and 10 deletions

View File

@ -8,13 +8,15 @@ import numpy as np
from typing import Dict, List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.dtype import PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup, reduceop_fusor
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import is_dtype_supported, Context
from tinygrad.lazy import LazyBuffer, view_supported_devices
from extra.models.llama import precompute_freqs_cis
@ -1642,5 +1644,31 @@ class TestScheduleRewrite(unittest.TestCase):
new_val = st_fixup(val, lambda st:st.reshape((4,)), {}, {})
self.assertIs(new_val, val)
def test_swizzle_rewrite(self):
# graph rewrite
sink = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = graph_rewrite(sink, reduceop_fusor)
# verify output
k = Kernel(sink)
p = k.to_program()
a = Tensor.randint(32, 32).realize()
b = Tensor.empty((), dtype=dtypes.int).realize()
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
expected_out = (a.numpy() + a.numpy().sum()).sum()
np.testing.assert_equal(b.numpy(), expected_out)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -82,7 +82,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
swizzle = (UOp(UOps.SWIZZLE, src=(st.to_uop(),)),) if not st.contiguous and AST_REWRITE else ()
swizzle_arg = st if not st.contiguous and AST_REWRITE else None
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
if AST_REWRITE else reduce_info.get((buf, st))
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
@ -91,7 +91,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if rinfo is None:
assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}"
return rsrc
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,)+swizzle, (alu_op, rinfo[1])))
ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, rinfo[1]))
if swizzle_arg is not None: ret = UOp(UOps.SWIZZLE, dtype, (ret,), swizzle_arg)
return cache.setdefault((buf, st), ret)
# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
@ -172,10 +174,11 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
# ***** reduceop fusor *****
def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp:
def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp:
uop_sts: Dict[UOp, ShapeTracker] = {}
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, root.arg[1])
return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis))
rsrc = reduceop.src[0]
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1])
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis))
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
@ -186,7 +189,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]:
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(name="rsrc"), UPat(UOps.SWIZZLE, src=(UPat(name="swizzle"),))), name="root"), apply_swizzle),
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce),
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape),
])

View File

@ -46,7 +46,7 @@ class UOps(Enum):
Holds `UOps.STORE`. SINK defines the AST for a Kernel.
- **`dtype`**: `None`
- **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed.
- **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed.
- **`arg`**: `Optional[KernelInfo]`
NOTE: `ScheduleItem` ASTs do not have the `KernelInfo` arg, `Kernel` inserts this to the SINK later.
@ -70,6 +70,59 @@ class UOps(Enum):
- **`arg`**: `ShapeTracker`
"""
SWIZZLE = auto()
"""
Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST,
the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph.
Example:
```python
a = Tensor.empty(32, 32)
first_reduce = a.sum()
output = (a + first_reduce).sum()
```
`first_reduce` must broadcast to `(32, 32)` before ADD. We UOp this as:
```
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x3,
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
```
The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD:
```diff
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
- UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
- UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
- UOp(UOps.LOAD, dtypes.int, arg=None, src=(
- x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
- UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
+ UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=(
+ UOp(UOps.LOAD, dtypes.int, arg=None, src=(
+ x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
+ UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
- x3,
- UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
+ x2,
+ UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),))
```
NOTE: Pushing a SWIZZLE through a reduce changes the axis.
NOTE: Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to `(32, 32, 1, 1)` above.
- **`dtype`**: Output DType
- **`src`**: `Tuple[UOp]`, a single UOp to swizzle.
- **`arg`**: ShapeTracker
""" # noqa E501
DEFINE_GLOBAL = auto()
DEFINE_VAR = auto()
DEFINE_LOCAL = auto()