update some test_tensor.py cases with 0 in shape (#2368)

This commit is contained in:
chenyu 2023-11-19 20:35:05 -05:00 committed by GitHub
parent 6add808f6a
commit c4cc4966ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 8 deletions

View File

@ -1,6 +1,5 @@
import numpy as np
import torch
import struct
import unittest, copy
import mmap
from tinygrad.tensor import Tensor, Device
@ -179,6 +178,7 @@ class TestTinygrad(unittest.TestCase):
assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}"
def test_ndim(self):
assert Tensor(1).ndim == 0
assert Tensor.randn(1).ndim == 1
assert Tensor.randn(2,2,2).ndim == 3
assert Tensor.randn(1,1,1,1,1,1).ndim == 6
@ -216,7 +216,7 @@ class TestTinygrad(unittest.TestCase):
assert Tensor.randn(1,2,5).numel() == 10
assert Tensor.randn(1,1,1,1,1,1).numel() == 1
assert Tensor([]).numel() == 0
# assert Tensor.randn(1,0,2,5) == 0 # TODO: fix empty tensors
assert Tensor.randn(1,0,2,5).numel() == 0
def test_element_size(self):
for _, dtype in dtypes.fields().items():
@ -230,8 +230,8 @@ class TestTinygrad(unittest.TestCase):
x.dot(layer).mean().backward()
def test_zerosized_tensors(self):
Tensor([]).realize()
Tensor([]).numpy()
np.testing.assert_equal(Tensor([]).numpy(), np.array([]))
np.testing.assert_equal(Tensor(None).numpy(), np.array([]))
def test_tensor_ndarray_dtype(self):
arr = np.array([1]) # where dtype is implicitly int64
@ -269,10 +269,6 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(ua_arr, (Tensor(ua_arr)/Tensor(1)).numpy())
class TestZeroShapeTensor(unittest.TestCase):
def test_from_empty(self):
np.testing.assert_equal(Tensor([]).numpy(), np.array([]))
np.testing.assert_equal(Tensor(None).numpy(), np.array([]))
def test_shape_stride(self):
t = Tensor.rand(3, 2, 0)
assert t.shape == (3, 2, 0)