[ready] refactor getitem round 2 :D (#2568)

* new getitem

* go

* add temporary simple tests

* better

* comments

* WOW that took awhile

* save 1 line lol

* work

* still need to add comprehensive tests, but i think getitem looks nice :D

* GIMME GREEN CI CHECKMARK PLS

* try..

* k idk

* added tests for errors

* fixed small hack

* added tests

* almost good

* try no contig?

* yay no more contig + comments and spacing

* finishing touches (comments)

* revert regex unittests lol

* add suggested change

* oops I fell asleep yesterday
This commit is contained in:
geohotstan 2023-12-05 11:36:32 +08:00 committed by GitHub
parent 6ba6349c97
commit f12bcccb87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 107 additions and 56 deletions

View File

@ -615,6 +615,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None])
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2])
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1])
helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2], lambda x: x[None, None, 1, None, 2, 0:2])
def test_slice_one_endpoint_out_of_bounds(self):
helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4])
@ -660,11 +661,11 @@ class TestOps(unittest.TestCase):
def test_slice_errors(self):
a = Tensor.ones(4, 3)
with self.assertRaises(IndexError):
a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
a[1, 77] # IndexError: (out of bounds).
a[0, -77]
a[..., ...] # IndexError: only single ellipsis
with self.assertRaises(IndexError): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
with self.assertRaises(IndexError): a[1, 77] # IndexError: (out of bounds).
with self.assertRaises(IndexError): a[1, -77]
with self.assertRaises(IndexError): a[..., ...] # IndexError: only single ellipsis
with self.assertRaises(ValueError): a[::0, 1] # no 0 strides
def test_slice_ellipsis(self):
helper_test_op([(3,3,3,3)], lambda x: x[..., 0], lambda x: x[..., 0])
@ -1219,7 +1220,6 @@ class TestOps(unittest.TestCase):
def _get_index_randoms(self):
# indices cannot have gradient
# TODO currently does not support IndexError for out of bounds idx values
a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False)
b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False)
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
@ -1275,6 +1275,41 @@ class TestOps(unittest.TestCase):
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p])
def test_slice_fancy_indexing_tuple_indices(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0),b,c,d,:], lambda x: x[(0),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1),b,c,d,:], lambda x: x[(1),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0),b,c,d,:], lambda x: x[(1,0),j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,(2,1,0),c,(2,1,0),e], lambda x: x[i,(2,1,0),k,(2,1,0),p])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,(2,1,0),None,c,(2,1,0),e], lambda x: x[1,(2,1,0),None,k,(2,1,0),p])
def test_slice_fancy_indexing_list_with_tensors(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a]], lambda x: x[[i]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,1]], lambda x: x[[i,1]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,[1,1]]], lambda x: x[[i,[1,1]]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,(1,1)]], lambda x: x[[i,(1,1)]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,b,c,d,e]], lambda x: x[[i,j,k,o,p]])
def test_slice_fancy_indexing_tuple_with_tensors(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,),], lambda x: x[(i,),]) TypeError: only integer tensors of a single element can be converted to an index
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,1),], lambda x: x[(i,1),]) TypeError: only integer tensors of a single element can be converted to an index
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,[1,1])], lambda x: x[(i,[1,1])])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,(1,1))], lambda x: x[(i,(1,1))])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,b,c,d,e)], lambda x: x[(i,j,k,o,p)])
def test_slice_fancy_indexing_errors(self): ...
# TODO: currently we not support IndexError for out of bounds idx values
# any out of bounds in fancy indexing returns 0
# ex: Tensor([1,2])[Tensor([1,2,55])].numpy() -> array([2., 0., 0.], dtype=float32)
# TODO: currently we do not support tensor indexing for list of list tensor
# ex: torch.tensor([1,2])[[[[torch.tensor(1)]]]] -> tensor([[2]])
# currently we return ValueError: setting an array element with a sequence.
# E TypeError: only integer tensors of a single element can be converted to an index
def test_gather(self):
# indices cannot have gradient
# indices cannot be negative (torch gather)

View File

