onnx Einsum, CumSum, DepthToSpace, SpaceToDepth (#3252)

* onnx Einsum, CumSum, DepthToSpace, SpaceToDepth

Einsum inner product and `...` are not supported

* --durations=20
This commit is contained in:
chenyu 2024-01-26 10:47:53 -05:00 committed by GitHub
parent e45ffdb6cf
commit bc92c4cc32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 10 deletions

View File

@ -126,7 +126,7 @@ jobs:
- name: Run Pytest
run: TORCH=1 python -m pytest -n=auto test/ --durations=20
- name: Run ONNX
run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
testopencl:
strategy:
@ -286,7 +286,7 @@ jobs:
- name: Run metal test
run: METAL=1 python -m pytest -n=auto test/ --ignore=test/external --ignore=test/models --durations=20
- name: Run ONNX
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test tensor core ops
run: METAL=1 TC=2 DEBUG=3 python test/test_ops.py TestOps.test_big_gemm
- name: Test LLaMA compile speed

View File

@ -149,6 +149,20 @@ def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, tr
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
def Einsum(*Inputs: List[Tensor], equation): return Tensor.einsum(equation, Inputs)
def CumSum(X:Tensor, axis:Tensor, exclusive=0, reverse=0):
axis = safe_numpy(axis).item()
if axis < 0: axis += X.ndim
if reverse: X = X.flip(axis)
if exclusive:
pad_arg, shrink_arg = [None] * X.ndim, [None] * X.ndim
pad_arg[axis] = (1, 0)
shrink_arg[axis] = (0, X.shape[axis])
X = X.pad(tuple(pad_arg)).shrink(tuple(shrink_arg))
if reverse: return X.cumsum(axis).flip(axis)
return X.cumsum(axis)
# works with Tensors.ndim != 4
def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor):
shape = [1, -1] + [1] * (self.ndim-2)
@ -311,6 +325,17 @@ def ConvTranspose(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSE
if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
b, c, h, w = X.shape
if mode == "DCR":
return X.reshape(b, blocksize, blocksize, c // (blocksize**2), h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
elif mode == "CRD":
return X.reshape(b, c // (blocksize ** 2), blocksize, blocksize, h, w).permute(0, 1, 4, 2, 5, 3).reshape(b, c // (blocksize ** 2), h * blocksize, w * blocksize)
def SpaceToDepth(X:Tensor, blocksize:int):
b, c, h, w = X.shape
return X.reshape(b, c, h // blocksize, blocksize, w // blocksize, blocksize).permute(0, 3, 5, 1, 2, 4).reshape(b, c * (blocksize**2), h // blocksize, w // blocksize)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: ()

View File

@ -60,14 +60,14 @@ if Device.DEFAULT in ["METAL"] or (OSX and Device.DEFAULT == "GPU"):
backend_test.exclude('test_eyelike_with_dtype_cpu')
backend_test.exclude('test_reduce_log_sum_exp*')
backend_test.exclude('test_operator_add*')
backend_test.exclude('test_einsum_*')
backend_test.exclude('test_cumsum_*')
# no float16 in CI, LLVM segfaults, GPU requires cl_khr_fp16
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
backend_test.exclude('float16')
backend_test.exclude('FLOAT16')
backend_test.exclude('string')
# dtype cast
backend_test.exclude('STRING')
backend_test.exclude('FLOAT8')
@ -91,6 +91,11 @@ backend_test.exclude('test_mod_*')
# no boolean ops (2d, 3d, 4d)
backend_test.exclude('test_bitshift_*')
# no string ops
backend_test.exclude('string')
backend_test.exclude('test_strnorm_*')
backend_test.exclude('test_regex_*')
# no scatternd gathernd
backend_test.exclude('test_gathernd_*')
backend_test.exclude('test_scatternd_*')
@ -122,7 +127,6 @@ backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
backend_test.exclude('test_bitwise_*')
backend_test.exclude('test_blackmanwindow_*')
backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_cumsum_*')
backend_test.exclude('test_det_*')
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
@ -134,8 +138,8 @@ backend_test.exclude('test_hannwindow_*')
backend_test.exclude('test_hardmax_*')
backend_test.exclude('test_gridsample_*')
backend_test.exclude('test_dft_*')
backend_test.exclude('test_einsum_*')
backend_test.exclude('test_strnorm_*')
backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
backend_test.exclude('test_unique_*')
backend_test.exclude('test_sequence_*')
backend_test.exclude('test_nonmaxsuppression_*')
@ -150,8 +154,6 @@ backend_test.exclude('test_melweightmatrix_*')
backend_test.exclude('test_basic_deform_conv_*')
backend_test.exclude('test_deform_conv_*')
backend_test.exclude('test_lppool_*')
backend_test.exclude('test_depthtospace_*')
backend_test.exclude('test_spacetodepth_*')
backend_test.exclude('test_scan*')
backend_test.exclude('test_split_to_sequence_*')
backend_test.exclude('test_resize_downsample_scales_cubic_*') # unsure how to implement cubic
@ -160,7 +162,6 @@ backend_test.exclude('test_resize_upsample_scales_cubic_*') # unsure how to impl
backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to implement cubic
# rest of the failing tests
backend_test.exclude('test_regex_*') # does not support string Tensors
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip