diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 1feb0550..b9b77a5a 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -120,10 +120,8 @@ def Unsqueeze(data: Tensor, axes): def Binarizer(x, threshold=0.0): return (x > threshold).float() def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0): - axis = axis + x.ndim if axis < 0 else axis - m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis)) - c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m - return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64) + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index bb08c585..669515d1 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -81,8 +81,6 @@ backend_test.exclude('test_convinteger_*') backend_test.exclude('test_matmulinteger_*') # we don't support indexes -# backend_test.exclude('test_argmax_*') # Needs more work: select_last_index -# backend_test.exclude('test_argmin_*') # Needs more work: select_last_index backend_test.exclude('test_nonzero_*') # no support for mod @@ -128,10 +126,6 @@ backend_test.exclude('test_bitwise_*') backend_test.exclude('test_blackmanwindow_*') backend_test.exclude('test_bernoulli_*') backend_test.exclude('test_det_*') - -backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support -backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support - backend_test.exclude('test_col2im_*') backend_test.exclude('test_hammingwindow_*') backend_test.exclude('test_hannwindow_*') diff --git a/test/test_ops.py b/test/test_ops.py index 5c085946..1ee710cb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -252,12 +252,14 @@ class TestOps(unittest.TestCase): helper_test_op([(3,3)], lambda x: x.tril(1)) helper_test_op([(3,3)], lambda x: x.tril(-1)) helper_test_op([(5,3,3)], lambda x: x.tril()) + helper_test_op([(5,0,3)], lambda x: x.tril()) helper_test_op([(5,3,3)], lambda x: x.tril(1)) def test_triu(self): helper_test_op([(3,3)], lambda x: x.triu()) helper_test_op([(3,3)], lambda x: x.triu(1)) helper_test_op([(3,3)], lambda x: x.triu(-1)) helper_test_op([(5,3,3)], lambda x: x.triu()) + helper_test_op([(5,0,3)], lambda x: x.triu()) helper_test_op([(5,3,3)], lambda x: x.triu(1)) def test_maximum(self): helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index de436480..4ee59c9d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -713,6 +713,7 @@ class Tensor: @staticmethod def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor: assert all_int((r,c)), "does not support symbolic" + if r == 0: return Tensor.zeros((r, c), **kwargs) return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0) def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)