contiguous_view (#336)

* contiguous_view

* non contig reduce too

* conv fast

* maybe faster valid

* improve test_onnx

* improve params

* elementwise_op

* draw non contig
This commit is contained in:
George Hotz 2022-06-19 20:37:28 -07:00 committed by GitHub
parent fb72ea3fbd
commit d05e7c291a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 138 deletions

View File

@ -12,7 +12,7 @@ from tinygrad.nn import batch_normalize
MAX_CONVS = int(os.getenv("MAX_CONVS", -1))
def run_onnx(onnx_model, inputs={}, debug=False):
def get_run_onnx(onnx_model):
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
def buffer_parse(inp):
if inp.data_type == 1:
@ -46,92 +46,96 @@ def run_onnx(onnx_model, inputs={}, debug=False):
print(inp)
raise Exception("no data")
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = shape_to_tuple(inp.type.tensor_type.shape)
if shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
if inp.name in inputs:
input_shape = inputs[inp.name].shape
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
tensors[inp.name] = Tensor(inputs[inp.name])
else:
raise Exception(f"no data for {inp.name} with shape {shape}")
def run_onnx(inputs={}, debug=False):
input_tensors = {}
conv_count = 0
for num,n in enumerate(onnx_model.graph.node):
if debug: print(f"{num}: op {n.op_type}")
inp = [tensors[x] for x in n.input]
opt = attribute_to_dict(n.attribute)
# free ones
if n.op_type == "Relu": ret = inp[0].relu()
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
elif n.op_type == "Tanh": ret = inp[0].tanh()
elif n.op_type == "Softmax": ret = inp[0].softmax()
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
# one liners
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt['min'], opt['max'])))
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
elif n.op_type == "BatchNormalization": ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt.get('epsilon', 1e-5))
elif n.op_type == "Gemm": ret = inp[0].linear(inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1], inp[2])
elif n.op_type == "Conv":
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
assert 'dilations' not in opt or opt['dilations'] == (1,1)
if opt['pads'][0] == opt['pads'][2] and opt['pads'][1] == opt['pads'][3]:
# symmetric padding
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=opt['pads'][0:2])
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = shape_to_tuple(inp.type.tensor_type.shape)
if shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
if inp.name in inputs:
input_shape = inputs[inp.name].shape
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
input_tensors[inp.name] = Tensor(inputs[inp.name])
else:
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
conv_count += 1
if conv_count == MAX_CONVS:
ret.numpy()
break
elif n.op_type in ["Add", "Sub", "Mul"]:
# TODO: add this to tinygrad? i don't think it's in torch
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
inp[1] = inp[1].reshape(inp[0].shape)
# TODO: is this right?
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
if n.op_type == "Add": ret = inp[0] + inp[1]
if n.op_type == "Sub": ret = inp[0] - inp[1]
if n.op_type == "Mul": ret = inp[0] * inp[1]
elif n.op_type == "Split":
i = 0
arg = [(0,x) for x in inp[0].shape]
for o,s in zip(n.output, opt['split']):
arg[opt['axis']] = (i,i+s)
tensors[o] = inp[0].slice(arg=arg)
i = i+s
continue
elif n.op_type == "AveragePool":
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
ret = inp[0].avg_pool2d(opt['kernel_shape'])
elif n.op_type == "MaxPool":
assert opt['kernel_shape'] == opt['strides']
#opt['kernel_shape'] = opt['strides']
# TODO: this is untested and probably wrong
ret = inp[0].pad2d(opt['pads'])
ret = ret.max_pool2d(opt['kernel_shape'])
# strides aren't supported in max_pool
#chan = ret.shape[1]
#w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
#ret = ret.conv2d(w, stride=opt['strides'])
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
assert len(n.output) == 1
if debug: print(ret.shape)
tensors[n.output[0]] = ret
#print(ret.numpy().mean())
raise Exception(f"no data for {inp.name} with shape {shape}")
return {outp.name:tensors[outp.name] for outp in onnx_model.graph.output}
conv_count = 0
for num,n in enumerate(onnx_model.graph.node):
if debug: print(f"{num}: op {n.op_type}")
inp = [tensors[x] if x in tensors else input_tensors[x] for x in n.input]
opt = attribute_to_dict(n.attribute)
# free ones
if n.op_type == "Relu": ret = inp[0].relu()
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
elif n.op_type == "Tanh": ret = inp[0].tanh()
elif n.op_type == "Softmax": ret = inp[0].softmax()
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
# one liners
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt['min'], opt['max'])))
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
elif n.op_type == "BatchNormalization": ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt.get('epsilon', 1e-5))
elif n.op_type == "Gemm": ret = inp[0].linear(inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1], inp[2])
elif n.op_type == "Conv":
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
assert 'dilations' not in opt or opt['dilations'] == (1,1)
if opt['pads'][0] == opt['pads'][2] and opt['pads'][1] == opt['pads'][3]:
# symmetric padding
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=opt['pads'][0:2])
else:
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
conv_count += 1
if conv_count == MAX_CONVS:
ret.numpy()
break
elif n.op_type in ["Add", "Sub", "Mul"]:
# TODO: add this to tinygrad? i don't think it's in torch
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
inp[1] = inp[1].reshape(inp[0].shape)
# TODO: is this right?
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
if n.op_type == "Add": ret = inp[0] + inp[1]
if n.op_type == "Sub": ret = inp[0] - inp[1]
if n.op_type == "Mul": ret = inp[0] * inp[1]
elif n.op_type == "Split":
i = 0
arg = [(0,x) for x in inp[0].shape]
for o,s in zip(n.output, opt['split']):
arg[opt['axis']] = (i,i+s)
tensors[o] = inp[0].slice(arg=arg)
i = i+s
continue
elif n.op_type == "AveragePool":
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
ret = inp[0].avg_pool2d(opt['kernel_shape'])
elif n.op_type == "MaxPool":
assert opt['kernel_shape'] == opt['strides']
#opt['kernel_shape'] = opt['strides']
# TODO: this is untested and probably wrong
ret = inp[0].pad2d(opt['pads'])
ret = ret.max_pool2d(opt['kernel_shape'])
# strides aren't supported in max_pool
#chan = ret.shape[1]
#w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
#ret = ret.conv2d(w, stride=opt['strides'])
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
assert len(n.output) == 1
if debug: print(ret.shape)
tensors[n.output[0]] = ret
#print(ret.numpy().mean())
return {outp.name:tensors[outp.name] for outp in onnx_model.graph.output}
return run_onnx
def run_onnx_torch(onnx_model, inputs):
import torch
@ -142,9 +146,28 @@ def run_onnx_torch(onnx_model, inputs):
return torch_out
class TestOnnxModel(unittest.TestCase):
def test_benchmark_openpilot_model(self):
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
for _ in range(5):
inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
"desire": np.zeros((1, 8)),
"traffic_convention": np.array([[1., 0.]]),
"initial_state": np.zeros((1, 512))
}
inputs = {k:v.astype(np.float32) for k,v in inputs.items()}
st = time.monotonic()
tinygrad_out = run_onnx(inputs)['outputs'].numpy()
et = time.monotonic() - st
print(f"ran openpilot model in {et*1000.0:.2f} ms")
def test_openpilot_model(self):
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
@ -154,7 +177,7 @@ class TestOnnxModel(unittest.TestCase):
}
inputs = {k:v.astype(np.float32) for k,v in inputs.items()}
st = time.monotonic()
tinygrad_out = run_onnx(onnx_model, inputs)['outputs'].numpy()
tinygrad_out = run_onnx(inputs)['outputs'].numpy()
et = time.monotonic() - st
print(f"ran openpilot model in {et*1000.0:.2f} ms")
@ -167,10 +190,11 @@ class TestOnnxModel(unittest.TestCase):
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx")
onnx_model = onnx.load(io.BytesIO(dat))
from test.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
run_onnx = get_run_onnx(onnx_model)
def run(img):
inputs = {"images:0": preprocess(img, new=True)}
tinygrad_out = list(run_onnx(onnx_model, inputs, False).values())[0].numpy()
tinygrad_out = list(run_onnx(inputs, False).values())[0].numpy()
return tinygrad_out.argmax()
cls = run(chicken_img)

View File

@ -1,7 +1,9 @@
import functools
import numpy as np
import pyopencl as cl
from typing import List, Tuple
from tinygrad.helpers import prod
from tinygrad.llops.ops_cpu import unary_op
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
@ -42,7 +44,7 @@ class GPUBuffer:
def toCPU(self):
data = np.empty(self.shape, dtype=np.float32)
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
cl.enqueue_copy(cl_queue, data, contiguous(self).cl, is_blocking=True)
return data
class CLProgram:
@ -59,33 +61,27 @@ def clbuild(name, prg, options=tuple(), argdtypes=None):
return CLProgram(name, prg, options, argdtypes)
code_for_op = {
UnaryOps.RELU: 'max(A, (float)0.)', UnaryOps.EXP: 'exp(A)', UnaryOps.LOG: 'log(A)', UnaryOps.NEG: '-A', UnaryOps.SIGN: 'sign(A)',
UnaryOps.NOOP: "A", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.NEG: "-A", UnaryOps.SIGN: "sign(A)",
BinaryOps.ADD: "A+B", BinaryOps.SUB: "A-B", BinaryOps.MUL: "A*B", BinaryOps.DIV: "B/A", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)"
}
def unary_op(op, x):
ret = GPUBuffer(x.shape)
unop = clbuild("unop", """
__kernel void unop(__global const float4 *a_g, __global float4 *res_g) {
int gid = get_global_id(0);
float4 A = a_g[gid];
res_g[gid] = convert_float4("""+code_for_op[op]+""");
}""")
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
def contiguous_view(x:GPUBuffer, name:str):
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
def elementwise_op(bufs: List[Tuple[str, GPUBuffer]], code):
assert all(buf.shape == bufs[0][1].shape for _, buf in bufs)
ret = GPUBuffer(bufs[0][1].shape)
ewop = clbuild("ewop", '\n'.join([contiguous_view(buf, name) for name, buf in bufs])+
"__kernel void ewop(__global float *res_g, "+','.join([f"__global const float *{name}_g" for name, _ in bufs])+") {"+
"int gid = get_global_id(0);"+
'\n'.join([f"float {name} = get_{name}({name}_g, gid);" for name, _ in bufs])+
f"res_g[gid] = {code}; }}")
ewop([prod(ret.shape)], None, ret.cl, *[buf.cl for _, buf in bufs])
return ret
def binary_op(op, x, y):
ret = GPUBuffer(x.shape)
assert x.shape == ret.shape and y.shape == ret.shape
binop = clbuild("binop", """
__kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) {
int gid = get_global_id(0);
float4 A = a_g[gid];
float4 B = b_g[gid];
res_g[gid] = convert_float4("""+code_for_op[op]+""");
}""")
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
return ret
def unary_op(op, x): return elementwise_op([("A", x)], code_for_op[op])
def binary_op(op, x, y): return elementwise_op([("A", x), ("B", y)], code_for_op[op])
def contiguous(x:GPUBuffer): return x if x.st.contiguous else unary_op(UnaryOps.NOOP, x)
def reduce_op(op, inp, new_shape):
ret = GPUBuffer(new_shape)
@ -108,12 +104,12 @@ def reduce_op(op, inp, new_shape):
acc *= shp
# TODO: support multistage reduces
prg = """
prg = contiguous_view(inp, 'A')+"""
__kernel void reduce(__global const float *a_g, __global float *res_g) {
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";
float out = """+start+""";\n"""+ \
'\n'.join(loop_start[::-1])+"""
float a = a_g[idx];
float a = get_A(a_g, idx);
"""+code+""";\n"""+ \
'\n'.join(loop_end)+"""
res_g[gid] = out;
@ -121,26 +117,20 @@ def reduce_op(op, inp, new_shape):
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
return ret
def contiguous(x, ret=None):
if ret is None: ret = GPUBuffer(x.st.shape)
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
int gid = get_global_id(0); int valid = 1; int idx = gid; """+x.st.expr().replace('//', '/')+""";
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
}""")([prod(ret.shape)], None, x.cl, ret.cl)
return ret
def movement_op(op, x, arg=None):
def movement_op(op, x, arg):
ret = GPUBuffer(x.st, x)
ret.st.movement_op(op, arg)
if ret.st.contiguous: return ret
else: return contiguous(ret)
return ret
def processing_op(op,x,w,C):
ret = GPUBuffer((C.bs, C.cout, C.oy, C.ox))
assert op == ProcessingOps.CONV, f"{op} isn't supported"
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "cin", "ys", "xs", "dx", "dy", "px", "py"])
params = [(f"int {x}", getattr(C, x)) for x in ["groups", "rcout", "oy", "ox", "iy", "ix"]]
conv_prg = clbuild("conv", """
__kernel void conv(__global const float* restrict input, __global const float* restrict weight, __global float* restrict output,
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs, int dx, int dy, int px, int py) {
"""+','.join([x[0] for x in params])+""") {
"""+ints+"""
int B = get_global_id(0)/(groups*rcout); // range 0-bs
int g = (get_global_id(0)/rcout)%groups;
int c = get_global_id(0) % rcout;
@ -154,24 +144,22 @@ def processing_op(op,x,w,C):
float acc = 0.0;
for (int ci = 0; ci < cin; ci++) {
#ifdef ONEBYONE
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + IY*ix + IX] * \
weight[g*rcout*cin + c*cin + ci];
#else
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
int idx_y = y*dy + IY - py;
int idx_x = x*dx + IX - px;
#ifdef ALLVALID
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
#else
int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix);
acc += valid ? input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x] : 0.0;
} }
acc += valid * input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
#endif
} }
}
output[gid] = acc;
}""",
options=tuple(["-DONEBYONE"]) if C.H == 1 and C.W == 1 and C.px == 0 and C.py == 0 else tuple(),
argdtypes=tuple([None, None, None] + [np.int32]*16))
conv_prg([C.bs*C.cout, C.oy, C.ox], None, x.cl, w.cl, ret.cl,
*[x for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
options=tuple(["-DALLVALID"]) if C.px == 0 and C.py == 0 else tuple(),
argdtypes=tuple([None, None, None] + [np.int32]*len(params)))
conv_prg([C.bs*C.cout, C.oy, C.ox], None, contiguous(x).cl, contiguous(w).cl, ret.cl, *[x[1] for x in params])
return ret

View File

@ -1,6 +1,6 @@
# TODO: move Device to here and proxy buffer call
from enum import Enum
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
UnaryOps = Enum("UnaryOps", ["NOOP", "RELU", "EXP", "LOG", "NEG", "SIGN"])
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
@ -46,9 +46,11 @@ def log_op(optype, op, ret, inp, dashed=False):
G.add_edge(nm(x), nm(ret), label=sop, color='#808080' if dashed else '', style='dashed' if dashed else '')
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
if nm(ret) not in G.nodes: G.add_node(nm(ret))
st = getattr(ret, "st", None)
non_contiguous = st is not None and not st.contiguous
G.nodes[nm(ret)]['label'] = str(ret.shape)
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype]
G.nodes[nm(ret)]['style'] = 'filled'
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype] + ('80' if non_contiguous else '')
G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled'
class Ops:
def unary_op(ctx, op:UnaryOps, x):