mirror of https://github.com/commaai/tinygrad.git
Proposal: Better UOps.SWIZZLE (#6309)
* better UOps.SWIZZLE * test_swizzle_rewrite * add it to docs * show a diff * a lil more verbose * two teeny notes * hotfix: sink
This commit is contained in:
parent
8c50ef8b7c
commit
07942ef361
|
@ -8,13 +8,15 @@ import numpy as np
|
|||
from typing import Dict, List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup, reduceop_fusor
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from test.helpers import is_dtype_supported, Context
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
@ -1642,5 +1644,31 @@ class TestScheduleRewrite(unittest.TestCase):
|
|||
new_val = st_fixup(val, lambda st:st.reshape((4,)), {}, {})
|
||||
self.assertIs(new_val, val)
|
||||
|
||||
def test_swizzle_rewrite(self):
|
||||
# graph rewrite
|
||||
sink = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
a = Tensor.randint(32, 32).realize()
|
||||
b = Tensor.empty((), dtype=dtypes.int).realize()
|
||||
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
|
||||
expected_out = (a.numpy() + a.numpy().sum()).sum()
|
||||
np.testing.assert_equal(b.numpy(), expected_out)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -82,7 +82,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
swizzle = (UOp(UOps.SWIZZLE, src=(st.to_uop(),)),) if not st.contiguous and AST_REWRITE else ()
|
||||
swizzle_arg = st if not st.contiguous and AST_REWRITE else None
|
||||
rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \
|
||||
if AST_REWRITE else reduce_info.get((buf, st))
|
||||
rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache)
|
||||
|
@ -91,7 +91,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
if rinfo is None:
|
||||
assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}"
|
||||
return rsrc
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,)+swizzle, (alu_op, rinfo[1])))
|
||||
ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, rinfo[1]))
|
||||
if swizzle_arg is not None: ret = UOp(UOps.SWIZZLE, dtype, (ret,), swizzle_arg)
|
||||
return cache.setdefault((buf, st), ret)
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs)
|
||||
|
@ -172,10 +174,11 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
|
|||
|
||||
# ***** reduceop fusor *****
|
||||
|
||||
def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp:
|
||||
def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp:
|
||||
uop_sts: Dict[UOp, ShapeTracker] = {}
|
||||
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, root.arg[1])
|
||||
return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis))
|
||||
rsrc = reduceop.src[0]
|
||||
new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1])
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis))
|
||||
|
||||
def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
||||
reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS]
|
||||
|
@ -186,7 +189,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]:
|
|||
return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {})
|
||||
|
||||
reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(name="rsrc"), UPat(UOps.SWIZZLE, src=(UPat(name="swizzle"),))), name="root"), apply_swizzle),
|
||||
(UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce),
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape),
|
||||
])
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class UOps(Enum):
|
|||
Holds `UOps.STORE`. SINK defines the AST for a Kernel.
|
||||
|
||||
- **`dtype`**: `None`
|
||||
- **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed.
|
||||
- **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed.
|
||||
- **`arg`**: `Optional[KernelInfo]`
|
||||
|
||||
NOTE: `ScheduleItem` ASTs do not have the `KernelInfo` arg, `Kernel` inserts this to the SINK later.
|
||||
|
@ -70,6 +70,59 @@ class UOps(Enum):
|
|||
- **`arg`**: `ShapeTracker`
|
||||
"""
|
||||
SWIZZLE = auto()
|
||||
"""
|
||||
Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST,
|
||||
the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph.
|
||||
|
||||
Example:
|
||||
```python
|
||||
a = Tensor.empty(32, 32)
|
||||
first_reduce = a.sum()
|
||||
output = (a + first_reduce).sum()
|
||||
```
|
||||
`first_reduce` must broadcast to `(32, 32)` before ADD. We UOp this as:
|
||||
|
||||
```
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x3,
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
||||
```
|
||||
|
||||
The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD:
|
||||
|
||||
```diff
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
- UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
- UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
- UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
- x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
- UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
+ UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=(
|
||||
+ UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
+ x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
+ UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
- x3,
|
||||
- UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
||||
+ x2,
|
||||
+ UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
||||
|
||||
```
|
||||
|
||||
NOTE: Pushing a SWIZZLE through a reduce changes the axis.
|
||||
|
||||
NOTE: Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to `(32, 32, 1, 1)` above.
|
||||
|
||||
- **`dtype`**: Output DType
|
||||
- **`src`**: `Tuple[UOp]`, a single UOp to swizzle.
|
||||
- **`arg`**: ShapeTracker
|
||||
""" # noqa E501
|
||||
DEFINE_GLOBAL = auto()
|
||||
DEFINE_VAR = auto()
|
||||
DEFINE_LOCAL = auto()
|
||||
|
|
Loading…
Reference in New Issue