Add tensor.numel (#869)

* add tensor.numel

* add tensor.numel
This commit is contained in:
Nima Khodaveisi 2023-05-31 01:08:09 +02:00 committed by GitHub
parent 2e393f7ef2
commit 5670123d88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 0 deletions

View File

@ -179,5 +179,9 @@ class TestTinygrad(unittest.TestCase):
assert a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8, "a.dtype should be float and b.dtype should be char"
assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}"
def test_numel(self):
a = Tensor.empty(6, 12, 79)
self.assertTrue(a.numel() == 5688)
if __name__ == '__main__':
unittest.main()

View File

@ -158,6 +158,8 @@ class Tensor:
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
cond = (self != 0.0)
return cond * input_ + (1.0 - cond) * other
def numel(self): return prod(self.shape)
# ***** (numpy) rng helper functions *****
# TODO: move randomness generation out of numpy