mirror of https://github.com/commaai/tinygrad.git
max to behave on ties like torch (#229)
* checkpoint * fixing pow * undo pow * backward max on GPU and CPU rewrite * indentation * changing seed for curiosity * max replaced equality * undo seed * rebase * fixed tests * merge error
This commit is contained in:
parent
30f8132646
commit
e2f833f58f
|
@ -6,9 +6,12 @@ import timeit
|
|||
import functools
|
||||
from tinygrad.tensor import Tensor, GPU, ANE, Device
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, device=Device.CPU, forward_only=False, vals=None):
|
||||
torch.manual_seed(0)
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else:
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
if device==Device.GPU:
|
||||
tst = [x.gpu() for x in tst]
|
||||
|
@ -85,8 +88,10 @@ class TestOps(unittest.TestCase):
|
|||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max(), Tensor.max, device=self.device)
|
||||
helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0].mul(0.5), lambda x: Tensor.max(x, axis=1).mul(0.5), device=self.device)
|
||||
helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), device=self.device,
|
||||
vals=[
|
||||
[[1.0,1.0,0.0,1.0]],
|
||||
])
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1), device=self.device)
|
||||
def test_mean_axis(self):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2)), device=self.device)
|
||||
|
|
|
@ -115,18 +115,21 @@ register('sum', Sum)
|
|||
|
||||
class Max(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, axis=None):
|
||||
am = input.argmax(axis=axis)
|
||||
am = np.expand_dims(am, axis=axis) if axis is not None else np.array([am])
|
||||
ctx.save_for_backward(input.shape, am, axis)
|
||||
return np.take_along_axis(input, am, axis=axis).squeeze(axis=axis)
|
||||
def forward(ctx, inp, axis=None):
|
||||
axis = [axis] if type(axis) == int else axis
|
||||
ret = np.amax(inp, axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
ctx.save_for_backward(inp, axis, ret)
|
||||
if axis is not None:
|
||||
ret = ret.reshape([inp.shape[i] for i in range(len(inp.shape)) if i not in axis])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
shape, am, axis = ctx.saved_tensors
|
||||
ret = np.zeros(shape, dtype=np.float32)
|
||||
np.put_along_axis(ret, am, grad_output.reshape(am.shape), axis=axis)
|
||||
return ret
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = (input==ret.reshape(shape))
|
||||
div = ret2.sum(axis=None if axis is None else tuple(axis), keepdims=True)
|
||||
return ret2*grad_output.reshape(shape)/div
|
||||
register('max', Max)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
|
|
@ -242,8 +242,10 @@ class Max(Function):
|
|||
def backward(ctx, grad_output):
|
||||
input, axis, ret = ctx.saved_tensors
|
||||
shape = [1 if axis is None or i in axis else input.shape[i] for i in range(len(input.shape))]
|
||||
ret2 = binary_op(ctx, "1.0*(a == b)", input, GPUBuffer(shape, ret))
|
||||
return binary_op(ctx, 'a*b', ret2, GPUBuffer(shape, grad_output))
|
||||
ret2 = binary_op(ctx, "1.0*(a==b)", input, GPUBuffer(shape, ret))
|
||||
div = reduce_op(ctx, "out += a", "out+1e-10", ret2, axis=axis)
|
||||
ret3 = binary_op(ctx, "a/b", ret2, GPUBuffer(shape, div))
|
||||
return binary_op(ctx, 'a*b', ret3, GPUBuffer(shape, grad_output))
|
||||
register('max', Max, device=Device.GPU)
|
||||
|
||||
class Matmul(Function):
|
||||
|
|
Loading…
Reference in New Issue