much cleaner way to write onnx ops

This commit is contained in:
George Hotz 2023-02-24 08:46:28 -08:00
parent d3029c91c5
commit d3feea302d
3 changed files with 58 additions and 33 deletions

View File

@ -1,10 +1,23 @@
import functools
import importlib
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from tinygrad.helpers import getenv, DEBUG
from onnx.helper import tensor_dtype_to_np_dtype
# global numpy cache for parameters
numpy_cache = {}
def safe_numpy(t):
global numpy_cache
if t not in numpy_cache:
if DEBUG >= 1:
print("numpy cache miss", t)
numpy_cache[t] = t.numpy()
return numpy_cache[t]
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model):
@ -51,16 +64,6 @@ def get_run_onnx(onnx_model):
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = attribute_to_dict(n.attribute)
# and cache them
numpy_cache = {}
def safe_numpy(t):
nonlocal numpy_cache
if t not in numpy_cache:
if DEBUG >= 1:
print("numpy cache miss", t)
numpy_cache[t] = t.numpy()
return numpy_cache[t]
def run_onnx(inputs={}, debug=False):
input_tensors = {}
intermediate_tensors = {}
@ -108,17 +111,6 @@ def get_run_onnx(onnx_model):
elif n.op_type == "Div": ret = inp[0].div(inp[1])
elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float']
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))])
elif n.op_type == "Unsqueeze":
if 'axes' not in opt: opt['axes'] = [int(x) for x in safe_numpy(inp[1])]
opt['axes'] = [len(inp[0].shape) + x if x < 0 else x for x in opt['axes']]
ptr = 0
new_shape = []
for i in range(len(inp[0].shape) + len(opt['axes'])):
if i in opt['axes']: new_shape.append(1)
else:
new_shape.append(inp[0].shape[ptr])
ptr += 1
ret = inp[0].reshape(new_shape)
elif n.op_type == "Resize":
# TODO: this is handcoded for YOLOv8
scales = safe_numpy(inp[2])
@ -134,14 +126,6 @@ def get_run_onnx(onnx_model):
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices]
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
elif n.op_type == "BatchNormalization":
invstd = inp[4].add(opt.get('epsilon', 1e-5))**-0.5
ret = inp[0].batchnorm(inp[1], inp[2], inp[3], invstd)
elif n.op_type == "Gemm":
A = inp[0].transpose() if opt.get('transA', 0) == 1 else inp[0]
B = inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1]
ret = opt.get('alpha', 1.0) * (A @ B)
if len(inp) > 2: ret += opt.get('beta', 1.0) * inp[2]
elif n.op_type == "Conv":
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
assert 'dilations' not in opt or opt['dilations'] == (1,1)
@ -182,12 +166,15 @@ def get_run_onnx(onnx_model):
starts = starts + inp[0].shape[axis] if starts < 0 else starts
arg[axis] = (starts, ends)
ret = inp[0].slice(arg=arg)
elif hasattr(onnx_ops, n.op_type):
ret = getattr(onnx_ops, n.op_type)(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
assert len(n.output) == 1, f"output size must be 1, it's {n.output}"
if debug: print(ret.shape)
intermediate_tensors[n.output[0]] = ret
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) == len(ret), f"output size must be {len(ret)}, it's {n.output}"
if debug: print([x.shape for x in ret])
for i,r in enumerate(ret): intermediate_tensors[n.output[i]] = r
#print(ret.numpy().mean())
if num == ONNXLIMIT:
output_tensor_names = n.output

34
extra/onnx_ops.py Normal file
View File

@ -0,0 +1,34 @@
from extra.onnx import safe_numpy
def Unsqueeze(data, axes):
axes = [len(data.shape) + int(x) if x < 0 else int(x) for x in safe_numpy(axes)]
ptr = 0
new_shape = []
for i in range(len(data.shape) + len(axes)):
if i in axes: new_shape.append(1)
else:
new_shape.append(data.shape[ptr])
ptr += 1
return data.reshape(new_shape)
def Gemm(A, B, C=None, alpha=1.0, beta=1.0, transA=0, transB=0):
ret = alpha * ((A.transpose() if transA == 1 else A) @ (B.transpose() if transB == 1 else B))
if C is not None: ret += beta * C
return ret
# TODO: this is copied from tinygrad/nn/__init__.py
def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, momentum=0.9, training_mode=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).pow(-0.5)
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
else:
invstd = (input_var + epsilon)**-0.5
return X.batchnorm(scale, B, input_mean, invstd)

View File

@ -42,8 +42,11 @@ backend_test = onnx.backend.test.BackendTest(TinygradBackend, __name__)
# passing node tests
backend_test.include('test_unsqueeze_*')
backend_test.include('test_sum_*')
backend_test.include('test_gemm_*')
backend_test.include('test_batchnorm_*')
"""
backend_test.include('test_sum_*')
backend_test.include('test_transpose_*')
backend_test.include('test_tanh_*')
@ -53,6 +56,7 @@ backend_test.include('test_reshape_*')
backend_test.include('test_flatten_*')
backend_test.include('test_expand_*')
backend_test.include('test_clip_*')
"""
# requires CastLike?
#backend_test.include('test_relu_*')