mirror of https://github.com/commaai/tinygrad.git
touchups
This commit is contained in:
parent
a2e6562330
commit
061e37de39
|
@ -108,9 +108,9 @@ Warning: do not rely on the ANE port. It segfaults sometimes. So if you were doi
|
|||
You need to support 15 basic ops:
|
||||
|
||||
```
|
||||
Add, Sub, Mul, Pow # binary ops
|
||||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Dot # matrix multiplication
|
||||
Conv2D, MaxPool2D # 2D ops
|
||||
Pad2D, Reshape, Transpose # moving things around ops
|
||||
|
|
|
@ -80,16 +80,15 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, device=self.device)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1), device=self.device)
|
||||
@cpu_only
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
|
||||
@cpu_only
|
||||
def test_max_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device)
|
||||
def test_sum_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)
|
||||
def test_logsoftmax(self):
|
||||
|
|
|
@ -66,7 +66,7 @@ class Sum(Function):
|
|||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, axis = ctx.saved_tensors
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
axis = [axis] if type(axis) is int else axis
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
return grad_output.reshape(shape) + np.zeros_like(input)
|
||||
register('sum', Sum)
|
||||
|
|
Loading…
Reference in New Issue