diff --git a/test/test_dtype.py b/test/test_dtype.py index b6c4a013..4036dbc1 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -218,5 +218,21 @@ class TestHelpers(unittest.TestCase): def test_scalar(self, dtype, amt): assert dtype.vec(amt).scalar() == dtype +class TestTypeSpec(unittest.TestCase): + def test_creation(self): + assert Tensor([]).dtype == Tensor.default_type + # assert Tensor([1]).dtype == dtypes.int + assert Tensor([1.1]).dtype == Tensor.default_type + + def test_const_full(self): + assert Tensor.ones([2,3]).dtype == Tensor.default_type + assert Tensor.zeros([2,3]).dtype == Tensor.default_type + assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type + # assert Tensor.full([2,3], 3).dtype == dtypes.int + + def test_reduce_0d_default(self): + assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type + # assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b371c671..49cee2d6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -473,7 +473,7 @@ class Tensor: axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn]) + if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn]) ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) return ret if keepdim else ret.reshape(shape=shape)