Add Onnx op Shrink (#851)

* Add onnx Shrink operation

* Fix soft/hard shrink onnx test
This commit is contained in:
M4tthewDE 2023-05-29 22:15:39 +02:00 committed by GitHub
parent 6f2b3755ca
commit 4408c25e9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -165,6 +165,9 @@ def get_run_onnx(onnx_model):
starts = starts + inp[0].shape[axis] if starts < 0 else starts
arg[axis] = (starts, ends)
ret = inp[0].slice(arg=arg)
elif n.op_type == "Shrink":
bias = opt['bias'] if 'bias' in opt else 0
ret = (inp[0] < -opt['lambd'])*(inp[0]+bias) + (inp[0] > opt['lambd'])*(inp[0]-bias)
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):