mirror of https://github.com/commaai/tinygrad.git
Accelerate with CL (#325)
* accelerated opencl * it's running, it's just wrong * bugfix * model is correct in opencl * lazy image convert * add padding support to convolution * that stuff was all upstreamed * remove HEAD * oops * test_simple_conv2d_4 passes, add dilation support * put logic in ops_opencl * fix crash * hmm, stride seems okay * padding for batched inputs * just an issue now with cout%4 * op model still passes * fix startPackedInputChannel * pre and post processing ops for graph * don't break other llops * shapetrackering * reshapes are free * lazy movement ops
This commit is contained in:
parent
bd7068f635
commit
d5b3e18540
|
@ -0,0 +1,102 @@
|
|||
#define NUM_OUTPUTS 4
|
||||
|
||||
__kernel void conv(
|
||||
read_only image2d_t input,
|
||||
read_only image2d_t weights,
|
||||
write_only image2d_t output,
|
||||
short numPackedInputChannelsForGroup,
|
||||
short totalNumPackedInputChannels,
|
||||
short numPackedOutputChannelsForGroup,
|
||||
short totalNumPackedOutputChannels,
|
||||
short numOutputColumns,
|
||||
short numOutputRows, short numInputRows,
|
||||
short filterSizeX, short filterSizeY,
|
||||
short paddingX, short paddingY,
|
||||
short strideX, short strideY,
|
||||
short dilationX, short dilationY) {
|
||||
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
|
||||
float4 outputValues[NUM_OUTPUTS];
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputValues[i] = (float4)(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
short packedOutputChannel = get_global_id(0);
|
||||
int2 weightLocation;
|
||||
weightLocation.x = 0;
|
||||
weightLocation.y = packedOutputChannel;
|
||||
|
||||
short groupNum = (packedOutputChannel / numPackedOutputChannelsForGroup);
|
||||
short startPackedInputChannel = mul24(groupNum, numPackedInputChannelsForGroup);
|
||||
short startOutputColumn = mul24((short)get_global_id(1), NUM_OUTPUTS);
|
||||
short startX = mad24(mad24(startOutputColumn, strideX, -paddingX), totalNumPackedInputChannels, startPackedInputChannel);
|
||||
short strideWithChannels = mul24(strideX, totalNumPackedInputChannels);
|
||||
|
||||
short outputRow = get_global_id(2);
|
||||
int2 inputLocation;
|
||||
|
||||
#ifdef BATCH
|
||||
// TODO: this doesn't work with y padding
|
||||
inputLocation.y = mad24(outputRow % numOutputRows, strideY, -paddingY);
|
||||
short batchOffset = (outputRow / numOutputRows) * numInputRows;
|
||||
inputLocation.y += batchOffset;
|
||||
#else
|
||||
inputLocation.y = mad24(outputRow, strideY, -paddingY);
|
||||
#endif
|
||||
|
||||
for (short rfRow = 0; rfRow < filterSizeY; ++rfRow) {
|
||||
// numPackedInputChannelsForGroup is 1 in depthwise
|
||||
for (short packedInputChannel = 0; packedInputChannel < numPackedInputChannelsForGroup; ++packedInputChannel) {
|
||||
short startXForChannel = startX + packedInputChannel;
|
||||
for (short rfColumn = 0; rfColumn < filterSizeX; ++rfColumn) {
|
||||
|
||||
short dilatedStepX = mul24(totalNumPackedInputChannels, dilationX);
|
||||
inputLocation.x = mad24(rfColumn, dilatedStepX, startXForChannel);
|
||||
float4 inputValues[NUM_OUTPUTS];
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
inputValues[i] = read_imagef(input, smp, inputLocation);
|
||||
inputLocation.x += strideWithChannels;
|
||||
}
|
||||
|
||||
#ifdef DEPTHWISE
|
||||
float4 weightValues = read_imagef(weights, smp, weightLocation);
|
||||
++weightLocation.x;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputValues[i] += inputValues[i] * weightValues;
|
||||
}
|
||||
#else
|
||||
float4 weightValues[4];
|
||||
for (short outChIdx = 0; outChIdx < 4; ++outChIdx) {
|
||||
weightValues[outChIdx] = read_imagef(weights, smp, weightLocation);
|
||||
++weightLocation.x;
|
||||
}
|
||||
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
float4 curOutputValues = outputValues[i];
|
||||
curOutputValues.x += dot(inputValues[i], weightValues[0]);
|
||||
curOutputValues.y += dot(inputValues[i], weightValues[1]);
|
||||
curOutputValues.z += dot(inputValues[i], weightValues[2]);
|
||||
curOutputValues.w += dot(inputValues[i], weightValues[3]);
|
||||
outputValues[i] = curOutputValues;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
inputLocation.y += dilationY;
|
||||
}
|
||||
|
||||
// insert unary and binary ops here
|
||||
|
||||
// output to memory
|
||||
int2 outputLocation;
|
||||
short outputColumn = startOutputColumn;
|
||||
outputLocation.y = outputRow;
|
||||
for (short i = 0; i < NUM_OUTPUTS; ++i) {
|
||||
outputLocation.x = mad24(outputColumn, totalNumPackedOutputChannels, packedOutputChannel);
|
||||
if (outputColumn < numOutputColumns) {
|
||||
write_imagef(output, outputLocation, outputValues[i]);
|
||||
}
|
||||
++outputColumn;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,228 @@
|
|||
# this is focused on speed
|
||||
# it may not run everything
|
||||
|
||||
import pathlib
|
||||
import numpy as np
|
||||
from tinygrad.ops import MovementOps, ProcessingOps
|
||||
from tinygrad.llops.ops_gpu import require_init_gpu, clbuild, sync, get_cl_queue, get_cl_ctx
|
||||
from tinygrad.llops.ops_gpu import contiguous
|
||||
from tinygrad.llops.ops_gpu import unary_op as unary_op_gpu, binary_op as binary_op_gpu, reduce_op as reduce_op_gpu
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shapetracker import ShapeTracker
|
||||
import pyopencl as cl
|
||||
from copy import deepcopy
|
||||
|
||||
def roundup(x, n=4): return (x+(n-1))//n * n
|
||||
def flip(x): return (x[1], x[0])
|
||||
class OpenCLBuffer:
|
||||
def __init__(self, shape, hostbuf=None, _buf=None, _image=None):
|
||||
require_init_gpu()
|
||||
self.shapetracker = deepcopy(shape) if isinstance(shape, ShapeTracker) else ShapeTracker(*shape)
|
||||
self._buf = _buf
|
||||
self._image = _image
|
||||
self.dtype = np.float32
|
||||
if hostbuf is not None:
|
||||
# TODO: lazy?
|
||||
self._buf = cl.Buffer(get_cl_ctx(), cl.mem_flags.READ_WRITE, 4*roundup(prod(shape)))
|
||||
cl.enqueue_copy(get_cl_queue(), self._buf, hostbuf.astype(np.float32).ravel())
|
||||
|
||||
def clone(self):
|
||||
return OpenCLBuffer(self.shapetracker, _buf=self._buf, _image=self._image)
|
||||
|
||||
@property
|
||||
def shape(self): return self.shapetracker.shape
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x):
|
||||
return OpenCLBuffer(x.shape, x)
|
||||
|
||||
def toCPU(self):
|
||||
data = np.empty(self.shape, dtype=np.float32)
|
||||
if self.shapetracker.contiguous == False:
|
||||
tmp = OpenCLBuffer(self.shapetracker.shape)
|
||||
contiguous(None, self, self.shapetracker, tmp)
|
||||
else:
|
||||
tmp = self
|
||||
cl.enqueue_copy(get_cl_queue(), data, tmp.cl, is_blocking=True)
|
||||
return data
|
||||
|
||||
@property
|
||||
def cl(self):
|
||||
if self._buf is None:
|
||||
self._buf = cl.Buffer(get_cl_ctx(), cl.mem_flags.READ_WRITE, 4*roundup(prod(self.shape)))
|
||||
if self._image is not None:
|
||||
assert prod(self.shape) == prod(self._image.shape)*4
|
||||
print(f"converting {self.shape} back to buffer, image shape is {self._image.shape}")
|
||||
clbuild("from_image", """
|
||||
__kernel void from_image(
|
||||
read_only image2d_t in,
|
||||
__global float4 *out) {
|
||||
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
int2 l;
|
||||
l.y = get_global_id(1);
|
||||
l.x = get_global_id(0);
|
||||
int W = get_image_width(in);
|
||||
out[l.y*W + l.x] = read_imagef(in, smp, l);
|
||||
}
|
||||
""")(self._image.shape, None, self._image, self._buf)
|
||||
self._image = None
|
||||
return self._buf
|
||||
|
||||
@property
|
||||
def image(self):
|
||||
if self._image is None:
|
||||
assert self.shape[2] == 4 and len(self.shape) == 3
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
|
||||
self._image = cl.Image(get_cl_ctx(), cl.mem_flags.READ_WRITE, fmt, shape=flip(self.shape))
|
||||
if self._buf is not None:
|
||||
assert prod(self.shape) == prod(self._image.shape)*4
|
||||
print(f"converting {self.shape} to image with shape {self._image.shape}")
|
||||
clbuild("to_image", """
|
||||
__kernel void to_image(
|
||||
__global const float4 *in,
|
||||
write_only image2d_t out) {
|
||||
int2 l;
|
||||
l.y = get_global_id(1);
|
||||
l.x = get_global_id(0);
|
||||
int W = get_image_width(out);
|
||||
write_imagef(out, l, in[l.y*W + l.x]);
|
||||
}
|
||||
""")(self._image.shape, None, self._buf, self._image)
|
||||
self._buf = None
|
||||
return self._image
|
||||
|
||||
def unary_op(ctx, op, x):
|
||||
# TODO: this doesn't actually have to be contiguous
|
||||
x = contiguous(ctx, x, x.shapetracker) if not x.shapetracker.contiguous else x
|
||||
return unary_op_gpu(ctx, op, x)
|
||||
|
||||
def binary_op(ctx, op, x, y):
|
||||
x = contiguous(ctx, x, x.shapetracker) if not x.shapetracker.contiguous else x
|
||||
y = contiguous(ctx, y, y.shapetracker) if not y.shapetracker.contiguous else y
|
||||
return binary_op_gpu(ctx, op, x, y)
|
||||
|
||||
def reduce_op(ctx, op, x, new_shape):
|
||||
x = contiguous(ctx, x, x.shapetracker) if not x.shapetracker.contiguous else x
|
||||
return reduce_op_gpu(ctx, op, x, new_shape)
|
||||
|
||||
def movement_op(ctx, op, x, arg=None):
|
||||
xc = x.clone()
|
||||
# convert from image if the buffer can change shape
|
||||
if op in [MovementOps.EXPAND, MovementOps.SLICE]: xc.cl
|
||||
xc.shapetracker.movement_op(op, arg)
|
||||
return xc
|
||||
|
||||
def load(x):
|
||||
with open(x) as f:
|
||||
ret = f.read()
|
||||
return ret
|
||||
|
||||
def conv(x,w,ret,C):
|
||||
print(x.shapetracker.expr(), w.shapetracker.expr())
|
||||
print(x.shape, w.shape, ret.shape)
|
||||
options = []
|
||||
if C.cin == 1: options.append("-DDEPTHWISE")
|
||||
if C.bs > 1:
|
||||
options.append("-DBATCH")
|
||||
assert C.py == 0, "batched conv doesn't work with y-padding"
|
||||
conv_prg = clbuild("conv", load(pathlib.Path(__file__).parent.parent.parent / 'accel/opencl/conv.cl'), tuple(options))
|
||||
assert C.cout%4 == 0
|
||||
kernel_args = [C.cout//4, (C.ox+3)//4, C.bs*C.oy]
|
||||
conv_args = [max(1, C.cin//4), C.groups*C.cin//4, max(1, C.rcout//4), C.cout//4, C.ox, C.oy, C.iy, C.W, C.H, C.px, C.py, C.xs, C.ys, C.dx, C.dy]
|
||||
print(conv_args, kernel_args)
|
||||
conv_prg(kernel_args, None, x.image, w.image, ret.image, *[np.int16(x) for x in conv_args])
|
||||
|
||||
|
||||
def processing_op(ctx,op,x,w,out_shape,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
ret = ctx.buffer((C.bs*C.oy, C.ox*C.cout//4, 4))
|
||||
conv(x, w, ret, C)
|
||||
return ret
|
||||
|
||||
# input format is N, H x W, C//4 x 4
|
||||
# dweight format is oc//4 x ch, cw x 4(oc)
|
||||
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
|
||||
def preprocessing_op(ctx,op,x,w,out_shape,C):
|
||||
assert op == ProcessingOps.CONV, f"{op} isn't supported"
|
||||
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
w = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
|
||||
print(x.shape, w.shape)
|
||||
|
||||
if C.bs > 1 and C.py > 0:
|
||||
# explictly add y-padding for batched inputs
|
||||
# N C H W
|
||||
xs = [(0, s) for s in x.shape]
|
||||
xs[3] = (-C.py, x.shape[3]+C.py)
|
||||
x = ctx.movement_op(MovementOps.SLICE, x, xs)
|
||||
C = C._replace(iy=C.iy + C.py*2, py=0)
|
||||
|
||||
# hack for non multiples of 4 on C.cin
|
||||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
ws = [(0, s) for s in w.shape]
|
||||
ws[2] = (0, w.shape[2]+to_add)
|
||||
w = ctx.movement_op(MovementOps.SLICE, w, ws)
|
||||
|
||||
xs = [(0, s) for s in x.shape]
|
||||
xs[2] = (0, x.shape[2]+to_add)
|
||||
x = ctx.movement_op(MovementOps.SLICE, x, xs)
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
added_output_channels = 4 - (C.rcout % 4)
|
||||
ws = [(0, s) for s in w.shape]
|
||||
ws[1] = (0, w.shape[1]+added_output_channels)
|
||||
w = ctx.movement_op(MovementOps.SLICE, w, ws)
|
||||
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
|
||||
|
||||
# packed
|
||||
assert (C.groups*C.cin) % 4 == 0
|
||||
print(x.shape)
|
||||
x = ctx.movement_op(MovementOps.PERMUTE, x, (0,3,4,1,2))
|
||||
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
|
||||
|
||||
assert C.cout % 4 == 0
|
||||
if C.cin == 1:
|
||||
# depthwise
|
||||
w = ctx.movement_op(MovementOps.RESHAPE, w, (C.cout//4,4,C.H*C.W))
|
||||
w = ctx.movement_op(MovementOps.PERMUTE, w, (0,2,1))
|
||||
else:
|
||||
w = ctx.movement_op(MovementOps.RESHAPE, w, (C.cout//4,4,C.cin//4,4,C.H,C.W))
|
||||
w = ctx.movement_op(MovementOps.PERMUTE, w, (0,4,2,5,1,3))
|
||||
w = ctx.movement_op(MovementOps.RESHAPE, w, (C.cout//4, C.H * C.cin//4 * C.W * 4, 4))
|
||||
|
||||
x = contiguous(ctx, x, x.shapetracker) if not x.shapetracker.contiguous else x
|
||||
w = contiguous(ctx, w, w.shapetracker) if not w.shapetracker.contiguous else w
|
||||
return x,w,C
|
||||
|
||||
def postprocessing_op(ctx, op, ret, out_shape, C):
|
||||
added_output_channels = C.rcout - out_shape[1]//C.groups
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ctx.movement_op(MovementOps.RESHAPE, ret, (C.bs, C.oy, C.ox, C.groups, C.rcout))
|
||||
xs = [(0, s) for s in ret.shape]
|
||||
xs[4] = (0, ret.shape[4]-added_output_channels)
|
||||
ret = ctx.movement_op(MovementOps.SLICE, ret, xs)
|
||||
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
|
||||
|
||||
ret = ctx.movement_op(MovementOps.RESHAPE, ret, (C.bs, C.oy, C.ox, C.cout))
|
||||
ret = ctx.movement_op(MovementOps.PERMUTE, ret, (0,3,1,2))
|
||||
return ret
|
||||
|
||||
def test_image():
|
||||
hostbuf = np.random.randn(5,8,4).astype(np.float32)
|
||||
x = OpenCLBuffer((5,8,4), hostbuf)
|
||||
assert np.allclose(x.toCPU(), hostbuf)
|
||||
print(x.image)
|
||||
assert np.allclose(x.toCPU(), hostbuf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_image()
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,15 +1,38 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestConv(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
x = Tensor.ones(1,12,128,256)
|
||||
w = Tensor.ones(32,12,3,3)
|
||||
ret = x.conv2d(w, padding=(1,1)).numpy()
|
||||
ret = x.conv2d(w, stride=(2,2), padding=(1,1)).numpy()
|
||||
# it's not 108 around the padding
|
||||
assert (ret[:, :, 1:-1, 1:-1] == 108).all()
|
||||
assert ret[0,0,0,0] == 48
|
||||
assert ret[0,0,0,1] == 72
|
||||
|
||||
def test_many_simple(self):
|
||||
x = Tensor(np.arange(8*2*8).reshape(1,8,2,8).astype(np.float32))
|
||||
#w = Tensor(np.arange(8*8*1*1).reshape(8,8,1,1).astype(np.float32))
|
||||
w = Tensor.eye(8).reshape((8,8,1,1))
|
||||
ret = x.conv2d(w, stride=(1,2), padding=(0,0)).numpy()
|
||||
print(ret)
|
||||
|
||||
def test_first_three(self):
|
||||
x = Tensor.ones(1,12,128,256)
|
||||
|
||||
w = Tensor.ones(32,12,3,3)
|
||||
x = x.conv2d(w, stride=(2,2), padding=(1,1))
|
||||
|
||||
w = Tensor.ones(32,1,3,3)
|
||||
x = x.conv2d(w, padding=(1,1), groups=32)
|
||||
|
||||
w = Tensor.ones(16,32,1,1)
|
||||
x = x.conv2d(w)
|
||||
|
||||
x = x.numpy()
|
||||
print(x.shape)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -183,6 +183,11 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_simple_conv2d_batched(self):
|
||||
helper_test_op([(2,4,9,9), (4,4,3,3)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5, forward_only=True)
|
||||
|
||||
def test_conv2d(self):
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
|
|
|
@ -43,8 +43,8 @@ class GPUBuffer:
|
|||
return data
|
||||
|
||||
@functools.lru_cache
|
||||
def clbuild(name, prg):
|
||||
clprg = cl.Program(cl_ctx, prg).build().__getattr__(name)
|
||||
def clbuild(name, prg, options=tuple()):
|
||||
clprg = cl.Program(cl_ctx, prg).build(options=options).__getattr__(name)
|
||||
def run(*args): clprg(cl_queue, *args)
|
||||
return run
|
||||
|
||||
|
@ -119,8 +119,8 @@ def reduce_op(ctx, op, inp, new_shape):
|
|||
clbuild("reduce", prg)([prod(ret.shape)], None, inp.cl, ret.cl)
|
||||
return ret
|
||||
|
||||
def contiguous(ctx, x, st):
|
||||
ret = ctx.buffer(st.shape)
|
||||
def contiguous(ctx, x, st, ret=None):
|
||||
if ret is None: ret = ctx.buffer(st.shape)
|
||||
clbuild("contiguous", """__kernel void contiguous(__global const float *x, __global float *ret) {
|
||||
int gid = get_global_id(0); int valid = 1; int idx = gid; """+st.expr().replace('//', '/')+""";
|
||||
ret[gid] = valid ? x[idx] : 0.0; // should never be out-of-bounds accesses
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
../../accel/opencl/ops_opencl.py
|
|
@ -44,37 +44,39 @@ def log_op(op, ret, inp):
|
|||
class Ops:
|
||||
def unary_op(ctx, op:UnaryOps, x):
|
||||
ret = ctx.op.unary_op(ctx, op, x)
|
||||
log_op(op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def reduce_op(ctx, op:ReduceOps, x, new_shape):
|
||||
ret = ctx.op.reduce_op(ctx, op, x, new_shape)
|
||||
log_op(op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == tuple(new_shape)
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def binary_op(ctx, op:BinaryOps, x, y):
|
||||
assert x.shape == y.shape
|
||||
ret = ctx.op.binary_op(ctx, op, x, y)
|
||||
log_op(op, ret, [x, y])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == x.shape
|
||||
log_op(op, ret, [x, y])
|
||||
return ret
|
||||
|
||||
def movement_op(ctx, op:MovementOps, x, arg=None):
|
||||
ret = ctx.op.movement_op(ctx, op, x, arg)
|
||||
log_op(op, ret, [x])
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == ShapeTracker(*x.shape).movement_op(op, arg).shape
|
||||
log_op(op, ret, [x])
|
||||
return ret
|
||||
|
||||
def processing_op(ctx, op:ProcessingOps, x, y, out_shape, C):
|
||||
# TODO: can we do better than out_shape?
|
||||
if getattr(ctx.op, "preprocessing_op", None) is not None: x,y,C = ctx.op.preprocessing_op(ctx, op, x, y, out_shape, C)
|
||||
ret = ctx.op.processing_op(ctx, op, x, y, out_shape, C)
|
||||
log_op(op, ret, [x, y])
|
||||
if getattr(ctx.op, "postprocessing_op", None) is not None: ret = ctx.op.postprocessing_op(ctx, op, ret, out_shape, C)
|
||||
assert isinstance(ret, ctx.buffer)
|
||||
assert ret.shape == out_shape
|
||||
log_op(op, ret, [x, y])
|
||||
return ret
|
|
@ -54,8 +54,10 @@ def strides_for_shape(shape):
|
|||
# TODO: support simplification across views
|
||||
class ShapeTracker:
|
||||
def __init__(self, *shape, strides=None):
|
||||
assert all([isinstance(x, int) for x in shape])
|
||||
if len(shape) == 0: shape = (1,)
|
||||
self.views = [View(shape, strides_for_shape(shape) if strides == None else strides)]
|
||||
self.contiguous = True
|
||||
|
||||
@property
|
||||
def shape(self): return tuple(self.views[-1].shape)
|
||||
|
@ -74,18 +76,21 @@ class ShapeTracker:
|
|||
self.views.append(View(new_shape, strides_for_shape(new_shape)))
|
||||
|
||||
def permute(self, *axis):
|
||||
self.contiguous = False
|
||||
assert all([isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis])
|
||||
assert len(set(axis)) == len(axis)
|
||||
assert len(set(axis)) == len(axis) and len(axis) == len(self.shape)
|
||||
strides = strides_for_shape(self.shape)
|
||||
self.views.append(View([self.shape[a] for a in axis], [strides[a] for a in axis]))
|
||||
|
||||
def slice(self, *arg):
|
||||
self.contiguous = False
|
||||
assert len(arg) == len(self.shape)
|
||||
strides = strides_for_shape(self.shape)
|
||||
offset = sum([strides[i]*x for i,(x,_) in enumerate(arg)])
|
||||
self.views += [View([y-x for x,y in arg], strides, offset), ZeroView(self.shape, arg)]
|
||||
|
||||
def expand(self, *new_shape):
|
||||
self.contiguous = False
|
||||
assert all([isinstance(x, int) for x in new_shape])
|
||||
assert all([x == y or x == 1 for x,y in zip(self.shape, new_shape)])
|
||||
strides = [s if x == y else 0 for s,(x,y) in zip(strides_for_shape(self.shape), zip(self.shape, new_shape))]
|
||||
|
@ -93,6 +98,7 @@ class ShapeTracker:
|
|||
|
||||
# TODO: combine with slice? this doesn't require a ZeroView, though slice shouldn't always either
|
||||
def stride(self, *mul):
|
||||
self.contiguous = False
|
||||
assert all([isinstance(x, int) for x in mul])
|
||||
old_strides = strides_for_shape(self.shape)
|
||||
strides = [z*m for z,m in zip(old_strides, mul)]
|
||||
|
@ -101,5 +107,6 @@ class ShapeTracker:
|
|||
self.views.append(View(new_shape, strides, offset))
|
||||
|
||||
# TODO: this is a special case of slice with strides, remove it
|
||||
# though it's nice that it can't change size
|
||||
def flip(self, *axis):
|
||||
self.stride(*[-1 if i in axis else 1 for i in range(len((self.shape)))])
|
||||
|
|
Loading…
Reference in New Issue