diff --git a/test/test_ops.py b/test/test_ops.py index ea690c87..32f5a554 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -6,9 +6,12 @@ import timeit import functools from tinygrad.tensor import Tensor, GPU, ANE, Device -def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False): +def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False, vals=None): torch.manual_seed(0) - ts = [torch.rand(x, requires_grad=True) for x in shps] + if shps is None: + ts = [torch.tensor(x, requires_grad=True) for x in vals] + else: + ts = [torch.rand(x, requires_grad=True) for x in shps] tst = [Tensor(x.detach().numpy()) for x in ts] if device==Device.GPU: tst = [x.gpu() for x in tst] @@ -85,8 +88,10 @@ class TestOps(unittest.TestCase): def test_max(self): helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device) helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device) - helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device) - helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device) + helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device, + vals=[ + [[1.0,1.0,0.0,1.0]], + ]) helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device) def test_mean_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device) diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 78691ca4..5fd8c787 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -115,18 +115,21 @@ register('sum', Sum) class Max(Function): @staticmethod - def forward(ctx, input, axis=None): - am = input.argmax(axis=axis) - am = np.expand_dims(am, axis=axis) if axis is not None else np.array([am]) - ctx.save_for_backward(input.shape, am, axis) - return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis) + def forward(ctx, inp, axis=None): + axis = [axis] if type(axis) == int else axis + ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True) + ctx.save_for_backward(inp, axis, ret) + if axis is not None: + ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis]) + return ret @staticmethod def backward(ctx, grad_output): - shape, am, axis = ctx.saved_tensors - ret = np.zeros(shape, dtype=np.float32) - np.put_along_axis(ret, am, grad_output.reshape(am.shape), axis=axis) - return ret + input, axis, ret = ctx.saved_tensors + shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] + ret2 = (input==ret.reshape(shape)) + div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True) + return ret2*grad_output.reshape(shape)/div register('max', Max) # ************* movement ops ************* diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index 51746cd1..d9b34aee 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -242,8 +242,10 @@ class Max(Function): def backward(ctx, grad_output): input, axis, ret = ctx.saved_tensors shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))] - ret2 = binary_op(ctx, "1.0*(a == b)", input, GPUBuffer(shape, ret)) - return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output)) + ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret)) + div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis) + ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div)) + return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output)) register('max', Max, device=Device.GPU) class Matmul(Function):