Better reshape (#1423)

* do reshaping without merge_views and reshape masks

* added tests

* properly do reshaping of zero or negative masks

* replace while loop with single expression

* remove old condition

* add more tests and comments

* remove empty file
This commit is contained in:
Sieds Lykles 2023-08-14 18:09:04 +02:00 committed by GitHub
parent e00acb1eaf
commit cf2bf1518d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 137 additions and 22 deletions

View File

@ -122,7 +122,6 @@ class TestRealSimplifies(unittest.TestCase):
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
class TestIndexExpressions2d(unittest.TestCase):
def setUp(self):
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
offsets = [0, 1, 15, 28, 10000]
@ -188,6 +187,7 @@ class TestIndexExpressions2d(unittest.TestCase):
st.expand((base_shape[0], base_shape[1], base_shape[1]))
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
def test_permute_reshape_1(self): # This tests multiple views
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
st.permute((1, 0))
@ -417,6 +417,94 @@ class TestMaskedShapeTracker(unittest.TestCase):
self.st.pad(((1,1), (1,1)))
self.st.assert_same()
def test_reshaping_splitting(self):
self.st = CheckingShapeTracker((5,10,5,10))
self.st.permute((1, 0, 3, 2))
self.st.pad(((0,0), (0,5), (0,0), (0,5)))
self.st.reshape((10,2,5,10,2,5))
assert len(self.st.views) == 1
self.st.assert_same()
def test_reshape_combining_1(self):
self.st = CheckingShapeTracker((2,1,10))
self.st.pad(((2,6), (0,0), (0,0)))
self.st.reshape((100,))
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.skip("Can't make this optimization yet")
def test_reshape_combining_2(self):
self.st = CheckingShapeTracker((1,1,5))
self.st.pad(((3,6), (0,0), (0,5)))
self.st.reshape((100,))
assert len(self.st.views) == 1
self.st.assert_same()
@unittest.skip("Can't make this optimization yet")
def test_reshape_splitting_combining(self):
self.st = CheckingShapeTracker((1,5,5))
self.st.pad(((0,4), (0,5), (0,0)))
self.st.reshape((10,25))
assert len(self.st.views) == 1
self.st.assert_same()
def test_reshape_only_1s(self):
self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
self.st.reshape((5, 6, 3, 5))
assert len(self.st.views) == 1
self.st.assert_same()
self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1))
assert len(self.st.views) == 1
self.st.assert_same()
self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1))
assert len(self.st.views) == 1
self.st.assert_same()
def test_zero_mask_1(self):
self.st = CheckingShapeTracker((1, 3, 2))
self.st.pad(((0,0), (0,3), (0,0)))
self.st.shrink(((0,1), (3,6), (0,2)))
self.st.reshape((3,2))
self.st.assert_same()
self.st.reshape((1, 3, 1, 2, 1))
self.st.assert_same()
def test_zero_mask_2(self):
self.st = CheckingShapeTracker((1, 3, 2))
self.st.pad(((0,2), (0,3), (0,0)))
self.st.shrink(((2,3), (3,6), (0,2)))
self.st.reshape((3,2))
self.st.assert_same()
self.st.reshape((1, 3, 1, 2, 1))
self.st.assert_same()
def test_expanded_reshaped(self):
self.st = CheckingShapeTracker((1, 3, 2, 1))
self.st.expand((5, 3, 2, 2))
self.st.pad(((0,0), (0,3), (0,0), (0, 0)))
self.st.reshape((5, 2, 3, 2, 2))
assert len(self.st.views) == 1
self.st.assert_same()
def test_splitting_big(self):
self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
self.st.reshape((10, 1, 30))
self.st.permute((2,1,0))
self.st.reshape((2,3,5,2,5))
assert len(self.st.views) == 1
v = self.st.views[-1]
assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
def test_combining_big(self):
self.st = CheckingShapeTracker((1,3,1,5,3,1))
self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))
self.st.reshape((1,1,1,105,1,1))
assert len(self.st.views) == 1
v = self.st.views[-1]
assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)), v.offset == -30
class TestShapeTracker(unittest.TestCase):
def setUp(self):
self.st = CheckingShapeTracker((7,4))

View File

@ -95,30 +95,58 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
if None in strides: return None
return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
def _reshape_mask(view: View, new_shape:Tuple[int, ...]) -> Tuple[Optional[Tuple[Tuple[int, int], ...]],bool]:
# assumes view can be reshaped to new_shape (if it had no mask), this implies we won't have to worry about strides
if view.mask is None: return view.mask, False
new_mask: List[Tuple[int, int]] = []
r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int],...], view.mask)), reversed(view.shape), reversed(new_shape)
stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
while len(new_mask) < len(new_shape):
if mask[1]-mask[0] < 1: # if the mask is never valid, just return all zeros
return ((0,0),)*len(new_shape), False
if old_dim == new_dim*stride: # easy, can just copy the mask
new_mask.append((mask[0]//stride, (mask[1]-1)//stride+1))
stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
elif old_dim > new_dim: # splitting the old mask
# we cannot split if the reshape cuts across the mask
if (mask[0]%(new_dim*stride)!=0 or mask[1]%(new_dim*stride)!=0) and mask[0]//(new_dim*stride)!=(mask[1]-1)//(new_dim*stride):
return view.mask, True
new_mask.append((mask[0]%(new_dim*stride)//stride, (mask[1]-1)%(new_dim*stride)//stride+1))
# the remaining mask still needs to be split, we need to determine the mask for the next dimension
# we maintain the stride
stride *= new_dim
new_dim = next(r_new_shape, 1)
elif old_dim < new_dim*stride: # combining masks
next_mask = next(r_masks, (0,1))
# if the current dimension is masked, we cannot merge unless the next masks have an index range of 1
if (mask[0]!=0 or mask[1]!=old_dim) and next_mask[1]-next_mask[0]!=1:
return view.mask, True
# we combine the current mask with the next and go through the loop again with the next dimension
mask = (next_mask[0]*old_dim+mask[0], (next_mask[1]-1)*old_dim+mask[1])
old_dim *= next(r_shape, 1)
for mask in (mask, *r_masks): # if the old shape has leading 1s, need to make sure their mask is (0,1), otherwise the mask is zero'd
if mask != (0,1): return ((0,0),)*len(new_shape), False
return tuple(reversed(new_mask)), False
@functools.lru_cache(maxsize=None)
def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
# check if this is adding or removing 1s (only)
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]:
new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1]
new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
new_mask_tuple = None
if mask:
for x,y in zip(shape, mask):
if x == 1 and y != (0, 1):
new_mask_tuple = ((0,0),) * len(new_shape)
break
else:
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
if view.contiguous or (view.shape == new_shape): return View(new_shape), False
strides, reverse_shape = [], reversed(new_shape)
for d, s in reversed(view.shape_strides):
acc, new_stride = 1, s
while acc < d:
new_dim = next(reverse_shape)
acc *= new_dim
strides.append(new_stride)
new_stride *= new_dim
if acc != d: break
else:
strides += [0,] * (len(new_shape) - len(strides))
mask, extra = _reshape_mask(view, new_shape)
if not extra: return View(new_shape, tuple(reversed(strides)), view.offset, mask), False
new_view = View(new_shape)
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
return new_view, True
return View(new_shape), True
@functools.lru_cache(maxsize=None)
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
@ -147,7 +175,6 @@ class ShapeTracker:
# this is the real size (ish)
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> int:
real_offset, mask = self.expr_node(Variable('zero', 0, 0))