mirror of https://github.com/commaai/tinygrad.git
[ready] refactor getitem round 2 :D (#2568)
* new getitem * go * add temporary simple tests * better * comments * WOW that took awhile * save 1 line lol * work * still need to add comprehensive tests, but i think getitem looks nice :D * GIMME GREEN CI CHECKMARK PLS * try.. * k idk * added tests for errors * fixed small hack * added tests * almost good * try no contig? * yay no more contig + comments and spacing * finishing touches (comments) * revert regex unittests lol * add suggested change * oops I fell asleep yesterday
This commit is contained in:
parent
6ba6349c97
commit
f12bcccb87
|
@ -615,6 +615,7 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2])
|
||||
helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1])
|
||||
helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2], lambda x: x[None, None, 1, None, 2, 0:2])
|
||||
|
||||
def test_slice_one_endpoint_out_of_bounds(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4])
|
||||
|
@ -660,11 +661,11 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def test_slice_errors(self):
|
||||
a = Tensor.ones(4, 3)
|
||||
with self.assertRaises(IndexError):
|
||||
a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
|
||||
a[1, 77] # IndexError: (out of bounds).
|
||||
a[0, -77]
|
||||
a[..., ...] # IndexError: only single ellipsis
|
||||
with self.assertRaises(IndexError): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds)
|
||||
with self.assertRaises(IndexError): a[1, 77] # IndexError: (out of bounds).
|
||||
with self.assertRaises(IndexError): a[1, -77]
|
||||
with self.assertRaises(IndexError): a[..., ...] # IndexError: only single ellipsis
|
||||
with self.assertRaises(ValueError): a[::0, 1] # no 0 strides
|
||||
|
||||
def test_slice_ellipsis(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: x[..., 0], lambda x: x[..., 0])
|
||||
|
@ -1219,7 +1220,6 @@ class TestOps(unittest.TestCase):
|
|||
|
||||
def _get_index_randoms(self):
|
||||
# indices cannot have gradient
|
||||
# TODO currently does not support IndexError for out of bounds idx values
|
||||
a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False)
|
||||
|
@ -1275,6 +1275,41 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p])
|
||||
|
||||
def test_slice_fancy_indexing_tuple_indices(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0),b,c,d,:], lambda x: x[(0),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1),b,c,d,:], lambda x: x[(1),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0),b,c,d,:], lambda x: x[(1,0),j,k,o,:])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,(2,1,0),c,(2,1,0),e], lambda x: x[i,(2,1,0),k,(2,1,0),p])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,(2,1,0),None,c,(2,1,0),e], lambda x: x[1,(2,1,0),None,k,(2,1,0),p])
|
||||
|
||||
def test_slice_fancy_indexing_list_with_tensors(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a]], lambda x: x[[i]])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,1]], lambda x: x[[i,1]])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,[1,1]]], lambda x: x[[i,[1,1]]])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,(1,1)]], lambda x: x[[i,(1,1)]])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,b,c,d,e]], lambda x: x[[i,j,k,o,p]])
|
||||
|
||||
def test_slice_fancy_indexing_tuple_with_tensors(self):
|
||||
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()
|
||||
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,),], lambda x: x[(i,),]) TypeError: only integer tensors of a single element can be converted to an index
|
||||
# helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,1),], lambda x: x[(i,1),]) TypeError: only integer tensors of a single element can be converted to an index
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,[1,1])], lambda x: x[(i,[1,1])])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,(1,1))], lambda x: x[(i,(1,1))])
|
||||
helper_test_op([(2,5,6,5,3,4)], lambda x: x[(a,b,c,d,e)], lambda x: x[(i,j,k,o,p)])
|
||||
|
||||
def test_slice_fancy_indexing_errors(self): ...
|
||||
# TODO: currently we not support IndexError for out of bounds idx values
|
||||
# any out of bounds in fancy indexing returns 0
|
||||
# ex: Tensor([1,2])[Tensor([1,2,55])].numpy() -> array([2., 0., 0.], dtype=float32)
|
||||
# TODO: currently we do not support tensor indexing for list of list tensor
|
||||
# ex: torch.tensor([1,2])[[[[torch.tensor(1)]]]] -> tensor([[2]])
|
||||
# currently we return ValueError: setting an array element with a sequence.
|
||||
# E TypeError: only integer tensors of a single element can be converted to an index
|
||||
|
||||
|
||||
def test_gather(self):
|
||||
# indices cannot have gradient
|
||||
# indices cannot be negative (torch gather)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set, DefaultDict
|
||||
from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
|
@ -302,75 +302,91 @@ class Tensor:
|
|||
# - There's a special case where a permute is needed at the end:
|
||||
# - if first Tensor passed in (expand dims) is not at dim 0
|
||||
# - and following Tensors does not follow consecutively to the end of fancy indexing's dims
|
||||
def __getitem__(self, indices) -> Tensor: # indices: Union[int, slice, Tensor, None, Ellipsis, List, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]]
|
||||
def normalize_int(e, i, dim_sz):
|
||||
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
|
||||
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
|
||||
# TODO: boolean indices
|
||||
# TODO: figure out the exact acceptable types for indices, especially for internal list/tuple types
|
||||
# TODO: update docs
|
||||
def __getitem__(self, indices: Union[int, slice, Tensor, None, List, Tuple]) -> Tensor: # no ellipsis type...
|
||||
# 1. indices normalization and validation
|
||||
# treat internal tuples and lists as Tensors and standardize indices to list type
|
||||
if isinstance(indices, (tuple, list)):
|
||||
if isinstance(indices, list) and all(isinstance(i, int) for i in indices): indices = [Tensor(indices)] # special case <indices: List[int]>, a lil ugly
|
||||
else: indices = [Tensor(list(i)) if isinstance(i, (tuple, list)) else i for i in indices]
|
||||
else: indices = [indices]
|
||||
|
||||
# TODO: if indices is a tuple of any sequence, or if indices is a list, it's for advanced indexing
|
||||
orig_slices = list(indices) if isinstance(indices, tuple) else [indices]
|
||||
count = defaultdict(list)
|
||||
for i,v in enumerate(orig_slices): count[type(v)].append(i)
|
||||
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
||||
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
||||
fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices)
|
||||
num_slices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None)
|
||||
indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
|
||||
# TODO: boolean indices
|
||||
if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor]) + len(count[list])) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
if len(ellipsis_found := count[type(Ellipsis)]) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
|
||||
# use Dict[type, List[dimension]] to track elements in indices
|
||||
type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list)
|
||||
|
||||
# replace ellipsis with equivalent number of slice(None)
|
||||
# TODO: move all slice(None) to the end and transpose non-None to the front
|
||||
ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices)
|
||||
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
|
||||
# record None for dimension injection later
|
||||
type_dim[None] = [dim for dim, i in enumerate(indices) if i is None]
|
||||
|
||||
valid_slices = [v for v in orig_slices if v is not None]
|
||||
valid_slices = [v if isinstance(v, slice) else slice(y_ := normalize_int(v, i, dim_sz), y_+1) if isinstance(v, int) else slice(None) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
|
||||
# filter None and record rest of indices
|
||||
indices_filtered = tuple(v for v in indices if v is not None)
|
||||
for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim)
|
||||
|
||||
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
|
||||
# validation! raise Errors
|
||||
if len(ellipsis_idx) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
|
||||
if float in type_dim: raise IndexError("float type is not valid index")
|
||||
if any(isinstance(i, slice) and i.step == 0 for i in indices): raise ValueError('slice step cannot be 0')
|
||||
if num_slices > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
|
||||
for dim in type_dim[int]:
|
||||
if indices_filtered[dim] >= self.shape[dim] or indices_filtered[dim] < -self.shape[dim]: raise IndexError(f"index {indices_filtered[dim]} is out of bounds for dimension {dim} with size {self.shape[dim]}")
|
||||
|
||||
# normalize! indices -> start, stop, strides
|
||||
start, stop, strides = zip(*y) if (y := [i.indices(sh) if isinstance(i, slice) else slice(normalized:= i if i != -1 else sh-1, normalized+1, 1).indices(sh) if isinstance(i, int) else (0, sh, 1) for i, sh in zip(indices_filtered, self.shape)]) else ((), (), ()) # type: ignore[arg-type]
|
||||
|
||||
# 2. basic indexing (no copy)
|
||||
# apply slices and flip where strides are negative
|
||||
new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides))
|
||||
sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0])
|
||||
new_shape = sliced_tensor.shape
|
||||
new_shape = list(sliced_tensor.shape)
|
||||
|
||||
# add strides by pad -> reshape -> shrink
|
||||
if any(abs(s) != 1 for s in strides):
|
||||
strides = tuple(abs(s) for s in strides)
|
||||
# Pad: add pad at the end: [dim_sz] -> [dim_sz_padded]
|
||||
padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape)))
|
||||
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
|
||||
reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides)))
|
||||
new_shape = reshaped_tensor.shape[::2]
|
||||
# Shrink: do [:, 0]
|
||||
new_shape = list(reshaped_tensor.shape[::2])
|
||||
sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape)))
|
||||
|
||||
final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0
|
||||
for i,s in enumerate(orig_slices):
|
||||
if s is None: final_shape.append(1)
|
||||
else: # s is int or slice or Tensor
|
||||
dim_shape = next(it_shape)
|
||||
if isinstance(s, list): s = Tensor(s)
|
||||
if isinstance(s, int): dim_collapsed += 1
|
||||
else:
|
||||
assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
|
||||
final_shape.append(dim_shape)
|
||||
if isinstance(s, Tensor):
|
||||
tensors.append(s)
|
||||
dim.append(i-dim_collapsed)
|
||||
ret = sliced_tensor.reshape(tuple(final_shape))
|
||||
# inject dim=1 for None and collapse dim for int
|
||||
for dim in type_dim[None]: new_shape.insert(dim, 1)
|
||||
for dim in (dims_collapsed := [dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int])]): new_shape.pop(dim)
|
||||
for dim_sh in new_shape: assert isinstance(dim_sh, int), f"does not support symbolic shape {dim_sh}"
|
||||
|
||||
ret = sliced_tensor.reshape(tuple(new_shape))
|
||||
|
||||
# 3. advanced indexing (copy)
|
||||
if type_dim[Tensor]:
|
||||
|
||||
# extract tensors and tensor dimensions
|
||||
idx, tdim = [], []
|
||||
for tensor_dim in type_dim[Tensor]:
|
||||
dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d)
|
||||
tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
|
||||
idx.append((t := indices[tensor_dim + dims_injected]).sign().__neg__().relu() * ret.shape[td] + t) # normalize the negative tensor indices
|
||||
|
||||
if tensors: # Fancy/tensor indexing
|
||||
# normalize idx
|
||||
# TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
|
||||
idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)]
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
# compute sum_dim, arange, and idx
|
||||
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)]
|
||||
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))]
|
||||
first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))]
|
||||
rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
|
||||
arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))]
|
||||
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
|
||||
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
||||
idx = first_idx + rest_idx
|
||||
ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:])
|
||||
# iteratively fancy index
|
||||
|
||||
# iteratively eq -> mul -> sum fancy index
|
||||
for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd)
|
||||
|
||||
# special permute case
|
||||
if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)):
|
||||
if tdim[0] != 0 and len(tdim) != 1 and tdim != list(range(tdim[0], tdim[-1]+1)):
|
||||
ret_dims = list(range(ret.ndim))
|
||||
ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:])
|
||||
ret = ret.permute(ret_dims[tdim[0]:tdim[0]+max_dim] + ret_dims[:tdim[0]] + ret_dims[tdim[0]+max_dim:])
|
||||
return ret
|
||||
|
||||
def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v)
|
||||
|
|
Loading…
Reference in New Issue