mirror of https://github.com/commaai/tinygrad.git
Better reshape (#1423)
* do reshaping without merge_views and reshape masks * added tests * properly do reshaping of zero or negative masks * replace while loop with single expression * remove old condition * add more tests and comments * remove empty file
This commit is contained in:
parent
e00acb1eaf
commit
cf2bf1518d
|
@ -122,7 +122,6 @@ class TestRealSimplifies(unittest.TestCase):
|
|||
View((8, 1, 6, 10, 28, 3, 2, 1), (5544, 0, 0, 56, 1, 1848, 672, 0), 0, None)])
|
||||
|
||||
class TestIndexExpressions2d(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
|
@ -188,6 +187,7 @@ class TestIndexExpressions2d(unittest.TestCase):
|
|||
st.expand((base_shape[0], base_shape[1], base_shape[1]))
|
||||
self.node_exprs.append(lambda idx, base_shape=base_shape, offset=offset: idx//(base_shape[1]*base_shape[1])%base_shape[0]*base_shape[1] + idx%base_shape[1] + offset)
|
||||
self.idxs_exprs.append(lambda idxs, base_shape=base_shape, offset=offset: idxs[0]*base_shape[1] + idxs[2] + offset)
|
||||
|
||||
def test_permute_reshape_1(self): # This tests multiple views
|
||||
for st, base_shape, offset in zip(self.sts, self.shapes, self.offset):
|
||||
st.permute((1, 0))
|
||||
|
@ -417,6 +417,94 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
|||
self.st.pad(((1,1), (1,1)))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshaping_splitting(self):
|
||||
self.st = CheckingShapeTracker((5,10,5,10))
|
||||
self.st.permute((1, 0, 3, 2))
|
||||
self.st.pad(((0,0), (0,5), (0,0), (0,5)))
|
||||
self.st.reshape((10,2,5,10,2,5))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshape_combining_1(self):
|
||||
self.st = CheckingShapeTracker((2,1,10))
|
||||
self.st.pad(((2,6), (0,0), (0,0)))
|
||||
self.st.reshape((100,))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.skip("Can't make this optimization yet")
|
||||
def test_reshape_combining_2(self):
|
||||
self.st = CheckingShapeTracker((1,1,5))
|
||||
self.st.pad(((3,6), (0,0), (0,5)))
|
||||
self.st.reshape((100,))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
@unittest.skip("Can't make this optimization yet")
|
||||
def test_reshape_splitting_combining(self):
|
||||
self.st = CheckingShapeTracker((1,5,5))
|
||||
self.st.pad(((0,4), (0,5), (0,0)))
|
||||
self.st.reshape((10,25))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_reshape_only_1s(self):
|
||||
self.st = CheckingShapeTracker((1, 1, 1, 4, 1, 3, 5, 1))
|
||||
self.st.pad(((0,4), (0,0), (0,0), (1,1), (0,0), (0,0), (0,0), (0,0)))
|
||||
self.st.reshape((5, 6, 3, 5))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 1, 5, 6, 3, 5, 1, 1))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 5, 6, 1, 3, 1, 5, 1))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_zero_mask_1(self):
|
||||
self.st = CheckingShapeTracker((1, 3, 2))
|
||||
self.st.pad(((0,0), (0,3), (0,0)))
|
||||
self.st.shrink(((0,1), (3,6), (0,2)))
|
||||
self.st.reshape((3,2))
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 3, 1, 2, 1))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_zero_mask_2(self):
|
||||
self.st = CheckingShapeTracker((1, 3, 2))
|
||||
self.st.pad(((0,2), (0,3), (0,0)))
|
||||
self.st.shrink(((2,3), (3,6), (0,2)))
|
||||
self.st.reshape((3,2))
|
||||
self.st.assert_same()
|
||||
self.st.reshape((1, 3, 1, 2, 1))
|
||||
self.st.assert_same()
|
||||
|
||||
def test_expanded_reshaped(self):
|
||||
self.st = CheckingShapeTracker((1, 3, 2, 1))
|
||||
self.st.expand((5, 3, 2, 2))
|
||||
self.st.pad(((0,0), (0,3), (0,0), (0, 0)))
|
||||
self.st.reshape((5, 2, 3, 2, 2))
|
||||
assert len(self.st.views) == 1
|
||||
self.st.assert_same()
|
||||
|
||||
def test_splitting_big(self):
|
||||
self.st = CheckingShapeTracker((1, 5, 1, 15, 1))
|
||||
self.st.pad(((0,0), (0,5), (0,0), (0,15), (0,0)))
|
||||
self.st.reshape((10, 1, 30))
|
||||
self.st.permute((2,1,0))
|
||||
self.st.reshape((2,3,5,2,5))
|
||||
assert len(self.st.views) == 1
|
||||
v = self.st.views[-1]
|
||||
assert v.strides == (15, 5, 1, 75, 15) and v.mask == ((0, 1), (0, 3), (0, 5), (0, 1), (0, 5))
|
||||
|
||||
def test_combining_big(self):
|
||||
self.st = CheckingShapeTracker((1,3,1,5,3,1))
|
||||
self.st.pad(((0,0),(2,2),(0,0),(0,0),(0,0),(0,0)))
|
||||
self.st.reshape((1,1,1,105,1,1))
|
||||
assert len(self.st.views) == 1
|
||||
v = self.st.views[-1]
|
||||
assert v.strides == (0, 0, 0, 1, 0, 0) and v.mask == ((0, 1), (0, 1), (0, 1), (30, 75), (0, 1), (0, 1)), v.offset == -30
|
||||
|
||||
class TestShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
|
|
|
@ -95,30 +95,58 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
|||
if None in strides: return None
|
||||
return View(vm1.shape, strides, mst.real_offset(), vm1.mask)
|
||||
|
||||
def _reshape_mask(view: View, new_shape:Tuple[int, ...]) -> Tuple[Optional[Tuple[Tuple[int, int], ...]],bool]:
|
||||
# assumes view can be reshaped to new_shape (if it had no mask), this implies we won't have to worry about strides
|
||||
if view.mask is None: return view.mask, False
|
||||
new_mask: List[Tuple[int, int]] = []
|
||||
r_masks, r_shape, r_new_shape = reversed(cast(Tuple[Tuple[int, int],...], view.mask)), reversed(view.shape), reversed(new_shape)
|
||||
stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
||||
while len(new_mask) < len(new_shape):
|
||||
if mask[1]-mask[0] < 1: # if the mask is never valid, just return all zeros
|
||||
return ((0,0),)*len(new_shape), False
|
||||
if old_dim == new_dim*stride: # easy, can just copy the mask
|
||||
new_mask.append((mask[0]//stride, (mask[1]-1)//stride+1))
|
||||
stride, old_dim, new_dim, mask = 1, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0,1))
|
||||
elif old_dim > new_dim: # splitting the old mask
|
||||
# we cannot split if the reshape cuts across the mask
|
||||
if (mask[0]%(new_dim*stride)!=0 or mask[1]%(new_dim*stride)!=0) and mask[0]//(new_dim*stride)!=(mask[1]-1)//(new_dim*stride):
|
||||
return view.mask, True
|
||||
new_mask.append((mask[0]%(new_dim*stride)//stride, (mask[1]-1)%(new_dim*stride)//stride+1))
|
||||
# the remaining mask still needs to be split, we need to determine the mask for the next dimension
|
||||
# we maintain the stride
|
||||
stride *= new_dim
|
||||
new_dim = next(r_new_shape, 1)
|
||||
elif old_dim < new_dim*stride: # combining masks
|
||||
next_mask = next(r_masks, (0,1))
|
||||
# if the current dimension is masked, we cannot merge unless the next masks have an index range of 1
|
||||
if (mask[0]!=0 or mask[1]!=old_dim) and next_mask[1]-next_mask[0]!=1:
|
||||
return view.mask, True
|
||||
# we combine the current mask with the next and go through the loop again with the next dimension
|
||||
mask = (next_mask[0]*old_dim+mask[0], (next_mask[1]-1)*old_dim+mask[1])
|
||||
old_dim *= next(r_shape, 1)
|
||||
for mask in (mask, *r_masks): # if the old shape has leading 1s, need to make sure their mask is (0,1), otherwise the mask is zero'd
|
||||
if mask != (0,1): return ((0,0),)*len(new_shape), False
|
||||
return tuple(reversed(new_mask)), False
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
shape, mask, strides, offset = view.shape, view.mask, view.strides, view.offset
|
||||
# check if this is adding or removing 1s (only)
|
||||
# NOTE: this is optional, but removes most calls to (expensive!) merge_views (with mask, not optional)
|
||||
if [x for x in shape if x != 1] == [x for x in new_shape if x != 1]:
|
||||
new_strides: List[int] = [y for x,y in zip(shape, strides) if x != 1]
|
||||
new_strides_tuple: Tuple[int, ...] = tuple([0 if x == 1 else new_strides.pop(0) for x in new_shape])
|
||||
new_mask_tuple = None
|
||||
if mask:
|
||||
for x,y in zip(shape, mask):
|
||||
if x == 1 and y != (0, 1):
|
||||
new_mask_tuple = ((0,0),) * len(new_shape)
|
||||
break
|
||||
else:
|
||||
new_mask: List[Tuple[int, int]] = [y for x,y in zip(shape, mask) if x != 1]
|
||||
new_mask_tuple = tuple([(0,1) if x == 1 else new_mask.pop(0) for x in new_shape])
|
||||
return View(new_shape, new_strides_tuple, offset, new_mask_tuple), False
|
||||
if view.contiguous or (view.shape == new_shape): return View(new_shape), False
|
||||
strides, reverse_shape = [], reversed(new_shape)
|
||||
for d, s in reversed(view.shape_strides):
|
||||
acc, new_stride = 1, s
|
||||
while acc < d:
|
||||
new_dim = next(reverse_shape)
|
||||
acc *= new_dim
|
||||
strides.append(new_stride)
|
||||
new_stride *= new_dim
|
||||
if acc != d: break
|
||||
else:
|
||||
strides += [0,] * (len(new_shape) - len(strides))
|
||||
mask, extra = _reshape_mask(view, new_shape)
|
||||
if not extra: return View(new_shape, tuple(reversed(strides)), view.offset, mask), False
|
||||
|
||||
new_view = View(new_shape)
|
||||
if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset
|
||||
if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False
|
||||
if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}")
|
||||
return new_view, True
|
||||
return View(new_shape), True
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_pad_args(shape:Tuple[int,...], arg:Tuple[Tuple[int, int], ...]):
|
||||
|
@ -147,7 +175,6 @@ class ShapeTracker:
|
|||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> int:
|
||||
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
|
||||
|
|
Loading…
Reference in New Issue