some dtype creation spec test cases (#2722)

This commit is contained in:
chenyu 2023-12-11 19:33:49 -05:00 committed by GitHub
parent ee9e1d3662
commit 4075208127
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 1 deletions

View File

@ -218,5 +218,21 @@ class TestHelpers(unittest.TestCase):
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
class TestTypeSpec(unittest.TestCase):
def test_creation(self):
assert Tensor([]).dtype == Tensor.default_type
# assert Tensor([1]).dtype == dtypes.int
assert Tensor([1.1]).dtype == Tensor.default_type
def test_const_full(self):
assert Tensor.ones([2,3]).dtype == Tensor.default_type
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
# assert Tensor.full([2,3], 3).dtype == dtypes.int
def test_reduce_0d_default(self):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
if __name__ == '__main__':
unittest.main()

View File

@ -473,7 +473,7 @@ class Tensor:
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn])
if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn])
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
return ret if keepdim else ret.reshape(shape=shape)