mirror of https://github.com/commaai/tinygrad.git
parent
2e393f7ef2
commit
5670123d88
|
@ -179,5 +179,9 @@ class TestTinygrad(unittest.TestCase):
|
|||
assert a.dtype != b.dtype and a.dtype == dtypes.float32 and b.dtype == dtypes.int8, "a.dtype should be float and b.dtype should be char"
|
||||
assert a.shape == b.shape, f"shape mismatch (Tensor.ones_like){a.shape} != (torch){b.shape}"
|
||||
|
||||
def test_numel(self):
|
||||
a = Tensor.empty(6, 12, 79)
|
||||
self.assertTrue(a.numel() == 5688)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -158,6 +158,8 @@ class Tensor:
|
|||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
cond = (self != 0.0)
|
||||
return cond * input_ + (1.0 - cond) * other
|
||||
|
||||
def numel(self): return prod(self.shape)
|
||||
|
||||
# ***** (numpy) rng helper functions *****
|
||||
# TODO: move randomness generation out of numpy
|
||||
|
|
Loading…
Reference in New Issue