mirror of https://github.com/commaai/tinygrad.git
onnx update for trilu and argmax (#3283)
* support 0 in shape for tril and triu * select_last_index for ArgMax and ArgMin * pass **kwargs
This commit is contained in:
parent
5b46b0ff3d
commit
7816c3b692
|
@ -120,10 +120,8 @@ def Unsqueeze(data: Tensor, axes):
|
|||
def Binarizer(x, threshold=0.0): return (x > threshold).float()
|
||||
|
||||
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
m = x == (x.max(axis=axis, keepdim=keepdims) if keepdims else x.max(axis=axis, keepdim=keepdims).unsqueeze(axis))
|
||||
c = Tensor.arange(x.shape[axis]).reshape(*[1]*(axis), x.shape[axis], *[1]*(x.ndim - axis-1)) * m
|
||||
return c.max(axis=axis,keepdim=keepdims).cast(dtypes.int64)
|
||||
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis=0, keepdims=1, select_last_index=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
def Concat(*xs: List[Tensor], axis): return xs[0].cat(*xs[1:], dim=axis)
|
||||
|
|
|
@ -81,8 +81,6 @@ backend_test.exclude('test_convinteger_*')
|
|||
backend_test.exclude('test_matmulinteger_*')
|
||||
|
||||
# we don't support indexes
|
||||
# backend_test.exclude('test_argmax_*') # Needs more work: select_last_index
|
||||
# backend_test.exclude('test_argmin_*') # Needs more work: select_last_index
|
||||
backend_test.exclude('test_nonzero_*')
|
||||
|
||||
# no support for mod
|
||||
|
@ -128,10 +126,6 @@ backend_test.exclude('test_bitwise_*')
|
|||
backend_test.exclude('test_blackmanwindow_*')
|
||||
backend_test.exclude('test_bernoulli_*')
|
||||
backend_test.exclude('test_det_*')
|
||||
|
||||
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
|
||||
backend_test.exclude('test_triu_zero_cpu') # TODO: zero array triu support
|
||||
|
||||
backend_test.exclude('test_col2im_*')
|
||||
backend_test.exclude('test_hammingwindow_*')
|
||||
backend_test.exclude('test_hannwindow_*')
|
||||
|
|
|
@ -252,12 +252,14 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(3,3)], lambda x: x.tril(1))
|
||||
helper_test_op([(3,3)], lambda x: x.tril(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,0,3)], lambda x: x.tril())
|
||||
helper_test_op([(5,3,3)], lambda x: x.tril(1))
|
||||
def test_triu(self):
|
||||
helper_test_op([(3,3)], lambda x: x.triu())
|
||||
helper_test_op([(3,3)], lambda x: x.triu(1))
|
||||
helper_test_op([(3,3)], lambda x: x.triu(-1))
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,0,3)], lambda x: x.triu())
|
||||
helper_test_op([(5,3,3)], lambda x: x.triu(1))
|
||||
def test_maximum(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
|
||||
|
|
|
@ -713,6 +713,7 @@ class Tensor:
|
|||
@staticmethod
|
||||
def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor:
|
||||
assert all_int((r,c)), "does not support symbolic"
|
||||
if r == 0: return Tensor.zeros((r, c), **kwargs)
|
||||
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
|
||||
|
|
Loading…
Reference in New Issue