mirror of https://github.com/commaai/tinygrad.git
some dtype creation spec test cases (#2722)
This commit is contained in:
parent
ee9e1d3662
commit
4075208127
|
@ -218,5 +218,21 @@ class TestHelpers(unittest.TestCase):
|
|||
def test_scalar(self, dtype, amt):
|
||||
assert dtype.vec(amt).scalar() == dtype
|
||||
|
||||
class TestTypeSpec(unittest.TestCase):
|
||||
def test_creation(self):
|
||||
assert Tensor([]).dtype == Tensor.default_type
|
||||
# assert Tensor([1]).dtype == dtypes.int
|
||||
assert Tensor([1.1]).dtype == Tensor.default_type
|
||||
|
||||
def test_const_full(self):
|
||||
assert Tensor.ones([2,3]).dtype == Tensor.default_type
|
||||
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
|
||||
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
|
||||
# assert Tensor.full([2,3], 3).dtype == dtypes.int
|
||||
|
||||
def test_reduce_0d_default(self):
|
||||
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
|
||||
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -473,7 +473,7 @@ class Tensor:
|
|||
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
|
||||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_)
|
||||
if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn])
|
||||
if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0.0, mlops.Max: -float("inf")}[fxn])
|
||||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue