mirror of https://github.com/commaai/tinygrad.git
contiguous_view (#336)
* contiguous_view * non contig reduce too * conv fast * maybe faster valid * improve test_onnx * improve params * elementwise_op * draw non contig
This commit is contained in:
parent
fb72ea3fbd
commit
d05e7c291a
|
@ -12,7 +12,7 @@ from tinygrad.nn import batch_normalize
|
|||
|
||||
MAX_CONVS = int(os.getenv("MAX_CONVS", -1))
|
||||
|
||||
def run_onnx(onnx_model, inputs={}, debug=False):
|
||||
def get_run_onnx(onnx_model):
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
def buffer_parse(inp):
|
||||
if inp.data_type == 1:
|
||||
|
@ -46,92 +46,96 @@ def run_onnx(onnx_model, inputs={}, debug=False):
|
|||
print(inp)
|
||||
raise Exception("no data")
|
||||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
if inp.name in tensors: continue
|
||||
shape = shape_to_tuple(inp.type.tensor_type.shape)
|
||||
if shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
|
||||
if inp.name in inputs:
|
||||
input_shape = inputs[inp.name].shape
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
tensors[inp.name] = Tensor(inputs[inp.name])
|
||||
else:
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
def run_onnx(inputs={}, debug=False):
|
||||
input_tensors = {}
|
||||
|
||||
conv_count = 0
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
if debug: print(f"{num}: op {n.op_type}")
|
||||
inp = [tensors[x] for x in n.input]
|
||||
opt = attribute_to_dict(n.attribute)
|
||||
|
||||
# free ones
|
||||
if n.op_type == "Relu": ret = inp[0].relu()
|
||||
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
|
||||
elif n.op_type == "Tanh": ret = inp[0].tanh()
|
||||
elif n.op_type == "Softmax": ret = inp[0].softmax()
|
||||
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
|
||||
# one liners
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt['min'], opt['max'])))
|
||||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
|
||||
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
|
||||
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
|
||||
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
|
||||
elif n.op_type == "BatchNormalization": ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt.get('epsilon', 1e-5))
|
||||
elif n.op_type == "Gemm": ret = inp[0].linear(inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1], inp[2])
|
||||
elif n.op_type == "Conv":
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert 'dilations' not in opt or opt['dilations'] == (1,1)
|
||||
if opt['pads'][0] == opt['pads'][2] and opt['pads'][1] == opt['pads'][3]:
|
||||
# symmetric padding
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=opt['pads'][0:2])
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
if inp.name in tensors: continue
|
||||
shape = shape_to_tuple(inp.type.tensor_type.shape)
|
||||
if shape[0] == 0: shape = tuple([1]+list(shape[1:])) # 1 batch size
|
||||
if inp.name in inputs:
|
||||
input_shape = inputs[inp.name].shape
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
input_tensors[inp.name] = Tensor(inputs[inp.name])
|
||||
else:
|
||||
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
|
||||
conv_count += 1
|
||||
if conv_count == MAX_CONVS:
|
||||
ret.numpy()
|
||||
break
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
# TODO: is this right?
|
||||
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
|
||||
if n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
if n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
if n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
elif n.op_type == "Split":
|
||||
i = 0
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
for o,s in zip(n.output, opt['split']):
|
||||
arg[opt['axis']] = (i,i+s)
|
||||
tensors[o] = inp[0].slice(arg=arg)
|
||||
i = i+s
|
||||
continue
|
||||
elif n.op_type == "AveragePool":
|
||||
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
|
||||
ret = inp[0].avg_pool2d(opt['kernel_shape'])
|
||||
elif n.op_type == "MaxPool":
|
||||
assert opt['kernel_shape'] == opt['strides']
|
||||
#opt['kernel_shape'] = opt['strides']
|
||||
# TODO: this is untested and probably wrong
|
||||
ret = inp[0].pad2d(opt['pads'])
|
||||
ret = ret.max_pool2d(opt['kernel_shape'])
|
||||
# strides aren't supported in max_pool
|
||||
#chan = ret.shape[1]
|
||||
#w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
|
||||
#ret = ret.conv2d(w, stride=opt['strides'])
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
assert len(n.output) == 1
|
||||
if debug: print(ret.shape)
|
||||
tensors[n.output[0]] = ret
|
||||
#print(ret.numpy().mean())
|
||||
raise Exception(f"no data for {inp.name} with shape {shape}")
|
||||
|
||||
return {outp.name:tensors[outp.name] for outp in onnx_model.graph.output}
|
||||
conv_count = 0
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
if debug: print(f"{num}: op {n.op_type}")
|
||||
inp = [tensors[x] if x in tensors else input_tensors[x] for x in n.input]
|
||||
opt = attribute_to_dict(n.attribute)
|
||||
|
||||
# free ones
|
||||
if n.op_type == "Relu": ret = inp[0].relu()
|
||||
elif n.op_type == "Sigmoid": ret = inp[0].sigmoid()
|
||||
elif n.op_type == "Tanh": ret = inp[0].tanh()
|
||||
elif n.op_type == "Softmax": ret = inp[0].softmax()
|
||||
elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1])
|
||||
# one liners
|
||||
elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha'])
|
||||
elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt['min'], opt['max'])))
|
||||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Flatten": ret = inp[0].flatten(opt['axis'] if 'axis' in opt else 0)
|
||||
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt['perm'])
|
||||
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
|
||||
elif n.op_type == "GlobalAveragePool": ret = inp[0].mean(axis=tuple(range(2, len(inp[0].shape))), keepdim=True)
|
||||
elif n.op_type == "BatchNormalization": ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], inp[4], opt.get('epsilon', 1e-5))
|
||||
elif n.op_type == "Gemm": ret = inp[0].linear(inp[1].transpose() if opt.get('transB', 0) == 1 else inp[1], inp[2])
|
||||
elif n.op_type == "Conv":
|
||||
x,w,b = inp if len(inp) == 3 else (inp[0], inp[1], None)
|
||||
assert 'dilations' not in opt or opt['dilations'] == (1,1)
|
||||
if opt['pads'][0] == opt['pads'][2] and opt['pads'][1] == opt['pads'][3]:
|
||||
# symmetric padding
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1), padding=opt['pads'][0:2])
|
||||
else:
|
||||
x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3]))
|
||||
ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1))
|
||||
conv_count += 1
|
||||
if conv_count == MAX_CONVS:
|
||||
ret.numpy()
|
||||
break
|
||||
elif n.op_type in ["Add", "Sub", "Mul"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
# TODO: is this right?
|
||||
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
|
||||
if n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
if n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
if n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
elif n.op_type == "Split":
|
||||
i = 0
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
for o,s in zip(n.output, opt['split']):
|
||||
arg[opt['axis']] = (i,i+s)
|
||||
tensors[o] = inp[0].slice(arg=arg)
|
||||
i = i+s
|
||||
continue
|
||||
elif n.op_type == "AveragePool":
|
||||
assert opt['kernel_shape'] == opt['strides'] or opt['strides'] == (1,1)
|
||||
ret = inp[0].avg_pool2d(opt['kernel_shape'])
|
||||
elif n.op_type == "MaxPool":
|
||||
assert opt['kernel_shape'] == opt['strides']
|
||||
#opt['kernel_shape'] = opt['strides']
|
||||
# TODO: this is untested and probably wrong
|
||||
ret = inp[0].pad2d(opt['pads'])
|
||||
ret = ret.max_pool2d(opt['kernel_shape'])
|
||||
# strides aren't supported in max_pool
|
||||
#chan = ret.shape[1]
|
||||
#w = Tensor.eye(chan).reshape((chan, chan, 1, 1))
|
||||
#ret = ret.conv2d(w, stride=opt['strides'])
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
assert len(n.output) == 1
|
||||
if debug: print(ret.shape)
|
||||
tensors[n.output[0]] = ret
|
||||
#print(ret.numpy().mean())
|
||||
|
||||
return {outp.name:tensors[outp.name] for outp in onnx_model.graph.output}
|
||||
return run_onnx
|
||||
|
||||
def run_onnx_torch(onnx_model, inputs):
|
||||
import torch
|
||||
|
@ -142,9 +146,28 @@ def run_onnx_torch(onnx_model, inputs):
|
|||
return torch_out
|
||||
|
||||
class TestOnnxModel(unittest.TestCase):
|
||||
def test_benchmark_openpilot_model(self):
|
||||
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
for _ in range(5):
|
||||
inputs = {
|
||||
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
"desire": np.zeros((1, 8)),
|
||||
"traffic_convention": np.array([[1., 0.]]),
|
||||
"initial_state": np.zeros((1, 512))
|
||||
}
|
||||
inputs = {k:v.astype(np.float32) for k,v in inputs.items()}
|
||||
st = time.monotonic()
|
||||
tinygrad_out = run_onnx(inputs)['outputs'].numpy()
|
||||
et = time.monotonic() - st
|
||||
print(f"ran openpilot model in {et*1000.0:.2f} ms")
|
||||
|
||||
def test_openpilot_model(self):
|
||||
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
inputs = {
|
||||
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
|
@ -154,7 +177,7 @@ class TestOnnxModel(unittest.TestCase):
|
|||
}
|
||||
inputs = {k:v.astype(np.float32) for k,v in inputs.items()}
|
||||
st = time.monotonic()
|
||||
tinygrad_out = run_onnx(onnx_model, inputs)['outputs'].numpy()
|
||||
tinygrad_out = run_onnx(inputs)['outputs'].numpy()
|
||||
et = time.monotonic() - st
|
||||
print(f"ran openpilot model in {et*1000.0:.2f} ms")
|
||||
|
||||
|
@ -167,10 +190,11 @@ class TestOnnxModel(unittest.TestCase):
|
|||
dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx")
|
||||
onnx_model = onnx.load(io.BytesIO(dat))
|
||||
from test.test_efficientnet import chicken_img, car_img, preprocess, _LABELS
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
|
||||
def run(img):
|
||||
inputs = {"images:0": preprocess(img, new=True)}
|
||||
tinygrad_out = list(run_onnx(onnx_model, inputs, False).values())[0].numpy()
|
||||
tinygrad_out = list(run_onnx(inputs, False).values())[0].numpy()
|
||||
return tinygrad_out.argmax()
|
||||
|
||||
cls = run(chicken_img)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import functools
|
||||
import numpy as np
|
||||
import pyopencl as cl
|
||||
from typing import List, Tuple
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.llops.ops_cpu import unary_op
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps
|
||||
from tinygrad.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
|
||||
|
@ -42,7 +44,7 @@ class GPUBuffer:
|
|||
|
||||
def toCPU(self):
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
cl.enqueue_copy(cl_queue, data, self.cl, is_blocking=True)
|
||||
cl.enqueue_copy(cl_queue, data, contiguous(self).cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
class CLProgram:
|
||||
|
@ -59,33 +61,27 @@ def clbuild(name, prg, options=tuple(), argdtypes=None):
|
|||
return CLProgram(name, prg, options, argdtypes)
|
||||
|
||||
code_for_op = {
|
||||
UnaryOps.RELU: 'max(A, (float)0.)', UnaryOps.EXP: 'exp(A)', UnaryOps.LOG: 'log(A)', UnaryOps.NEG: '-A', UnaryOps.SIGN: 'sign(A)',
|
||||
UnaryOps.NOOP: "A", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.EXP: "exp(A)", UnaryOps.LOG: "log(A)", UnaryOps.NEG: "-A", UnaryOps.SIGN: "sign(A)",
|
||||
BinaryOps.ADD: "A+B", BinaryOps.SUB: "A-B", BinaryOps.MUL: "A*B", BinaryOps.DIV: "B/A", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)"
|
||||
}
|
||||
|
||||
def unary_op(op, x):
|
||||
ret = GPUBuffer(x.shape)
|
||||
unop = clbuild("unop", """
|
||||
__kernel void unop(__global const float4 *a_g, __global float4 *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float4 A = a_g[gid];
|
||||
res_g[gid] = convert_float4("""+code_for_op[op]+""");
|
||||
}""")
|
||||
unop([roundup(prod(ret.shape))//4], None, x.cl, ret.cl)
|
||||
def contiguous_view(x:GPUBuffer, name:str):
|
||||
return f"inline float get_{name}(__global const float *x, int gid) {{ int valid = 1; int idx = gid; {x.st.expr().replace('//', '/')}; return valid ? x[idx] : 0.0;}}"
|
||||
|
||||
def elementwise_op(bufs: List[Tuple[str, GPUBuffer]], code):
|
||||
assert all(buf.shape == bufs[0][1].shape for _, buf in bufs)
|
||||
ret = GPUBuffer(bufs[0][1].shape)
|
||||
ewop = clbuild("ewop", '\n'.join([contiguous_view(buf, name) for name, buf in bufs])+
|
||||
"__kernel void ewop(__global float *res_g, "+','.join([f"__global const float *{name}_g" for name, _ in bufs])+") {"+
|
||||
"int gid = get_global_id(0);"+
|
||||
'\n'.join([f"float {name} = get_{name}({name}_g, gid);" for name, _ in bufs])+
|
||||
f"res_g[gid] = {code}; }}")
|
||||
ewop([prod(ret.shape)], None, ret.cl, *[buf.cl for _, buf in bufs])
|
||||
return ret
|
||||
|
||||
def binary_op(op, x, y):
|
||||
ret = GPUBuffer(x.shape)
|
||||
assert x.shape == ret.shape and y.shape == ret.shape
|
||||
binop = clbuild("binop", """
|
||||
__kernel void binop(__global const float4 *a_g, __global const float4 *b_g, __global float4 *res_g) {
|
||||
int gid = get_global_id(0);
|
||||
float4 A = a_g[gid];
|
||||
float4 B = b_g[gid];
|
||||
res_g[gid] = convert_float4("""+code_for_op[op]+""");
|
||||
}""")
|
||||
binop([roundup(prod(ret.shape))//4], None, x.cl, y.cl, ret.cl)
|
||||
return ret
|
||||
def unary_op(op, x): return elementwise_op([("A", x)], code_for_op[op])
|
||||
def binary_op(op, x, y): return elementwise_op([("A", x), ("B", y)], code_for_op[op])
|
||||
def contiguous(x:GPUBuffer): return x if x.st.contiguous else unary_op(UnaryOps.NOOP, x)
|
||||
|
||||
def reduce_op(op, inp, new_shape):
|
||||
ret = GPUBuffer(new_shape)
|
||||
|
@ -108,12 +104,12 @@ def reduce_op(op, inp, new_shape):
|
|||
acc *= shp
|
||||
|
||||
# TODO: support multistage reduces
|
||||
prg = """
|
||||
prg = contiguous_view(inp, 'A')+"""
|
||||
__kernel void reduce(__global const float *a_g, __global float *res_g) {
|
||||
int gid = get_global_id(0); int idx = gid;"""+view.expr.replace('//', '/')+""";
|
||||
float out = """+start+""";\n"""+ \
|
||||
'\n'.join(loop_start[::-1])+"""
|
||||
float a = a_g[idx];
|
||||
float a = get_A(a_g, idx);
|
||||
"""+code+""";\n"""+ \
|
||||
'\n'.join(loop_end)+"""
|
||||
res_g[gid] = out;
|
||||
|
@ -121,26 +117,20 @@ def reduce_op(op, inp, new_shape):
|
|||
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def contiguous(x, ret=None):
|
||||
if ret is None: ret = GPUBuffer(x.st.shape)
|
||||
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
|
||||
int gid = get_global_id(0); int valid = 1; int idx = gid; """+x.st.expr().replace('//', '/')+""";
|
||||
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
|
||||
}""")([prod(ret.shape)], None, x.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def movement_op(op, x, arg=None):
|
||||
def movement_op(op, x, arg):
|
||||
ret = GPUBuffer(x.st, x)
|
||||
ret.st.movement_op(op, arg)
|
||||
if ret.st.contiguous: return ret
|
||||
else: return contiguous(ret)
|
||||
return ret
|
||||
|
||||
def processing_op(op,x,w,C):
|
||||
ret = GPUBuffer((C.bs, C.cout, C.oy, C.ox))
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
ints = ''.join(f"int {x} = {getattr(C, x)};" for x in ["H", "W", "cin", "ys", "xs", "dx", "dy", "px", "py"])
|
||||
params = [(f"int {x}", getattr(C, x)) for x in ["groups", "rcout", "oy", "ox", "iy", "ix"]]
|
||||
conv_prg = clbuild("conv", """
|
||||
__kernel void conv(__global const float* restrict input, __global const float* restrict weight, __global float* restrict output,
|
||||
int H, int W, int groups, int rcout, int cin, int oy, int ox, int iy, int ix, int ys, int xs, int bs, int dx, int dy, int px, int py) {
|
||||
"""+','.join([x[0] for x in params])+""") {
|
||||
"""+ints+"""
|
||||
int B = get_global_id(0)/(groups*rcout); // range 0-bs
|
||||
int g = (get_global_id(0)/rcout)%groups;
|
||||
int c = get_global_id(0) % rcout;
|
||||
|
@ -154,24 +144,22 @@ def processing_op(op,x,w,C):
|
|||
|
||||
float acc = 0.0;
|
||||
for (int ci = 0; ci < cin; ci++) {
|
||||
|
||||
#ifdef ONEBYONE
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + IY*ix + IX] * \
|
||||
weight[g*rcout*cin + c*cin + ci];
|
||||
#else
|
||||
for (int y = 0; y < H; y++) { for (int x = 0; x < W; x++) {
|
||||
int idx_y = y*dy + IY - py;
|
||||
int idx_x = x*dx + IX - px;
|
||||
#ifdef ALLVALID
|
||||
acc += input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
#else
|
||||
int valid = (idx_y >= 0 && idx_y < iy && idx_x >= 0 && idx_x < ix);
|
||||
acc += valid ? input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + idx_y*ix + idx_x] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x] : 0.0;
|
||||
} }
|
||||
acc += valid * input[B*groups*cin*iy*ix + g*cin*iy*ix + ci*iy*ix + clamp(idx_y, 0, iy-1)*ix + clamp(idx_x, 0, ix-1)] * \
|
||||
weight[g*rcout*cin*H*W + c*cin*H*W + ci*H*W + y*W + x];
|
||||
#endif
|
||||
} }
|
||||
}
|
||||
output[gid] = acc;
|
||||
}""",
|
||||
options=tuple(["-DONEBYONE"]) if C.H == 1 and C.W == 1 and C.px == 0 and C.py == 0 else tuple(),
|
||||
argdtypes=tuple([None, None, None] + [np.int32]*16))
|
||||
conv_prg([C.bs*C.cout, C.oy, C.ox], None, x.cl, w.cl, ret.cl,
|
||||
*[x for x in list(C[0:12])+[C.dx, C.dy, C.px, C.py]])
|
||||
options=tuple(["-DALLVALID"]) if C.px == 0 and C.py == 0 else tuple(),
|
||||
argdtypes=tuple([None, None, None] + [np.int32]*len(params)))
|
||||
conv_prg([C.bs*C.cout, C.oy, C.ox], None, contiguous(x).cl, contiguous(w).cl, ret.cl, *[x[1] for x in params])
|
||||
return ret
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# TODO: move Device to here and proxy buffer call
|
||||
from enum import Enum
|
||||
UnaryOps = Enum("UnaryOps", ["RELU", "EXP", "LOG", "NEG", "SIGN"])
|
||||
UnaryOps = Enum("UnaryOps", ["NOOP", "RELU", "EXP", "LOG", "NEG", "SIGN"])
|
||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "SLICE", "EXPAND", "FLIP"])
|
||||
|
@ -46,9 +46,11 @@ def log_op(optype, op, ret, inp, dashed=False):
|
|||
G.add_edge(nm(x), nm(ret), label=sop, color='#808080' if dashed else '', style='dashed' if dashed else '')
|
||||
if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
st = getattr(ret, "st", None)
|
||||
non_contiguous = st is not None and not st.contiguous
|
||||
G.nodes[nm(ret)]['label'] = str(ret.shape)
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(ret)]['style'] = 'filled'
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype] + ('80' if non_contiguous else '')
|
||||
G.nodes[nm(ret)]['style'] = 'filled, dashed' if non_contiguous else 'filled'
|
||||
|
||||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
|
|
Loading…
Reference in New Issue