mirror of https://github.com/commaai/tinygrad.git
onnx 58/109/208
This commit is contained in:
parent
e8a153e4e9
commit
85452fbaf3
|
@ -123,8 +123,6 @@ def get_run_onnx(onnx_model):
|
|||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices]
|
||||
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
|
||||
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
|
||||
elif n.op_type in ["Sum"]:
|
||||
ret = functools.reduce(Tensor.__add__, inp)
|
||||
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
|
|
|
@ -2,6 +2,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import prod
|
||||
from extra.onnx import safe_numpy
|
||||
import numpy as np
|
||||
import functools
|
||||
|
||||
def Unsqueeze(data, axes):
|
||||
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
|
||||
|
@ -81,6 +82,7 @@ def Expand(input, shape):
|
|||
|
||||
def Identity(input): return input
|
||||
def Neg(input): return -input
|
||||
def Reciprocal(input): return input.reciprocal()
|
||||
def Sqrt(input): return input.sqrt()
|
||||
def Sign(input): return input.sign()
|
||||
def Abs(input): return input.abs()
|
||||
|
@ -97,6 +99,9 @@ def Softmax(input, axis=-1): return input.softmax(axis)
|
|||
def LogSoftmax(input, axis=-1): return input.log_softmax(axis)
|
||||
def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max)
|
||||
|
||||
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
|
||||
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
|
||||
|
||||
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None else ([] if noop_with_empty_axes else None)
|
||||
|
||||
# ReduceProd would require a new llop
|
||||
|
@ -111,4 +116,11 @@ def ReduceLogSum(data, axes=None, keepdims=1, noop_with_empty_axes=0): return da
|
|||
def ReduceLogSumExp(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
|
||||
|
||||
def GlobalAveragePool(X): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def GlobalMaxPool(X): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
def GlobalMaxPool(X): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
|
||||
|
||||
def Tile(input, repeats):
|
||||
repeats_ = [int(x) for x in safe_numpy(repeats)]
|
||||
new_shape = [x for i in range(len(input.shape)) for x in [1,input.shape[i]]]
|
||||
expand_shape = [x for r,s in zip(repeats_, input.shape) for x in [r,s]]
|
||||
final_shape = [r*s for r,s in zip(repeats_, input.shape)]
|
||||
return input.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
|
|
@ -44,9 +44,16 @@ backend_test.exclude('test_max_*')
|
|||
# add support for SoftmaxCrossEntropyLoss and NegativeLogLikelihoodLoss
|
||||
backend_test.exclude('test_sce_*')
|
||||
|
||||
# no optimizers (add them)
|
||||
backend_test.exclude('test_adagrad_*')
|
||||
backend_test.exclude('test_adam_*')
|
||||
backend_test.exclude('test_nesterov_momentum_*')
|
||||
|
||||
# we only support float32
|
||||
backend_test.exclude('test_add_uint8_*')
|
||||
backend_test.exclude('test_sub_uint8_*')
|
||||
backend_test.exclude('test_div_uint8_*')
|
||||
backend_test.exclude('test_mul_uint8_*')
|
||||
backend_test.exclude('test_cast_*')
|
||||
backend_test.exclude('test_castlike_*')
|
||||
|
||||
|
@ -75,6 +82,7 @@ backend_test.exclude('test_and*')
|
|||
backend_test.exclude('test_xor*')
|
||||
backend_test.exclude('test_or*')
|
||||
backend_test.exclude('test_bitshift_*')
|
||||
backend_test.exclude('test_not_*')
|
||||
|
||||
# no scatter gather
|
||||
backend_test.exclude('test_gather_*')
|
||||
|
@ -82,9 +90,13 @@ backend_test.exclude('test_gathernd_*')
|
|||
backend_test.exclude('test_scatter_*')
|
||||
backend_test.exclude('test_scatternd_*')
|
||||
|
||||
# no quantize
|
||||
backend_test.exclude('test_dequantizelinear_*')
|
||||
backend_test.exclude('test_dynamicquantizelinear_*')
|
||||
backend_test.exclude('test_qlinearmatmul_*')
|
||||
backend_test.exclude('test_quantizelinear_*')
|
||||
|
||||
# unsupported (strange) ops
|
||||
backend_test.exclude('test_adagrad_*')
|
||||
backend_test.exclude('test_adam_*')
|
||||
backend_test.exclude('test_argmax_*')
|
||||
backend_test.exclude('test_argmin_*')
|
||||
backend_test.exclude('test_bitwise_*')
|
||||
|
@ -102,8 +114,6 @@ backend_test.exclude('test_gru_*')
|
|||
backend_test.exclude('test_gridsample_*')
|
||||
backend_test.exclude('test_if_*')
|
||||
backend_test.exclude('test_compress_*')
|
||||
backend_test.exclude('test_dequantizelinear_*')
|
||||
backend_test.exclude('test_dynamicquantizelinear_*')
|
||||
backend_test.exclude('test_det_*')
|
||||
backend_test.exclude('test_dft_*')
|
||||
backend_test.exclude('test_einsum_*')
|
||||
|
@ -111,13 +121,18 @@ backend_test.exclude('test_erf_*')
|
|||
backend_test.exclude('test_strnorm_*')
|
||||
backend_test.exclude('test_unique_*')
|
||||
backend_test.exclude('test_sequence_*')
|
||||
backend_test.exclude('test_nonmaxsuppression_*')
|
||||
backend_test.exclude('test_reversesequence_*')
|
||||
backend_test.exclude('test_roialign_*')
|
||||
backend_test.exclude('test_rnn_*')
|
||||
backend_test.exclude('test_top_k_*')
|
||||
|
||||
backend_test.include('test_selu_*')
|
||||
# disable model tests for now since they are slow
|
||||
for x in backend_test.test_suite:
|
||||
if 'OnnxBackendRealModelTest' in str(type(x)):
|
||||
backend_test.exclude(str(x).split(" ")[0])
|
||||
|
||||
# the node tests
|
||||
#for x in backend_test.test_suite:
|
||||
# if 'OnnxBackendNodeModelTest' in str(type(x)):
|
||||
# backend_test.include(str(x).split(" ")[0])
|
||||
#backend_test.include('test_tile_*')
|
||||
|
||||
# passing node tests
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue