mirror of https://github.com/commaai/tinygrad.git
much cleaner way to write onnx ops
This commit is contained in:
parent
d3029c91c5
commit
d3feea302d
|
@ -1,10 +1,23 @@
|
|||
import functools
|
||||
import importlib
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import getenv, DEBUG
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
|
||||
# global numpy cache for parameters
|
||||
numpy_cache = {}
|
||||
def safe_numpy(t):
|
||||
global numpy_cache
|
||||
if t not in numpy_cache:
|
||||
if DEBUG >= 1:
|
||||
print("numpy cache miss", t)
|
||||
numpy_cache[t] = t.numpy()
|
||||
return numpy_cache[t]
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
|
||||
def get_run_onnx(onnx_model):
|
||||
|
@ -51,16 +64,6 @@ def get_run_onnx(onnx_model):
|
|||
for num,n in enumerate(onnx_model.graph.node):
|
||||
attribute_dict[num] = attribute_to_dict(n.attribute)
|
||||
|
||||
# and cache them
|
||||
numpy_cache = {}
|
||||
def safe_numpy(t):
|
||||
nonlocal numpy_cache
|
||||
if t not in numpy_cache:
|
||||
if DEBUG >= 1:
|
||||
print("numpy cache miss", t)
|
||||
numpy_cache[t] = t.numpy()
|
||||
return numpy_cache[t]
|
||||
|
||||
def run_onnx(inputs={}, debug=False):
|
||||
input_tensors = {}
|
||||
intermediate_tensors = {}
|
||||
|
@ -108,17 +111,6 @@ def get_run_onnx(onnx_model):
|
|||
elif n.op_type == "Div": ret = inp[0].div(inp[1])
|
||||
elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float']
|
||||
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))])
|
||||
elif n.op_type == "Unsqueeze":
|
||||
if 'axes' not in opt: opt['axes'] = [int(x) for x in safe_numpy(inp[1])]
|
||||
opt['axes'] = [len(inp[0].shape) + x if x < 0 else x for x in opt['axes']]
|
||||
ptr = 0
|
||||
new_shape = []
|
||||
for i in range(len(inp[0].shape) + len(opt['axes'])):
|
||||
if i in opt['axes']: new_shape.append(1)
|
||||
else:
|
||||
new_shape.append(inp[0].shape[ptr])
|
||||
ptr += 1
|
||||
ret = inp[0].reshape(new_shape)
|
||||
elif n.op_type == "Resize":
|
||||
# TODO: this is handcoded for YOLOv8
|
||||
scales = safe_numpy(inp[2])
|
||||
|
@ -134,14 +126,6 @@ def get_run_onnx(onnx_model):
|
|||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices]
|
||||
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
|
||||
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
|
||||
elif n.op_type == "BatchNormalization":
|
||||
invstd = inp[4].add(opt.get('epsilon', 1e-5))**-0.5
|
||||
ret = inp[0].batchnorm(inp[1], inp[2], inp[3], invstd)
|
||||
elif n.op_type == "Gemm":
|
||||
A = inp[0].transpose() if opt.get('transA', 0) == 1 else inp[0]
|
||||
B = inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1]
|
||||
ret = opt.get('alpha', 1.0) * (A @ B)
|
||||
if len(inp) > 2: ret += opt.get('beta', 1.0) * inp[2]
|
||||
elif n.op_type == "Conv":
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert 'dilations' not in opt or opt['dilations'] == (1,1)
|
||||
|
@ -182,12 +166,15 @@ def get_run_onnx(onnx_model):
|
|||
starts = starts + inp[0].shape[axis] if starts < 0 else starts
|
||||
arg[axis] = (starts, ends)
|
||||
ret = inp[0].slice(arg=arg)
|
||||
elif hasattr(onnx_ops, n.op_type):
|
||||
ret = getattr(onnx_ops, n.op_type)(*inp, **opt)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
assert len(n.output) == 1, f"output size must be 1, it's {n.output}"
|
||||
if debug: print(ret.shape)
|
||||
intermediate_tensors[n.output[0]] = ret
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
assert len(n.output) == len(ret), f"output size must be {len(ret)}, it's {n.output}"
|
||||
if debug: print([x.shape for x in ret])
|
||||
for i,r in enumerate(ret): intermediate_tensors[n.output[i]] = r
|
||||
#print(ret.numpy().mean())
|
||||
if num == ONNXLIMIT:
|
||||
output_tensor_names = n.output
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
from extra.onnx import safe_numpy
|
||||
|
||||
def Unsqueeze(data, axes):
|
||||
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
|
||||
ptr = 0
|
||||
new_shape = []
|
||||
for i in range(len(data.shape) + len(axes)):
|
||||
if i in axes: new_shape.append(1)
|
||||
else:
|
||||
new_shape.append(data.shape[ptr])
|
||||
ptr += 1
|
||||
return data.reshape(new_shape)
|
||||
|
||||
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
|
||||
ret = alpha * ((A.transpose() if transA == 1 else A) @ (B.transpose() if transB == 1 else B))
|
||||
if C is not None: ret += beta * C
|
||||
return ret
|
||||
|
||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||
def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).pow(-0.5)
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
else:
|
||||
invstd = (input_var + epsilon)**-0.5
|
||||
return X.batchnorm(scale, B, input_mean, invstd)
|
|
@ -42,8 +42,11 @@ backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
|
|||
|
||||
# passing node tests
|
||||
backend_test.include('test_unsqueeze_*')
|
||||
backend_test.include('test_sum_*')
|
||||
backend_test.include('test_gemm_*')
|
||||
backend_test.include('test_batchnorm_*')
|
||||
|
||||
"""
|
||||
backend_test.include('test_sum_*')
|
||||
backend_test.include('test_transpose_*')
|
||||
backend_test.include('test_tanh_*')
|
||||
|
||||
|
@ -53,6 +56,7 @@ backend_test.include('test_reshape_*')
|
|||
backend_test.include('test_flatten_*')
|
||||
backend_test.include('test_expand_*')
|
||||
backend_test.include('test_clip_*')
|
||||
"""
|
||||
|
||||
# requires CastLike?
|
||||
#backend_test.include('test_relu_*')
|
||||
|
|
Loading…
Reference in New Issue