max to behave on ties like torch (#229)

* checkpoint

* fixing pow

* undo pow

* backward max on GPU and CPU rewrite

* indentation

* changing seed for curiosity

* max replaced equality

* undo seed

* rebase

* fixed tests

* merge error
This commit is contained in:
Marcel Bischoff 2020-12-30 18:52:50 -05:00 committed by GitHub
parent 30f8132646
commit e2f833f58f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 15 deletions

View File

@ -6,9 +6,12 @@ import timeit
import functools
from tinygrad.tensor import Tensor, GPU, ANE, Device
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False):
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False, vals=None):
torch.manual_seed(0)
ts = [torch.rand(x, requires_grad=True) for x in shps]
if shps is None:
ts = [torch.tensor(x, requires_grad=True) for x in vals]
else:
ts = [torch.rand(x, requires_grad=True) for x in shps]
tst = [Tensor(x.detach().numpy()) for x in ts]
if device==Device.GPU:
tst = [x.gpu() for x in tst]
@ -85,8 +88,10 @@ class TestOps(unittest.TestCase):
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device)
helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device,
vals=[
[[1.0,1.0,0.0,1.0]],
])
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
def test_mean_axis(self):
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)

View File

@ -115,18 +115,21 @@ register('sum', Sum)
class Max(Function):
@staticmethod
def forward(ctx, input, axis=None):
am = input.argmax(axis=axis)
am = np.expand_dims(am, axis=axis) if axis is not None else np.array([am])
ctx.save_for_backward(input.shape, am, axis)
return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis)
def forward(ctx, inp, axis=None):
axis = [axis] if type(axis) == int else axis
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
ctx.save_for_backward(inp, axis, ret)
if axis is not None:
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
return ret
@staticmethod
def backward(ctx, grad_output):
shape, am, axis = ctx.saved_tensors
ret = np.zeros(shape, dtype=np.float32)
np.put_along_axis(ret, am, grad_output.reshape(am.shape), axis=axis)
return ret
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = (input==ret.reshape(shape))
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
return ret2*grad_output.reshape(shape)/div
register('max', Max)
# ************* movement ops *************

View File

@ -242,8 +242,10 @@ class Max(Function):
def backward(ctx, grad_output):
input, axis, ret = ctx.saved_tensors
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
ret2 = binary_op(ctx, "1.0*(a == b)", input, GPUBuffer(shape, ret))
return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output))
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output))
register('max', Max, device=Device.GPU)
class Matmul(Function):