mirror of https://github.com/commaai/tinygrad.git
update some test_tensor.py cases with 0 in shape (#2368)
This commit is contained in:
parent
6add808f6a
commit
c4cc4966ed
|
@ -1,6 +1,5 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import struct
|
||||
import unittest, copy
|
||||
import mmap
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
|
@ -179,6 +178,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}"
|
||||
|
||||
def test_ndim(self):
|
||||
assert Tensor(1).ndim == 0
|
||||
assert Tensor.randn(1).ndim == 1
|
||||
assert Tensor.randn(2,2,2).ndim == 3
|
||||
assert Tensor.randn(1,1,1,1,1,1).ndim == 6
|
||||
|
@ -216,7 +216,7 @@ class TestTinygrad(unittest.TestCase):
|
|||
assert Tensor.randn(1,2,5).numel() == 10
|
||||
assert Tensor.randn(1,1,1,1,1,1).numel() == 1
|
||||
assert Tensor([]).numel() == 0
|
||||
# assert Tensor.randn(1,0,2,5) == 0 # TODO: fix empty tensors
|
||||
assert Tensor.randn(1,0,2,5).numel() == 0
|
||||
|
||||
def test_element_size(self):
|
||||
for _, dtype in dtypes.fields().items():
|
||||
|
@ -230,8 +230,8 @@ class TestTinygrad(unittest.TestCase):
|
|||
x.dot(layer).mean().backward()
|
||||
|
||||
def test_zerosized_tensors(self):
|
||||
Tensor([]).realize()
|
||||
Tensor([]).numpy()
|
||||
np.testing.assert_equal(Tensor([]).numpy(), np.array([]))
|
||||
np.testing.assert_equal(Tensor(None).numpy(), np.array([]))
|
||||
|
||||
def test_tensor_ndarray_dtype(self):
|
||||
arr = np.array([1]) # where dtype is implicitly int64
|
||||
|
@ -269,10 +269,6 @@ class TestTinygrad(unittest.TestCase):
|
|||
np.testing.assert_allclose(ua_arr, (Tensor(ua_arr)/Tensor(1)).numpy())
|
||||
|
||||
class TestZeroShapeTensor(unittest.TestCase):
|
||||
def test_from_empty(self):
|
||||
np.testing.assert_equal(Tensor([]).numpy(), np.array([]))
|
||||
np.testing.assert_equal(Tensor(None).numpy(), np.array([]))
|
||||
|
||||
def test_shape_stride(self):
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
assert t.shape == (3, 2, 0)
|
||||
|
|
Loading…
Reference in New Issue