@ -1,7 +1,7 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set, DefaultDict
from collections import defaultdict
from functools import partialmethod, reduce
from itertools import accumulate
@ -302,75 +302,91 @@ class Tensor:
# - There's a special case where a permute is needed at the end:
# - if first Tensor passed in (expand dims) is not at dim 0
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
def __getitem__(self, indices) -> Tensor: # indices: Union[int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
# TODO: boolean indices
# TODO: figure out the exact acceptable types for indices, especially for internal list/tuple types
# TODO: update docs
def __getitem__(self, indices: Union[int, slice, Tensor, None, List, Tuple]) -> Tensor: # no ellipsis type...
# 1. indices normalization and validation
# treat internal tuples and lists as Tensors and standardize indices to list type
if isinstance(indices, (tuple, list)):
if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices)] # special case <indices: List[int]>, a lil ugly
else: indices = [Tensor(list(i)) if isinstance(i, (tuple, list)) else i for i in indices]
else: indices = [indices]
# TODO: if indices is a tuple of any sequence, or if indices is a list, it's for advanced indexing
orig_slices = list(indices) if isinstance(indices, tuple) else [indices]
count = defaultdict(list)
for i,v in enumerate(orig_slices): count[type(v)].append(i)
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
num_slices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
# TODO: boolean indices
if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor]) + len(count[list])) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
if len(ellipsis_found := count[type(Ellipsis)]) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
# use Dict[type, List[dimension]] to track elements in indices
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
# replace ellipsis with equivalent number of slice(None)
# TODO: move all slice(None) to the end and transpose non-None to the front
ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices)
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
# record None for dimension injection later
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
valid_slices = [v for v in orig_slices if v is not None]
valid_slices = [v if isinstance(v, slice) else slice(y_ := normalize_int(v, i, dim_sz), y_+1) if isinstance(v, int) else slice(None) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
# filter None and record rest of indices
indices_filtered = tuple(v for v in indices if v is not None)
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
# validation! raise Errors
if len(ellipsis_idx) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
if float in type_dim: raise IndexError("float type is not valid index")
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
for dim in type_dim[int]:
if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]: raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}")
# normalize! indices -> start, stop, strides
start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type]
# 2. basic indexing (no copy)
# apply slices and flip where strides are negative
new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides))
sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
new_shape = sliced_tensor.shape
new_shape = list(sliced_tensor.shape)
# add strides by pad -> reshape -> shrink
if any(abs(s) != 1 for s in strides):
strides = tuple(abs(s) for s in strides)
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
new_shape = reshaped_tensor.shape[::2]
# Shrink: do [:, 0]
new_shape = list(reshaped_tensor.shape[::2])
sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))
final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0
for i,s in enumerate(orig_slices):
if s is None: final_shape.append(1)
else: # s is int or slice or Tensor
dim_shape = next(it_shape)
if isinstance(s, list): s = Tensor(s)
if isinstance(s, int): dim_collapsed += 1
else:
assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
final_shape.append(dim_shape)
if isinstance(s, Tensor):
tensors.append(s)
dim.append(i-dim_collapsed)
ret = sliced_tensor.reshape(tuple(final_shape))
# inject dim=1 for None and collapse dim for int
for dim in type_dim[None]: new_shape.insert(dim, 1)
for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
for dim_sh in new_shape: assert isinstance(dim_sh, int), f"does not support symbolic shape {dim_sh}"
ret = sliced_tensor.reshape(tuple(new_shape))
# 3. advanced indexing (copy)
if type_dim[Tensor]:
# extract tensors and tensor dimensions
idx, tdim = [], []
for tensor_dim in type_dim[Tensor]:
dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d)
tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
idx.append((t := indices[tensor_dim + dims_injected]).sign().__neg__().relu() * ret.shape[td] + t) # normalize the negative tensor indices
if tensors: # Fancy/tensor indexing
# normalize idx
# TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)]
max_dim = max(i.ndim for i in idx)
# compute sum_dim, arange, and idx
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)]
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))]
first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))]
rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)]
max_dim = max(i.ndim for i in idx)
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))]
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
idx = first_idx + rest_idx
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
# iteratively fancy index
# iteratively eq -> mul -> sum fancy index
for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
# special permute case
if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)):
if tdim[0] != 0 and len(tdim) != 1 and tdim != list(range(tdim[0], tdim[-1]+1)):
ret_dims = list(range(ret.ndim))
ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:])
ret = ret.permute(ret_dims[tdim[0]:tdim[0]+max_dim] + ret_dims[:tdim[0]] + ret_dims[tdim[0]+max_dim:])
return ret
def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)