push swizzle through dim change (#6801)

* push swizzle through dim change

* can this be generic

* generic version

* cleanups
This commit is contained in:
qazal 2024-09-30 09:04:59 +08:00 committed by GitHub
parent a76c6c740c
commit 2ec73d6f05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 5 deletions

View File

@ -1364,11 +1364,10 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
@unittest.expectedFailure
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0]]
self.check_schedule(xt, 3)
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
@unittest.expectedFailure

View File

@ -84,10 +84,13 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
(reduceop.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(swizzle_st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
swizzle_st = unwrap(swizzle.st)
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
assert swizzle_st.contiguous, "can't push a non contiguous SWIZZLE down to STORE"
assert prod(swizzle_st.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE"
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(swizzle_st.reduce(root.arg[1])))
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
op, axis = root.arg
output_shape = swizzle_st.reduce(axis)
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (op, new_axis)).swizzle(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzles = [x for x in root.src if x.op is UOps.SWIZZLE]