mirror of https://github.com/commaai/tinygrad.git
onnx Einsum, CumSum, DepthToSpace, SpaceToDepth (#3252)
* onnx Einsum, CumSum, DepthToSpace, SpaceToDepth Einsum inner product and `...` are not supported * --durations=20
This commit is contained in:
parent
e45ffdb6cf
commit
bc92c4cc32
|
@ -126,7 +126,7 @@ jobs:
|
|||
- name: Run Pytest
|
||||
run: TORCH=1 python -m pytest -n=auto test/ --durations=20
|
||||
- name: Run ONNX
|
||||
run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
|
||||
run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
|
||||
testopencl:
|
||||
strategy:
|
||||
|
@ -286,7 +286,7 @@ jobs:
|
|||
- name: Run metal test
|
||||
run: METAL=1 python -m pytest -n=auto test/ --ignore=test/external --ignore=test/models --durations=20
|
||||
- name: Run ONNX
|
||||
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py
|
||||
run: METAL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test tensor core ops
|
||||
run: METAL=1 TC=2 DEBUG=3 python test/test_ops.py TestOps.test_big_gemm
|
||||
- name: Test LLaMA compile speed
|
||||
|
|
|
@ -149,6 +149,20 @@ def Gemm(A: Tensor, B: Tensor, C: Tensor=None, alpha=1.0, beta=1.0, transA=0, tr
|
|||
if C is not None: ret += beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
||||
return ret
|
||||
|
||||
def Einsum(*Inputs: List[Tensor], equation): return Tensor.einsum(equation, Inputs)
|
||||
|
||||
def CumSum(X:Tensor, axis:Tensor, exclusive=0, reverse=0):
|
||||
axis = safe_numpy(axis).item()
|
||||
if axis < 0: axis += X.ndim
|
||||
if reverse: X = X.flip(axis)
|
||||
if exclusive:
|
||||
pad_arg, shrink_arg = [None] * X.ndim, [None] * X.ndim
|
||||
pad_arg[axis] = (1, 0)
|
||||
shrink_arg[axis] = (0, X.shape[axis])
|
||||
X = X.pad(tuple(pad_arg)).shrink(tuple(shrink_arg))
|
||||
if reverse: return X.cumsum(axis).flip(axis)
|
||||
return X.cumsum(axis)
|
||||
|
||||
# works with Tensors.ndim != 4
|
||||
def _batchnorm(self:Tensor, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor):
|
||||
shape = [1, -1] + [1] * (self.ndim-2)
|
||||
|
@ -311,6 +325,17 @@ def ConvTranspose(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSE
|
|||
if out_sh: output_padding = [os - rs for os, rs in zip(output_shape, out_sh)]
|
||||
return X.conv_transpose2d(W, B, stride=strides, groups=group, dilation=dilations, padding=pads if pads is not None else 0, output_padding=output_padding)
|
||||
|
||||
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
||||
b, c, h, w = X.shape
|
||||
if mode == "DCR":
|
||||
return X.reshape(b, blocksize, blocksize, c // (blocksize**2), h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, c // (blocksize**2), h * blocksize, w * blocksize)
|
||||
elif mode == "CRD":
|
||||
return X.reshape(b, c // (blocksize ** 2), blocksize, blocksize, h, w).permute(0, 1, 4, 2, 5, 3).reshape(b, c // (blocksize ** 2), h * blocksize, w * blocksize)
|
||||
|
||||
def SpaceToDepth(X:Tensor, blocksize:int):
|
||||
b, c, h, w = X.shape
|
||||
return X.reshape(b, c, h // blocksize, blocksize, w // blocksize, blocksize).permute(0, 3, 5, 1, 2, 4).reshape(b, c * (blocksize**2), h // blocksize, w // blocksize)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
|
||||
if isinstance(ratio, Tensor) and not ratio.shape: ratio = safe_numpy(ratio) # ratio and tensor is passed in as Tensor with shape: ()
|
||||
|
|
|
@ -60,14 +60,14 @@ if Device.DEFAULT in ["METAL"] or (OSX and Device.DEFAULT == "GPU"):
|
|||
backend_test.exclude('test_eyelike_with_dtype_cpu')
|
||||
backend_test.exclude('test_reduce_log_sum_exp*')
|
||||
backend_test.exclude('test_operator_add*')
|
||||
backend_test.exclude('test_einsum_*')
|
||||
backend_test.exclude('test_cumsum_*')
|
||||
|
||||
# no float16 in CI, LLVM segfaults, GPU requires cl_khr_fp16
|
||||
if Device.DEFAULT in ['LLVM', 'CUDA', 'GPU'] and CI:
|
||||
backend_test.exclude('float16')
|
||||
backend_test.exclude('FLOAT16')
|
||||
|
||||
backend_test.exclude('string')
|
||||
|
||||
# dtype cast
|
||||
backend_test.exclude('STRING')
|
||||
backend_test.exclude('FLOAT8')
|
||||
|
@ -91,6 +91,11 @@ backend_test.exclude('test_mod_*')
|
|||
# no boolean ops (2d, 3d, 4d)
|
||||
backend_test.exclude('test_bitshift_*')
|
||||
|
||||
# no string ops
|
||||
backend_test.exclude('string')
|
||||
backend_test.exclude('test_strnorm_*')
|
||||
backend_test.exclude('test_regex_*')
|
||||
|
||||
# no scatternd gathernd
|
||||
backend_test.exclude('test_gathernd_*')
|
||||
backend_test.exclude('test_scatternd_*')
|
||||
|
@ -122,7 +127,6 @@ backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
|
|||
backend_test.exclude('test_bitwise_*')
|
||||
backend_test.exclude('test_blackmanwindow_*')
|
||||
backend_test.exclude('test_bernoulli_*')
|
||||
backend_test.exclude('test_cumsum_*')
|
||||
backend_test.exclude('test_det_*')
|
||||
|
||||
backend_test.exclude('test_tril_zero_cpu') # TODO: zero array tril support
|
||||
|
@ -134,8 +138,8 @@ backend_test.exclude('test_hannwindow_*')
|
|||
backend_test.exclude('test_hardmax_*')
|
||||
backend_test.exclude('test_gridsample_*')
|
||||
backend_test.exclude('test_dft_*')
|
||||
backend_test.exclude('test_einsum_*')
|
||||
backend_test.exclude('test_strnorm_*')
|
||||
backend_test.exclude('test_einsum_batch_diagonal_cpu*') # TODO: equation = '...ii ->...i'
|
||||
backend_test.exclude('test_einsum_inner_prod_cpu*') # TODO: equation = 'i,i'
|
||||
backend_test.exclude('test_unique_*')
|
||||
backend_test.exclude('test_sequence_*')
|
||||
backend_test.exclude('test_nonmaxsuppression_*')
|
||||
|
@ -150,8 +154,6 @@ backend_test.exclude('test_melweightmatrix_*')
|
|||
backend_test.exclude('test_basic_deform_conv_*')
|
||||
backend_test.exclude('test_deform_conv_*')
|
||||
backend_test.exclude('test_lppool_*')
|
||||
backend_test.exclude('test_depthtospace_*')
|
||||
backend_test.exclude('test_spacetodepth_*')
|
||||
backend_test.exclude('test_scan*')
|
||||
backend_test.exclude('test_split_to_sequence_*')
|
||||
backend_test.exclude('test_resize_downsample_scales_cubic_*') # unsure how to implement cubic
|
||||
|
@ -160,7 +162,6 @@ backend_test.exclude('test_resize_upsample_scales_cubic_*') # unsure how to impl
|
|||
backend_test.exclude('test_resize_upsample_sizes_cubic_*') # unsure how to implement cubic
|
||||
|
||||
# rest of the failing tests
|
||||
backend_test.exclude('test_regex_*') # does not support string Tensors
|
||||
backend_test.exclude('test_resize_downsample_scales_linear_antialias_cpu') # antialias not implemented
|
||||
backend_test.exclude('test_resize_downsample_sizes_linear_antialias_cpu') # antialias not implemented
|
||||
backend_test.exclude('test_resize_tf_crop_and_resize_cpu') # unsure about fill value after clip
|
||||
|
|
Loading…
Reference in New Issue