onnx update for trilu and argmax (#3283)

* support 0 in shape for tril and triu

* select_last_index for ArgMax and ArgMin

* pass **kwargs
This commit is contained in:
chenyu 2024-01-30 18:39:16 -05:00 committed by GitHub
parent 5b46b0ff3d
commit 7816c3b692
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 5 additions and 10 deletions

View File

@ -120,10 +120,8 @@ def Unsqueeze(data: Tensor, axes):
def Binarizer(x, threshold=0.0): return (x > threshold).float()
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis
m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis))
c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis)

View File

@ -81,8 +81,6 @@ backend_test.exclude('test_convinteger_*')
backend_test.exclude('test_matmulinteger_*')
# we don't support indexes
# backend_test.exclude('test_argmax_*') # Needs more work: select_last_index
# backend_test.exclude('test_argmin_*') # Needs more work: select_last_index
backend_test.exclude('test_nonzero_*')
# no support for mod
@ -128,10 +126,6 @@ backend_test.exclude('test_bitwise_*')
backend_test.exclude('test_blackmanwindow_*')
backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_det_*')
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support
backend_test.exclude('test_col2im_*')
backend_test.exclude('test_hammingwindow_*')
backend_test.exclude('test_hannwindow_*')

View File

@ -252,12 +252,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,3)], lambda x: x.tril(1))
helper_test_op([(3,3)], lambda x: x.tril(-1))
helper_test_op([(5,3,3)], lambda x: x.tril())
helper_test_op([(5,0,3)], lambda x: x.tril())
helper_test_op([(5,3,3)], lambda x: x.tril(1))
def test_triu(self):
helper_test_op([(3,3)], lambda x: x.triu())
helper_test_op([(3,3)], lambda x: x.triu(1))
helper_test_op([(3,3)], lambda x: x.triu(-1))
helper_test_op([(5,3,3)], lambda x: x.triu())
helper_test_op([(5,0,3)], lambda x: x.triu())
helper_test_op([(5,3,3)], lambda x: x.triu(1))
def test_maximum(self):
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)

View File

@ -713,6 +713,7 @@ class Tensor:
@staticmethod
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
assert all_int((r,c)), "does not support symbolic"
if r == 0: return Tensor.zeros((r, c), **kwargs)
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)