instant identity removal

This commit is contained in:
George Hotz 2023-02-25 09:46:04 -08:00
parent a8de233e12
commit 8b96522e1d
2 changed files with 14 additions and 5 deletions

View File

@ -72,11 +72,20 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div)
def test_div_const(self):
helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1)
helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x)
helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2)
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
def test_pow(self):
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
def test_pow_const(self):
helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0)
helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x)
helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0)
helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0)
def test_relu(self):

View File

@ -382,11 +382,11 @@ class Tensor:
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
def add(self, x, reverse=False): return self._broadcasted(mlops.Add, x, reverse)
def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse)
def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse)
def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse)
def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse)
def add(self, x, reverse=False): return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def matmul(self, x:Tensor, reverse=False): return x.dot(self) if reverse else self.dot(x)
def maximum(self, x): return self._broadcasted(mlops.Maximum, x)