onnx 58/109/208

This commit is contained in:
George Hotz 2023-02-24 12:19:05 -08:00
parent e8a153e4e9
commit 85452fbaf3
3 changed files with 37 additions and 12 deletions

View File

@ -123,8 +123,6 @@ def get_run_onnx(onnx_model):
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices]
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
elif n.op_type in ["Sum"]:
ret = functools.reduce(Tensor.__add__, inp)
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
# TODO: add this to tinygrad? i don't think it's in torch
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):

View File

@ -2,6 +2,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from extra.onnx import safe_numpy
import numpy as np
import functools
def Unsqueeze(data, axes):
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
@ -81,6 +82,7 @@ def Expand(input, shape):
def Identity(input): return input
def Neg(input): return -input
def Reciprocal(input): return input.reciprocal()
def Sqrt(input): return input.sqrt()
def Sign(input): return input.sign()
def Abs(input): return input.abs()
@ -97,6 +99,9 @@ def Softmax(input, axis=-1): return input.softmax(axis)
def LogSoftmax(input, axis=-1): return input.log_softmax(axis)
def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max)
def Sum(*data_0): return functools.reduce(Tensor.__add__, data_0)
def Mean(*data_0): return functools.reduce(Tensor.__add__, data_0) / len(data_0)
def _axes(axes, noop_with_empty_axes): return [int(x) for x in safe_numpy(axes)] if axes is not None else ([] if noop_with_empty_axes else None)
# ReduceProd would require a new llop
@ -111,4 +116,11 @@ def ReduceLogSum(data, axes=None, keepdims=1, noop_with_empty_axes=0): return da
def ReduceLogSumExp(data, axes=None, keepdims=1, noop_with_empty_axes=0): return data.exp().sum(_axes(axes, noop_with_empty_axes), keepdim=keepdims).log()
def GlobalAveragePool(X): return X.mean(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def GlobalMaxPool(X): return X.max(axis=tuple(range(2, len(X.shape))), keepdim=True)
def Tile(input, repeats):
repeats_ = [int(x) for x in safe_numpy(repeats)]
new_shape = [x for i in range(len(input.shape)) for x in [1,input.shape[i]]]
expand_shape = [x for r,s in zip(repeats_, input.shape) for x in [r,s]]
final_shape = [r*s for r,s in zip(repeats_, input.shape)]
return input.reshape(new_shape).expand(expand_shape).reshape(final_shape)

View File

@ -44,9 +44,16 @@ backend_test.exclude('test_max_*')
# add support for SoftmaxCrossEntropyLoss and NegativeLogLikelihoodLoss
backend_test.exclude('test_sce_*')
# no optimizers (add them)
backend_test.exclude('test_adagrad_*')
backend_test.exclude('test_adam_*')
backend_test.exclude('test_nesterov_momentum_*')
# we only support float32
backend_test.exclude('test_add_uint8_*')
backend_test.exclude('test_sub_uint8_*')
backend_test.exclude('test_div_uint8_*')
backend_test.exclude('test_mul_uint8_*')
backend_test.exclude('test_cast_*')
backend_test.exclude('test_castlike_*')
@ -75,6 +82,7 @@ backend_test.exclude('test_and*')
backend_test.exclude('test_xor*')
backend_test.exclude('test_or*')
backend_test.exclude('test_bitshift_*')
backend_test.exclude('test_not_*')
# no scatter gather
backend_test.exclude('test_gather_*')
@ -82,9 +90,13 @@ backend_test.exclude('test_gathernd_*')
backend_test.exclude('test_scatter_*')
backend_test.exclude('test_scatternd_*')
# no quantize
backend_test.exclude('test_dequantizelinear_*')
backend_test.exclude('test_dynamicquantizelinear_*')
backend_test.exclude('test_qlinearmatmul_*')
backend_test.exclude('test_quantizelinear_*')
# unsupported (strange) ops
backend_test.exclude('test_adagrad_*')
backend_test.exclude('test_adam_*')
backend_test.exclude('test_argmax_*')
backend_test.exclude('test_argmin_*')
backend_test.exclude('test_bitwise_*')
@ -102,8 +114,6 @@ backend_test.exclude('test_gru_*')
backend_test.exclude('test_gridsample_*')
backend_test.exclude('test_if_*')
backend_test.exclude('test_compress_*')
backend_test.exclude('test_dequantizelinear_*')
backend_test.exclude('test_dynamicquantizelinear_*')
backend_test.exclude('test_det_*')
backend_test.exclude('test_dft_*')
backend_test.exclude('test_einsum_*')
@ -111,13 +121,18 @@ backend_test.exclude('test_erf_*')
backend_test.exclude('test_strnorm_*')
backend_test.exclude('test_unique_*')
backend_test.exclude('test_sequence_*')
backend_test.exclude('test_nonmaxsuppression_*')
backend_test.exclude('test_reversesequence_*')
backend_test.exclude('test_roialign_*')
backend_test.exclude('test_rnn_*')
backend_test.exclude('test_top_k_*')
backend_test.include('test_selu_*')
# disable model tests for now since they are slow
for x in backend_test.test_suite:
if 'OnnxBackendRealModelTest' in str(type(x)):
backend_test.exclude(str(x).split(" ")[0])
# the node tests
#for x in backend_test.test_suite:
# if 'OnnxBackendNodeModelTest' in str(type(x)):
# backend_test.include(str(x).split(" ")[0])
#backend_test.include('test_tile_*')
# passing node tests
"""