From 3c5a51fb3aeccf6f737aa437b02586d84705849b Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Wed, 15 Nov 2023 23:12:38 +0800 Subject: [PATCH] aaaaaaa finally (#2310) --- extra/onnx_ops.py | 23 ++++++++++++++------- test/external/external_test_onnx_backend.py | 1 - 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index a0cdd291..1e31a4fc 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -344,14 +344,21 @@ def Conv(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=padding) def ConvTranspose(X: Tensor, W: Tensor, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, output_shape=None, output_padding=0, strides=1): - if not kernel_shape: kernel_shape = W.shape - if pads is None and auto_pad != "NOTSET": pads = _auto_pad(X, auto_pad, strides, kernel_shape, dilations) - elif pads is None and auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2) - strides_ = [1]*(W.ndim-1) + [strides] if isinstance(strides, int) else [1]*(W.ndim-len(strides)) + list(strides) - dilations_ = [1]*(W.ndim-1) + [dilations] if isinstance(dilations, int) else [1]*(W.ndim-len(dilations)) + list(dilations) - if output_shape and not output_padding: - out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides_, X.shape, kernel_shape, dilations_))] - output_padding = [os - rs for os, rs in zip(output_shape, out_sh[-len(output_shape):])] + if kernel_shape is None: kernel_shape = W.shape[2:] + if isinstance(strides, int): strides = [strides]*(W.ndim-2) + if isinstance(dilations, int): dilations = [dilations]*(W.ndim-2) + if isinstance(output_padding, int): output_padding = [output_padding]*(W.ndim-2) + out_sh = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] if output_shape is not None or auto_pad != "NOTSET" else [] + if pads is None: + if output_shape is None: output_shape = [xs*st for xs, st in zip(X.shape[2:], strides)] + if auto_pad == "NOTSET": pads = [0,0] * (X.ndim - 2) + else: + total_padding = [st*(ish-1) + pad + ((ks-1)*dil+1)-osh for st, ish, pad, ks, dil, osh in zip(strides, X.shape[2:], output_padding, kernel_shape, dilations, output_shape)] + pad_shape = flatten([[sh//2, sh-sh//2] for sh in total_padding]) + pads = pad_shape[::2] + pad_shape[1::2] if auto_pad == "SAME_UPPER" else pad_shape[1::2] + pad_shape[::2] + else: + if output_shape is None: output_shape = [st*(xs-1) + (ks-1)*di+1 if n < 2 else st*(xs-1) + (ks-1)*di+1 - pads[n-2] - pads[n-1] for n, (st, xs, ks, di) in enumerate(zip(strides, X.shape[2:], kernel_shape, dilations))] + if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)] return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding) # Reimplemented here because you need legacy RNG for passing ONNX tests. diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 6c1d139c..b8f96d69 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -146,7 +146,6 @@ backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to imple # rest of the failing tests backend_test.exclude('test_regex_*') # does not support string Tensors -backend_test.exclude('test_convtranspose_autopad_same_cpu') # TODO geohotstan has no idea how this is done, autopad requires output_shape but output_shape requires pads from autopad backend_test.exclude('test_optional_has_element_empty_optional_input_cpu') # Attempts to create Tensor from None backend_test.exclude('test_reshape_allowzero_reordered_cpu') # reshaping to shape with 0 backend_test.exclude('test_reduce_min_empty_set_cpu') # max a tensor with 0 in shape