mirror of https://github.com/commaai/tinygrad.git
test_onnx works with enet also
This commit is contained in:
parent
6fdb276886
commit
2305a5347b
|
@ -17,7 +17,7 @@ def _load_labels():
|
|||
|
||||
_LABELS = _load_labels()
|
||||
|
||||
def preprocess(img):
|
||||
def preprocess(img, new=False):
|
||||
# preprocess image
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
||||
|
@ -27,21 +27,29 @@ def preprocess(img):
|
|||
img = img[y0: y0 + 224, x0: x0 + 224]
|
||||
|
||||
# low level preprocess
|
||||
img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
|
||||
img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
|
||||
#img /= 255.0
|
||||
#img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
|
||||
#img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
|
||||
if new:
|
||||
img = img.astype(np.float32)
|
||||
img -= [127.0, 127.0, 127.0]
|
||||
img /= [128.0, 128.0, 128.0]
|
||||
img = img[None]
|
||||
else:
|
||||
img = np.moveaxis(img, [2, 0, 1], [0, 1, 2])
|
||||
img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224)
|
||||
img /= 255.0
|
||||
img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1))
|
||||
img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1))
|
||||
return img
|
||||
|
||||
|
||||
def _infer(model: EfficientNet, img, bs=1):
|
||||
img = preprocess(img)
|
||||
# run the net
|
||||
if bs > 1: img = img.repeat(bs, axis=0)
|
||||
out = model.forward(Tensor(img)).cpu()
|
||||
return _LABELS[np.argmax(out.data[0])]
|
||||
|
||||
chicken_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg'))
|
||||
car_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg'))
|
||||
chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')
|
||||
car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')
|
||||
|
||||
class TestEfficientNet(unittest.TestCase):
|
||||
@classmethod
|
||||
|
|
|
@ -5,6 +5,7 @@ import numpy as np
|
|||
import onnx
|
||||
from extra.utils import fetch
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
|
||||
def run_onnx(onnx_model, inputs={}, debug=False):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
|
@ -29,11 +30,16 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
|||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
#print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
if len(inp.raw_data) == 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims))
|
||||
else:
|
||||
if len(inp.raw_data) > 0:
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
elif len(inp.float_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims))
|
||||
elif len(inp.int64_data) > 0:
|
||||
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims))
|
||||
else:
|
||||
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
|
||||
print(inp)
|
||||
raise Exception("no data")
|
||||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
|
@ -45,8 +51,9 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
|||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
tensors[inp.name] = Tensor(inputs[inp.name])
|
||||
else:
|
||||
print(f"filling {inp.name} shape {shape} with 0")
|
||||
tensors[inp.name] = Tensor.zeros(*shape)
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
#print(f"filling {inp.name} shape {shape} with 0")
|
||||
#tensors[inp.name] = Tensor.zeros(*shape)
|
||||
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
|
@ -55,17 +62,33 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
|||
opt = attribute_to_dict(n.attribute)
|
||||
if n.op_type == "Conv":
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert opt['dilations'] == (1,1)
|
||||
ret = x.pad2d(opt['pads']).conv2d(w, b, stride=opt['strides'], groups=opt['group'])
|
||||
assert 'dilations' not in opt or opt['dilations'] == (1,1)
|
||||
# pads are in different order
|
||||
pads = (opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3])
|
||||
ret = x.pad2d(pads).conv2d(w, b, stride=opt['strides'], groups=opt['group'] if 'group' in opt else 1)
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
elif n.op_type == "Relu": ret = inp[0].relu()
|
||||
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
|
||||
elif n.op_type == "Tanh": ret = inp[0].tanh()
|
||||
elif n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
elif n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
elif n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
elif n.op_type == "Softmax": ret = inp[0].softmax()
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
# TODO: add this to tinygrad
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
# TODO: is this right?
|
||||
if 'broadcast' in opt:
|
||||
new_shape = [1 for x in range(len(inp[0].shape))]
|
||||
new_shape[opt['broadcast']] = -1
|
||||
#print(inp[1].shape, new_shape)
|
||||
inp[1] = inp[1].reshape(new_shape)
|
||||
if n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
if n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
if n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
|
||||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
|
||||
elif n.op_type == "Squeeze":
|
||||
ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
|
||||
elif n.op_type == "Clip":
|
||||
if 'min' in opt and 'max' in opt: ret = inp[0].clip(opt['min'], opt['max'])
|
||||
else: ret = inp[0].clip(inp[1], inp[2])
|
||||
|
@ -78,26 +101,36 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
|||
tensors[o] = inp[0].slice(arg=arg)
|
||||
i = i+s
|
||||
continue
|
||||
elif n.op_type == "Gemm":
|
||||
a,w,b = inp
|
||||
elif n.op_type in ["Gemm", "MatMul"]:
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
#print(a.shape, w.shape, b.shape)
|
||||
if opt['transB'] == 1: w = w.transpose()
|
||||
ret = a.linear(w,b)
|
||||
if 'transB' in opt and opt['transB'] == 1: w = w.transpose()
|
||||
ret = x.dot(w) if b is None else x.linear(w,b)
|
||||
elif n.op_type == "BatchNormalization":
|
||||
from tinygrad.nn import batch_normalize
|
||||
ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt['epsilon'])
|
||||
# does ONNX really specify a default eps?
|
||||
#print(n)
|
||||
ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt['epsilon'] if 'epsilon' in opt else 1e-5)
|
||||
elif n.op_type == "AveragePool":
|
||||
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
|
||||
ret = inp[0].avg_pool2d(opt['kernel_shape'])
|
||||
elif n.op_type == "MaxPool":
|
||||
assert opt['kernel_shape'] == opt['strides']
|
||||
#opt['kernel_shape'] = opt['strides']
|
||||
# TODO: this is untested and probably wrong
|
||||
ret = inp[0].pad2d(opt['pads'])
|
||||
ret = ret.max_pool2d(opt['kernel_shape'])
|
||||
chan = ret.shape[1]
|
||||
# strides aren't supported in max_pool
|
||||
w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
|
||||
ret = ret.conv2d(w, stride=opt['strides'])
|
||||
#chan = ret.shape[1]
|
||||
#w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
|
||||
#ret = ret.conv2d(w, stride=opt['strides'])
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
assert len(n.output) == 1
|
||||
if debug: print(ret.shape)
|
||||
tensors[n.output[0]] = ret
|
||||
#print(ret.numpy().mean())
|
||||
|
||||
return {outp.name:tensors[outp.name] for outp in onnx_model.graph.output}
|
||||
|
||||
|
@ -125,16 +158,24 @@ class TestOnnxModel(unittest.TestCase):
|
|||
torch_out = run_onnx_torch(onnx_model, inputs).numpy()
|
||||
print(tinygrad_out, torch_out)
|
||||
np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
|
||||
def test_resnet(self):
|
||||
# mobilenet requires "Shape", "Gather", "Unsqueeze"
|
||||
# googlenet doesn't work without dilated convs
|
||||
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx")
|
||||
# NOTE: many onnx models can't be run right now due to max pool with strides != kernel_size
|
||||
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx")
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
from test.test_efficientnet import chicken_img, car_img, _LABELS
|
||||
inputs = {"data": chicken_img}
|
||||
tinygrad_out = run_onnx(onnx_model, inputs, False)['resnetv15_dense0_fwd'].numpy()
|
||||
cls = tinygrad_out.argmax()
|
||||
from test.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
|
||||
|
||||
def run(img):
|
||||
inputs = {"images:0": preprocess(img, new=True)}
|
||||
tinygrad_out = list(run_onnx(onnx_model, inputs, False).values())[0].numpy()
|
||||
return tinygrad_out.argmax()
|
||||
|
||||
cls = run(chicken_img)
|
||||
print(cls, _LABELS[cls])
|
||||
assert _LABELS[cls] == "hen"
|
||||
cls = run(car_img)
|
||||
print(cls, _LABELS[cls])
|
||||
assert "car" in _LABELS[cls]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import functools
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.helpers import prod, binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
|
||||
cl_ctx, cl_queue = None, None
|
||||
def require_init_gpu():
|
||||
|
@ -21,7 +21,7 @@ class GPUBuffer:
|
|||
def __init__(self, shape, hostbuf=None):
|
||||
require_init_gpu()
|
||||
self.shape, self.dtype = tuple(shape), np.float32
|
||||
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(np.prod(shape))) # padding
|
||||
self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(prod(shape))) # padding
|
||||
if hostbuf is not None and not isinstance(hostbuf, GPUBuffer):
|
||||
cl.enqueue_copy(cl_queue, self.cl, hostbuf.astype(np.float32).ravel())
|
||||
|
||||
|
|
Loading…
Reference in New Issue