mirror of https://github.com/commaai/tinygrad.git
touchup test_indexing (#3169)
This commit is contained in:
parent
a04e4d0442
commit
097b1390ec
|
@ -1,12 +1,12 @@
|
|||
# test cases are modified from pytorch test_indexing.py https://github.com/pytorch/pytorch/blob/597d3fb86a2f3b8d6d8ee067e769624dcca31cdb/test/test_indexing.py
|
||||
|
||||
import math, unittest, random, copy, warnings
|
||||
import unittest, random, copy, warnings
|
||||
import numpy as np
|
||||
|
||||
from tinygrad import Tensor, dtypes, Device, TinyJit
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.helpers import CI, all_same, prod
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
@ -16,13 +16,12 @@ def numpy_testing_assert_equal_helper(a, b):
|
|||
np.testing.assert_equal(a, b)
|
||||
|
||||
def consec(shape, start=1):
|
||||
return Tensor.arange(math.prod(shape)).reshape(shape)+start
|
||||
return Tensor.arange(prod(shape)).reshape(shape)+start
|
||||
|
||||
# creates strided tensor with base set to reference tensor's base, equivalent to torch.set_()
|
||||
def set_(reference: Tensor, shape, strides, offset):
|
||||
if reference.lazydata.base.realized is None: reference.realize()
|
||||
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
|
||||
# TODO: this shouldn't directly create a LazyBuffer
|
||||
strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
|
||||
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
return strided
|
||||
|
@ -48,7 +47,7 @@ def all_(tensor:Tensor) -> Tensor:
|
|||
|
||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||
def diagonal(tensor:Tensor) -> Tensor:
|
||||
assert tensor.ndim == 2 and all(sh == sh[0] for sh in tensor.shape), 'only support 2 ndim square tensors'
|
||||
assert tensor.ndim == 2 and all_same(tensor.shape), 'only support 2 ndim square tensors'
|
||||
return (Tensor.eye(tensor.shape[0]) * tensor).sum(0)
|
||||
|
||||
# https://numpy.org/doc/stable/reference/generated/numpy.unravel_index.html
|
||||
|
|
Loading…
Reference in New Issue