save 2 lines

This commit is contained in:
George Hotz 2020-12-29 14:02:50 -05:00
parent ea341c84fe
commit fb6aaefb9b
1 changed files with 2 additions and 4 deletions

View File

@ -203,8 +203,7 @@ class Tensor:
def mean(self, axis=None):
out = self.sum(axis=axis)
coeff = np.prod(out.shape)/np.prod(self.shape)
return out * coeff
return out * (np.prod(out.shape)/np.prod(self.shape))
def sqrt(self):
return self.pow(0.5)
@ -242,8 +241,7 @@ class Tensor:
# TODO: this needs a test
if Tensor.training:
_mask = np.asarray(np.random.binomial(1, 1.0-p, size=self.shape), dtype=self.dtype)
ret = self * Tensor(_mask, requires_grad=False, device=self.device)
return ret * (1/(1.0 - p))
return self * Tensor(_mask, requires_grad=False, device=self.device) * (1/(1.0 - p))
else:
return self