mirror of https://github.com/commaai/tinygrad.git
fix shape test
This commit is contained in:
parent
3becefa218
commit
5cdfeffe2c
|
@ -107,7 +107,6 @@ def get_run_onnx(onnx_model):
|
|||
elif n.op_type == "ReduceL2": ret = inp[0].pow(2).sum(axis=opt['axes'], keepdim=opt['keepdims']).sqrt()
|
||||
elif n.op_type == "ReduceSum": ret = inp[0].sum(axis=opt['axes'], keepdim=opt['keepdims'])
|
||||
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
|
||||
elif n.op_type == "Shape": ret = inp[0].shape
|
||||
elif n.op_type == "Expand": ret = inp[0].reshape([1]*(max(len(inp[0].shape), len(inp[1]))-len(inp[0].shape)) + list(inp[0].shape)) # just broadcast
|
||||
elif n.op_type == "Div": ret = inp[0].div(inp[1])
|
||||
elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float']
|
||||
|
@ -164,7 +163,7 @@ def get_run_onnx(onnx_model):
|
|||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
|
||||
if debug: print([x.shape for x in ret])
|
||||
if debug: print([x.shape if isinstance(x, Tensor) else None for x in ret])
|
||||
for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i]
|
||||
#print(ret.numpy().mean())
|
||||
if num == ONNXLIMIT:
|
||||
|
|
|
@ -60,4 +60,6 @@ def Dropout(data, ratio=0.5, training_mode=False, seed=None):
|
|||
if seed is not None: Tensor.manual_seed(seed)
|
||||
_mask : np.ndarray = np.asarray(Tensor._rng.binomial(1, 1.0-ratio, size=data.shape), dtype=data.dtype)
|
||||
mask = Tensor(_mask, requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
|
||||
def Shape(data, end=None, start=0): return list(data.shape)[start:end]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import unittest
|
||||
from onnx.backend.base import Backend, BackendRep
|
||||
import onnx.backend.test
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from typing import Any, Tuple
|
||||
|
||||
# pip3 install tabulate
|
||||
|
@ -17,7 +19,7 @@ class TinygradModel(BackendRep):
|
|||
def run(self, inputs: Any, **kwargs: Any) -> Tuple[Any, ...]:
|
||||
real_inputs = {k:v for k,v in zip(self.input_names, inputs)}
|
||||
ret = self.fxn(real_inputs, debug=True)
|
||||
return tuple(x.numpy() for x in ret.values())
|
||||
return tuple(x.numpy() if isinstance(x, Tensor) else np.array(x) for x in ret.values())
|
||||
|
||||
class TinygradBackend(Backend):
|
||||
@classmethod
|
||||
|
@ -41,21 +43,26 @@ backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
|
|||
# backend_test.include(str(x).split(" ")[0])
|
||||
|
||||
# passing node tests
|
||||
backend_test.include('test_unsqueeze_*')
|
||||
backend_test.include('test_gemm_*')
|
||||
backend_test.include('test_batchnorm_*')
|
||||
#backend_test.include('test_unsqueeze_*')
|
||||
#backend_test.include('test_gemm_*')
|
||||
#backend_test.include('test_batchnorm_*')
|
||||
#backend_test.include('test_transpose_*')
|
||||
|
||||
backend_test.include('test_shape_*')
|
||||
|
||||
# almost passing node tests
|
||||
#backend_test.include('test_conv_.*')
|
||||
#backend_test.include('test_dropout_*')
|
||||
|
||||
# good to investigate
|
||||
#backend_test.include('test_slice_*')
|
||||
|
||||
# failing for real reasons
|
||||
#backend_test.include('test_averagepool_2d_*')
|
||||
#backend_test.include('test_maxpool_2d_*')
|
||||
|
||||
"""
|
||||
backend_test.include('test_sum_*')
|
||||
backend_test.include('test_transpose_*')
|
||||
backend_test.include('test_tanh_*')
|
||||
|
||||
# should be passing (good place to start!)
|
||||
|
@ -76,16 +83,13 @@ backend_test.include('test_clip_*')
|
|||
|
||||
# the node tests, slowly
|
||||
#backend_test.include('test_reduce_sum_*')
|
||||
#backend_test.include('test_shape_*')
|
||||
#backend_test.include('test_softmax_*')
|
||||
#backend_test.include('test_slice_*')
|
||||
#backend_test.include('test_lrn_*')
|
||||
#backend_test.include('test_batchnorm_*')
|
||||
|
||||
# working big model tests
|
||||
backend_test.include('test_resnet50')
|
||||
backend_test.include('test_densenet121')
|
||||
backend_test.include('test_vgg19')
|
||||
#backend_test.include('test_resnet50')
|
||||
#backend_test.include('test_densenet121')
|
||||
#backend_test.include('test_vgg19')
|
||||
|
||||
"""
|
||||
# wrong big model tests
|
||||
|
|
Loading…
Reference in New Issue