tinygrad/docs-legacy/reshape_without_symbolic.md

3.3 KiB

"View.reshape without symbolic"

This section contains the sketch proof of "Complete, Fast and Correct View.reshapes without using Symbolic". The goal is to reduce multi-views which cost runtime.

  1. old_shape = (s1,s2,...,si,s(i+1),...,sn)
  2. old_stride = (st1, st2, ... ,sti, st(i+1), ..., stn)
  3. merge_old_shape = (p1, p2), where p1 = s1 * ... * si & p2 = s(i+1) * ... * sn,
  4. new_shape = (k1, ..., kp, k(p+1), ..., kl)
  5. prod(new_shape) = p1 * p2 (trivial)
  6. mask and new_mask represent valid indexes before & after reshape respectively.

Assumption

p1 & p2 individually are mergeable (we will discuss later on this) & we cannot merge p1 & p2.

Claim

If prod([k1 ... kp]) < p1 and prod([k1 ... k(p+1)]) > p1, reshape is not possible.

Proof

k(p+1) will require some dimensions from p1 & some from p2, which means p1 & p2 should be mergeable, but they are not.

Conclusion

Hence, reshape is only possible if ∃ a p, where prod([k1 .. kp]) = p1.

Conditions for mergeability

Case 1 - All non-zero strides

They will merge if stx = st(x+1) * s(x+1), where x ∈ [1, ..., i-1, i+1, ..., n-1].

Proof

Lets consider merging of (s1 ... si) -> p1, here we have to get a single new stride corresponding to p1. For which it has to be contiguous.

Case 2 - Some stride is zero

Let stj = 0 & st(j+1) != 0 & s(j+1) > 1, where 1 < j < i.

If sj = 1 , reshape is trivial.

If sj > 1,

  • If maskj has range > 1, reshape is not possible, because s(j+1) will need to be repeated at-least once and a single stride can't capture repetition.
  • If maskj has range = 1, reshape is possible, since it is virtually shape = 1, with some offset.

Conditions for reshaping mask

Case 1 - Splitting Dimension - Mask shouldn't be cut for successful reshape.

  • Example - [1,2,3,4,5,6,7,8] -> 1,2,3,4], [5,6,7,8 ; mask = ((2,6)) ; new_mask[0] = (0,2) (trivial split).

  • new_mask[1] = not possible. It is only possible if mask spans [1-8] or lies within a single dimension [1-4] or [5-8].

Case 2 - Combining Dimension - Mask should unfold continuously.

  • Example - 1,2],[3,4],[5,6 -> [1,2,3,4,5,6]; mask = ((0,2),(0,2)).

  • new_mask = (0,4); only possible because mask1 span the whole dimension.

  • If mask1 did not span the whole dimension, the only way combining would be possible is if mask0 had range 1 as shown below.

    • 1,2,3],[4,5,6 -> [1,2,3,4,5,6]; mask = ((1,2),(0,2)); new_mask = ((3,5))