touchup test_indexing (#3169)

This commit is contained in:
chenyu 2024-01-18 14:32:43 -05:00 committed by GitHub
parent a04e4d0442
commit 097b1390ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 5 deletions

View File

@ -1,12 +1,12 @@
# test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py
import math, unittest, random, copy, warnings
import unittest, random, copy, warnings
import numpy as np
from tinygrad import Tensor, dtypes, Device, TinyJit
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.helpers import CI
from tinygrad.helpers import CI, all_same, prod
random.seed(42)
@ -16,13 +16,12 @@ def numpy_testing_assert_equal_helper(a, b):
np.testing.assert_equal(a, b)
def consec(shape, start=1):
return Tensor.arange(math.prod(shape)).reshape(shape)+start
return Tensor.arange(prod(shape)).reshape(shape)+start
# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
def set_(reference: Tensor, shape, strides, offset):
if reference.lazydata.base.realized is None: reference.realize()
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
# TODO: this shouldn't directly create a LazyBuffer
strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
return strided
@ -48,7 +47,7 @@ def all_(tensor:Tensor) -> Tensor:
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
def diagonal(tensor:Tensor) -> Tensor:
assert tensor.ndim == 2 and all(sh == sh[0] for sh in tensor.shape), 'only support 2 ndim square tensors'
assert tensor.ndim == 2 and all_same(tensor.shape), 'only support 2 ndim square tensors'
return (Tensor.eye(tensor.shape[0]) * tensor).sum(0)
# https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html