mirror of https://github.com/commaai/tinygrad.git
break maxpool2d on GPU
This commit is contained in:
parent
061e37de39
commit
02655c07d5
|
@ -111,9 +111,9 @@ You need to support 15 basic ops:
|
|||
Add, Sub, Mul, Pow # binary ops (with broadcasting)
|
||||
Relu, Log, Exp # unary ops
|
||||
Sum, Max # reduce ops (with axis argument)
|
||||
Dot # matrix multiplication
|
||||
Conv2D, MaxPool2D # 2D ops
|
||||
Pad2D, Reshape, Transpose # moving things around ops
|
||||
Dot, Conv2D # matrix multiplication and conv
|
||||
Reshape, Transpose # moving things around ops
|
||||
Unpad2D, Pad2D # stupid slices
|
||||
```
|
||||
|
||||
## ImageNet inference
|
||||
|
|
|
@ -164,6 +164,7 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device)
|
||||
|
||||
@cpu_only
|
||||
def test_maxpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
|
|
|
@ -105,18 +105,29 @@ register('dot', Dot)
|
|||
|
||||
# ************* simple ops *************
|
||||
|
||||
# TODO: Combine Pad2D and Unpad2D into something generic
|
||||
class Pad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
ctx.save_for_backward(padding)
|
||||
return np.pad(x, ((0,0), (0,0), tuple(padding[2:4]), tuple(padding[0:2])))
|
||||
return np.pad(x, ((0,0), (0,0), tuple(ctx.padding[2:4]), tuple(ctx.padding[0:2])))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
padding, = ctx.saved_tensors
|
||||
return grad_output[..., padding[2]:-padding[3], padding[0]:-padding[1]]
|
||||
return grad_output[...,
|
||||
ctx.padding[2]:(None if ctx.padding[3] == 0 else -ctx.padding[3]),
|
||||
ctx.padding[0]:(None if ctx.padding[1] == 0 else -ctx.padding[1])]
|
||||
register('pad2d', Pad2D)
|
||||
|
||||
class Unpad2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, padding=None):
|
||||
return Pad2D.backward(ctx, x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return Pad2D.forward(ctx, grad_output)
|
||||
register('unpad2d', Unpad2D)
|
||||
|
||||
class Reshape(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, shape):
|
||||
|
@ -237,36 +248,3 @@ class Conv2D(Function):
|
|||
return gdx.reshape((bs, ctx.groups*cin, OY, OX)), gdw.reshape((ctx.groups*rcout, cin, H, W))
|
||||
register('conv2d', Conv2D)
|
||||
|
||||
|
||||
# ************* pooling ops *************
|
||||
|
||||
def stack_for_pool(x, py, px):
|
||||
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
|
||||
xup = x[:, :, :my, :mx]
|
||||
stack = [xup[:, :, k//px::py, k%px::px][None] for k in range(py*px)]
|
||||
return np.concatenate(stack, axis=0)
|
||||
|
||||
def unstack_for_pool(fxn, s, py, px):
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
for k in range(py*px):
|
||||
Y, X = k//px, k%px
|
||||
ll = fxn(Y*px+X)
|
||||
if X == 0 and Y == 0:
|
||||
ret = np.zeros(s, dtype=ll.dtype)
|
||||
ret[:, :, Y:my:py, X:mx:px] = ll
|
||||
return ret
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
idxs = np.argmax(stack, axis=0)
|
||||
ctx.save_for_backward(idxs, x.shape)
|
||||
return np.max(stack, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
idxs,s = ctx.saved_tensors
|
||||
return unstack_for_pool(lambda idx: grad_output * (idxs == idx), s, *ctx.kernel_size)
|
||||
register('max_pool2d', MaxPool2D)
|
||||
|
||||
|
|
|
@ -14,54 +14,6 @@ def uint2(x, y):
|
|||
return np.array((x,y), dtype=cl.cltypes.uint2)
|
||||
i32 = np.int32
|
||||
|
||||
def subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, decls=''):
|
||||
py, px = stride
|
||||
N, C, Yin, Xin = input.shape
|
||||
Yout, Xout = (Yin-kernel_size[0])//py+1, (Xin-kernel_size[1])//px+1
|
||||
ret = buffer_new(ctx, (N, C, Yout, Xout), zero=True)
|
||||
subsample = clbuild(ctx.cl_ctx, "subsample", """
|
||||
__kernel void subsample(__global float *output, __global const float *input, uint2 osize, uint2 isize,
|
||||
uint2 ksz, uint2 stride) {
|
||||
int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0));
|
||||
int oid = gid.x + osize.x*(gid.y + osize.y*gid.z);
|
||||
"""+decls+""";
|
||||
for (uint j=0; j<ksz.y; ++j) {
|
||||
for (uint i=0; i<ksz.x; ++i) {
|
||||
int iid = (gid.x*stride.x+i) + isize.x*((gid.y*stride.y+j) + isize.y*gid.z);
|
||||
if (gid.x*stride.x+i < isize.x && gid.y*stride.y+j < isize.y) {
|
||||
"""+iter_op+""";
|
||||
}
|
||||
}
|
||||
}
|
||||
output[oid] = """+result_op+""";
|
||||
}""")
|
||||
subsample(ctx.cl_queue, (N*C, Yout, Xout), None,
|
||||
ret.cl, input.cl, uint2(Xout, Yout), uint2(Xin, Yin),
|
||||
uint2(*kernel_size[::-1]), uint2(px, py))
|
||||
ctx.data = np.empty((N, C, Yout, Xout)) # set shape expectation on tensor instance
|
||||
return ret
|
||||
|
||||
def supersample_op(ctx, input, out_shape, kernel_size, result_op, decls='', input2=None):
|
||||
(N, C, Yin, Xin), (Yout, Xout) = input.shape, out_shape[2:]
|
||||
py,px = kernel_size
|
||||
ret = buffer_new(ctx, out_shape, zero=True)
|
||||
supsample = clbuild(ctx.cl_ctx, "supsample", """
|
||||
__kernel void supsample(__global float *output, __global const float *input, __global const void *input2,
|
||||
uint2 osize, uint2 isize, uint2 ksz) {
|
||||
int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0));
|
||||
int oid = gid.x + osize.x*(gid.y + osize.y*gid.z);
|
||||
int iid = (gid.x/ksz.x) + isize.x*((gid.y/ksz.y) + isize.y*gid.z);
|
||||
"""+decls+""";
|
||||
if (gid.x/ksz.x < isize.x && gid.y/ksz.y < isize.y) {
|
||||
output[oid] = """+result_op+""";
|
||||
}
|
||||
}""")
|
||||
supsample(ctx.cl_queue, (N*C, Yout, Xout), None,
|
||||
ret.cl, input.cl, input2.cl if input2 is not None else input2,
|
||||
uint2(Xout, Yout), uint2(Xin, Yin), uint2(px, py))
|
||||
ctx.data = np.empty((N, C, Yout, Xout)) # set shape expectation on tensor instance
|
||||
return ret
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_binop_prg(cl_ctx, code, complist):
|
||||
ndims = len(complist)
|
||||
|
@ -424,26 +376,6 @@ class Exp(Function):
|
|||
return binary_op(ctx, 'a * b', grad_output, ret)
|
||||
register('exp', Exp, device=Device.GPU)
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel_size=(2, 2)):
|
||||
idxs = subsample_op(ctx, input, kernel_size, kernel_size,
|
||||
iter_op="if (input[iid]>maxval) { maxval = input[iid]; maxidx = j * ksz.x + i; }",
|
||||
result_op="(float)maxidx", decls="float maxval=-FLT_MAX; int maxidx=0")
|
||||
ctx.save_for_backward(idxs, input.shape)
|
||||
return subsample_op(ctx, input, kernel_size, kernel_size,
|
||||
iter_op="maxval = max(maxval, input[iid])",
|
||||
result_op="maxval", decls="float maxval = -FLT_MAX")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
idxs, orig_shape = ctx.saved_tensors
|
||||
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size,
|
||||
result_op="(maxidx == kernidx) * input[iid]",
|
||||
decls="int maxidx=((__global float*)input2)[iid]; int kernidx=(gid.x%ksz.x) + ksz.x*(gid.y%ksz.y)",
|
||||
input2=idxs)
|
||||
register('max_pool2d', MaxPool2D, device=Device.GPU)
|
||||
|
||||
# ************* conv ops *************
|
||||
|
||||
class Conv2D(Function):
|
||||
|
|
|
@ -252,6 +252,14 @@ class Tensor:
|
|||
ww[range(chan), 0, :, :] = 1/(kernel_size[0]*kernel_size[1])
|
||||
return self.conv2d(Tensor(ww, device=self.device, requires_grad=False), stride=kernel_size, groups=chan)
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
py, px = kernel_size
|
||||
xup = self.unpad2d(padding=(0, self.shape[3]%px, 0, self.shape[2]%py))
|
||||
xup = xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
|
||||
# TODO: support tuples in max
|
||||
xup = xup.max(axis=5).max(axis=3)
|
||||
return xup
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, *tensors):
|
||||
|
|
Loading…
Reference in New Issue