mirror of https://github.com/commaai/tinygrad.git
push swizzle through dim change (#6801)
* push swizzle through dim change * can this be generic * generic version * cleanups
This commit is contained in:
parent
a76c6c740c
commit
2ec73d6f05
|
@ -1364,11 +1364,10 @@ class TestIndexing(unittest.TestCase):
|
|||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_advanced_indexing(self):
|
||||
X = Tensor.arange(10)+1
|
||||
xt = X[[0]]
|
||||
self.check_schedule(xt, 3)
|
||||
self.check_schedule(xt, 2)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
|
|
|
@ -84,10 +84,13 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
|||
(reduceop.arg[0], new_axis)).swizzle(ShapeTracker.from_shape(swizzle_st.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
swizzle_st = unwrap(swizzle.st)
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
|
||||
assert swizzle_st.contiguous, "can't push a non contiguous SWIZZLE down to STORE"
|
||||
assert prod(swizzle_st.shape) == prod(unwrap(swizzle.src[0].st).shape), "can't push expands down to STORE"
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, root.arg).swizzle(ShapeTracker.from_shape(swizzle_st.reduce(root.arg[1])))
|
||||
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
||||
op, axis = root.arg
|
||||
output_shape = swizzle_st.reduce(axis)
|
||||
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (op, new_axis)).swizzle(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzles = [x for x in root.src if x.op is UOps.SWIZZLE]
|
||||
|
|
Loading…
Reference in New Issue