diff --git a/README.md b/README.md index 93d6af90..22a6be19 100644 --- a/README.md +++ b/README.md @@ -111,9 +111,9 @@ You need to support 15 basic ops: Add, Sub, Mul, Pow # binary ops (with broadcasting) Relu, Log, Exp # unary ops Sum, Max # reduce ops (with axis argument) -Dot # matrix multiplication -Conv2D, MaxPool2D # 2D ops -Pad2D, Reshape, Transpose # moving things around ops +Dot, Conv2D # matrix multiplication and conv +Reshape, Transpose # moving things around ops +Unpad2D, Pad2D # stupid slices ``` ## ImageNet inference diff --git a/test/test_ops.py b/test/test_ops.py index d8350061..bcf18dd9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -164,6 +164,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(), lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), device=self.device) + @cpu_only def test_maxpool2d(self): for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): diff --git a/tinygrad/ops_cpu.py b/tinygrad/ops_cpu.py index 97078d82..891ab255 100644 --- a/tinygrad/ops_cpu.py +++ b/tinygrad/ops_cpu.py @@ -105,18 +105,29 @@ register('dot', Dot) # ************* simple ops ************* +# TODO: Combine Pad2D and Unpad2D into something generic class Pad2D(Function): @staticmethod def forward(ctx, x, padding=None): - ctx.save_for_backward(padding) - return np.pad(x, ((0,0), (0,0), tuple(padding[2:4]), tuple(padding[0:2]))) + return np.pad(x, ((0,0), (0,0), tuple(ctx.padding[2:4]), tuple(ctx.padding[0:2]))) @staticmethod def backward(ctx, grad_output): - padding, = ctx.saved_tensors - return grad_output[..., padding[2]:-padding[3], padding[0]:-padding[1]] + return grad_output[..., + ctx.padding[2]:(None if ctx.padding[3] == 0 else -ctx.padding[3]), + ctx.padding[0]:(None if ctx.padding[1] == 0 else -ctx.padding[1])] register('pad2d', Pad2D) +class Unpad2D(Function): + @staticmethod + def forward(ctx, x, padding=None): + return Pad2D.backward(ctx, x) + + @staticmethod + def backward(ctx, grad_output): + return Pad2D.forward(ctx, grad_output) +register('unpad2d', Unpad2D) + class Reshape(Function): @staticmethod def forward(ctx, x, shape): @@ -237,36 +248,3 @@ class Conv2D(Function): return gdx.reshape((bs, ctx.groups*cin, OY, OX)), gdw.reshape((ctx.groups*rcout, cin, H, W)) register('conv2d', Conv2D) - -# ************* pooling ops ************* - -def stack_for_pool(x, py, px): - my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px - xup = x[:, :, :my, :mx] - stack = [xup[:, :, k//px::py, k%px::px][None] for k in range(py*px)] - return np.concatenate(stack, axis=0) - -def unstack_for_pool(fxn, s, py, px): - my, mx = (s[2]//py)*py, (s[3]//px)*px - for k in range(py*px): - Y, X = k//px, k%px - ll = fxn(Y*px+X) - if X == 0 and Y == 0: - ret = np.zeros(s, dtype=ll.dtype) - ret[:, :, Y:my:py, X:mx:px] = ll - return ret - -class MaxPool2D(Function): - @staticmethod - def forward(ctx, x, kernel_size=(2, 2)): - stack = stack_for_pool(x, *kernel_size) - idxs = np.argmax(stack, axis=0) - ctx.save_for_backward(idxs, x.shape) - return np.max(stack, axis=0) - - @staticmethod - def backward(ctx, grad_output): - idxs,s = ctx.saved_tensors - return unstack_for_pool(lambda idx: grad_output * (idxs == idx), s, *ctx.kernel_size) -register('max_pool2d', MaxPool2D) - diff --git a/tinygrad/ops_gpu.py b/tinygrad/ops_gpu.py index d366dbe2..c7c2ce87 100644 --- a/tinygrad/ops_gpu.py +++ b/tinygrad/ops_gpu.py @@ -14,54 +14,6 @@ def uint2(x, y): return np.array((x,y), dtype=cl.cltypes.uint2) i32 = np.int32 -def subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, decls=''): - py, px = stride - N, C, Yin, Xin = input.shape - Yout, Xout = (Yin-kernel_size[0])//py+1, (Xin-kernel_size[1])//px+1 - ret = buffer_new(ctx, (N, C, Yout, Xout), zero=True) - subsample = clbuild(ctx.cl_ctx, "subsample", """ - __kernel void subsample(__global float *output, __global const float *input, uint2 osize, uint2 isize, - uint2 ksz, uint2 stride) { - int3 gid = (int3)(get_global_id(2), get_global_id(1), get_global_id(0)); - int oid = gid.x + osize.x*(gid.y + osize.y*gid.z); - """+decls+"""; - for (uint j=0; j