diff --git a/test/test_efficientnet.py b/test/test_efficientnet.py index 8294ed39..c48ec0fd 100644 --- a/test/test_efficientnet.py +++ b/test/test_efficientnet.py @@ -17,7 +17,7 @@ def _load_labels(): _LABELS = _load_labels() -def preprocess(img): +def preprocess(img, new=False): # preprocess image aspect_ratio = img.size[0] / img.size[1] img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) @@ -27,21 +27,29 @@ def preprocess(img): img = img[y0: y0 + 224, x0: x0 + 224] # low level preprocess - img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) - img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) - #img /= 255.0 - #img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) - #img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) + if new: + img = img.astype(np.float32) + img -= [127.0, 127.0, 127.0] + img /= [128.0, 128.0, 128.0] + img = img[None] + else: + img = np.moveaxis(img, [2, 0, 1], [0, 1, 2]) + img = img.astype(np.float32)[:3].reshape(1, 3, 224, 224) + img /= 255.0 + img -= np.array([0.485, 0.456, 0.406]).reshape((1, -1, 1, 1)) + img /= np.array([0.229, 0.224, 0.225]).reshape((1, -1, 1, 1)) return img + def _infer(model: EfficientNet, img, bs=1): + img = preprocess(img) # run the net if bs > 1: img = img.repeat(bs, axis=0) out = model.forward(Tensor(img)).cpu() return _LABELS[np.argmax(out.data[0])] -chicken_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg')) -car_img = preprocess(Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg')) +chicken_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/Chicken.jpg') +car_img = Image.open(pathlib.Path(__file__).parent / 'efficientnet/car.jpg') class TestEfficientNet(unittest.TestCase): @classmethod diff --git a/test/test_onnx.py b/test/test_onnx.py index 4b98b31a..5c26cdea 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -5,6 +5,7 @@ import numpy as np import onnx from extra.utils import fetch from tinygrad.tensor import Tensor +from tinygrad.helpers import prod def run_onnx(onnx_model, inputs={}, debug=False): def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim) @@ -29,11 +30,16 @@ def run_onnx(onnx_model, inputs={}, debug=False): # get weights and biases for inp in onnx_model.graph.initializer: - #print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) - if len(inp.raw_data) == 0: - tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims)) - else: + if len(inp.raw_data) > 0: tensors[inp.name] = buffer_parse(inp) + elif len(inp.float_data) > 0: + tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims)) + elif len(inp.int64_data) > 0: + tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims)) + else: + print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) + print(inp) + raise Exception("no data") # get inputs for inp in onnx_model.graph.input: @@ -45,8 +51,9 @@ def run_onnx(onnx_model, inputs={}, debug=False): assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}" tensors[inp.name] = Tensor(inputs[inp.name]) else: - print(f"filling {inp.name} shape {shape} with 0") - tensors[inp.name] = Tensor.zeros(*shape) + raise Exception(f"no data for {inp.name} with shape {shape}") + #print(f"filling {inp.name} shape {shape} with 0") + #tensors[inp.name] = Tensor.zeros(*shape) for num,n in enumerate(onnx_model.graph.node): @@ -55,17 +62,33 @@ def run_onnx(onnx_model, inputs={}, debug=False): opt = attribute_to_dict(n.attribute) if n.op_type == "Conv": x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None) - assert opt['dilations'] == (1,1) - ret = x.pad2d(opt['pads']).conv2d(w, b, stride=opt['strides'], groups=opt['group']) + assert 'dilations' not in opt or opt['dilations'] == (1,1) + # pads are in different order + pads = (opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]) + ret = x.pad2d(pads).conv2d(w, b, stride=opt['strides'], groups=opt['group'] if 'group' in opt else 1) elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha']) elif n.op_type == "Relu": ret = inp[0].relu() elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() elif n.op_type == "Tanh": ret = inp[0].tanh() - elif n.op_type == "Add": ret = inp[0] + inp[1] - elif n.op_type == "Sub": ret = inp[0] - inp[1] - elif n.op_type == "Mul": ret = inp[0] * inp[1] + elif n.op_type == "Softmax": ret = inp[0].softmax() + elif n.op_type in ["Add", "Sub", "Mul"]: + # TODO: add this to tinygrad + if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape): + inp[1] = inp[1].reshape(inp[0].shape) + # TODO: is this right? + if 'broadcast' in opt: + new_shape = [1 for x in range(len(inp[0].shape))] + new_shape[opt['broadcast']] = -1 + #print(inp[1].shape, new_shape) + inp[1] = inp[1].reshape(new_shape) + if n.op_type == "Add": ret = inp[0] + inp[1] + if n.op_type == "Sub": ret = inp[0] - inp[1] + if n.op_type == "Mul": ret = inp[0] * inp[1] elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0) elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis']) + elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm']) + elif n.op_type == "Squeeze": + ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']]) elif n.op_type == "Clip": if 'min' in opt and 'max' in opt: ret = inp[0].clip(opt['min'], opt['max']) else: ret = inp[0].clip(inp[1], inp[2]) @@ -78,26 +101,36 @@ def run_onnx(onnx_model, inputs={}, debug=False): tensors[o] = inp[0].slice(arg=arg) i = i+s continue - elif n.op_type == "Gemm": - a,w,b = inp + elif n.op_type in ["Gemm", "MatMul"]: + x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None) #print(a.shape, w.shape, b.shape) - if opt['transB'] == 1: w = w.transpose() - ret = a.linear(w,b) + if 'transB' in opt and opt['transB'] == 1: w = w.transpose() + ret = x.dot(w) if b is None else x.linear(w,b) elif n.op_type == "BatchNormalization": from tinygrad.nn import batch_normalize - ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt['epsilon']) + # does ONNX really specify a default eps? + #print(n) + ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt['epsilon'] if 'epsilon' in opt else 1e-5) + elif n.op_type == "AveragePool": + assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1) + ret = inp[0].avg_pool2d(opt['kernel_shape']) elif n.op_type == "MaxPool": + assert opt['kernel_shape'] == opt['strides'] + #opt['kernel_shape'] = opt['strides'] + # TODO: this is untested and probably wrong ret = inp[0].pad2d(opt['pads']) ret = ret.max_pool2d(opt['kernel_shape']) - chan = ret.shape[1] # strides aren't supported in max_pool - w = Tensor.eye(chan).reshape((chan, chan, 1, 1)) - ret = ret.conv2d(w, stride=opt['strides']) + #chan = ret.shape[1] + #w = Tensor.eye(chan).reshape((chan, chan, 1, 1)) + #ret = ret.conv2d(w, stride=opt['strides']) else: print("UNSUPPORTED", n.op_type, n.input, n.output) raise Exception(f"op_type {n.op_type} not supported") assert len(n.output) == 1 + if debug: print(ret.shape) tensors[n.output[0]] = ret + #print(ret.numpy().mean()) return {outp.name:tensors[outp.name] for outp in onnx_model.graph.output} @@ -125,16 +158,24 @@ class TestOnnxModel(unittest.TestCase): torch_out = run_onnx_torch(onnx_model, inputs).numpy() print(tinygrad_out, torch_out) np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2) + def test_resnet(self): - # mobilenet requires "Shape", "Gather", "Unsqueeze" - # googlenet doesn't work without dilated convs - dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v1-7.onnx") + # NOTE: many onnx models can't be run right now due to max pool with strides != kernel_size + dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx") onnx_model = onnx.load(io.BytesIO(dat)) - from test.test_efficientnet import chicken_img, car_img, _LABELS - inputs = {"data": chicken_img} - tinygrad_out = run_onnx(onnx_model, inputs, False)['resnetv15_dense0_fwd'].numpy() - cls = tinygrad_out.argmax() + from test.test_efficientnet import chicken_img, car_img, preprocess, _LABELS + + def run(img): + inputs = {"images:0": preprocess(img, new=True)} + tinygrad_out = list(run_onnx(onnx_model, inputs, False).values())[0].numpy() + return tinygrad_out.argmax() + + cls = run(chicken_img) print(cls, _LABELS[cls]) + assert _LABELS[cls] == "hen" + cls = run(car_img) + print(cls, _LABELS[cls]) + assert "car" in _LABELS[cls] if __name__ == "__main__": unittest.main() diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index c7e93fcf..13e8c92a 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -1,7 +1,7 @@ import functools import numpy as np import pyopencl as cl -from tinygrad.helpers import binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps +from tinygrad.helpers import prod, binary_broadcast, get_conv_args, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps cl_ctx, cl_queue = None, None def require_init_gpu(): @@ -21,7 +21,7 @@ class GPUBuffer: def __init__(self, shape, hostbuf=None): require_init_gpu() self.shape, self.dtype = tuple(shape), np.float32 - self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(np.prod(shape))) # padding + self.cl = hostbuf.cl if isinstance(hostbuf, GPUBuffer) else cl.Buffer(cl_ctx, cl.mem_flags.READ_WRITE, 4*roundup(prod(shape))) # padding if hostbuf is not None and not isinstance(hostbuf, GPUBuffer): cl.enqueue_copy(cl_queue, self.cl, hostbuf.astype(np.float32).ravel())