From c21c2a0b62007c79e71dd5cbdb25a5d1b58adddb Mon Sep 17 00:00:00 2001 From: Ryan Neph Date: Mon, 9 Nov 2020 14:58:18 -0800 Subject: [PATCH] revert b0c0c5d: Strided Pool funcs (#74) (#87) Strided CPU Pooling was introduced but assumes small kernel size (<=(10,10)), but efficientnet.py feeds kernel_size=(112,112). This causes a huge array buffer allocation in stack_for_pool() that hangs inference for a long time or until system OOM. Revert CPU Pooling for now, and re-introduce #74 later with a new global-average-pooling op that can be used instead of avgpool2d with large kernel size for efficientnet inference. Co-authored-by: Ryan Neph --- test/test_ops.py | 18 +------------- tinygrad/ops.py | 64 ++++++++++++++++++++---------------------------- 2 files changed, 27 insertions(+), 55 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 673f5b57..6ef1da33 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -97,29 +97,13 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu, forward_only=self.gpu) - def test_maxpool2d_strided_fwd(self): - for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: - for strd in [(1,1), (2,1), (2,2), (4,2)]: - with self.subTest(kernel_size=ksz, stride=strd): - helper_test_op([(32,2,110,28)], - lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, stride=strd), - lambda x: Tensor.max_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=True) - def test_avgpool2d(self): - for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: + for ksz in [(2,2), ]:#, (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu) - def test_avgpool2d_strided_fwd(self): - for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: - for strd in [(1,1), (2,1), (2,2), (4,2)]: - with self.subTest(kernel_size=ksz, stride=strd): - helper_test_op([(32,2,111,28)], - lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, stride=strd), - lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=True) - if GPU: class TestOpsGPU(TestOps): gpu = True diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ca0126c6..52e1f212 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,3 +1,4 @@ +import sys import warnings import numpy as np from .tensor import Function, register @@ -237,67 +238,54 @@ register('conv2d', Conv2D) # ************* pooling ops ************* -def stack_for_pool(x, kernel_size, stride, fill_value=0): - (ky, kx), (py, px) = kernel_size, stride - my, mx = (x.shape[2]-ky)//py+1, (x.shape[3]-kx)//px+1 - stack = fill_value*np.ones((ky, kx, *x.shape[:2], my+ky, mx+kx), dtype=x.dtype) - for Y in range(ky): - for X in range(kx): - sl = x[..., Y:Y+my*py+ky:py, X:X+mx*px+kx:px] - stack[Y, X, ..., :sl.shape[2], :sl.shape[3]] = sl - return stack.reshape(-1, *stack.shape[2:]), (my, mx) +def stack_for_pool(x, py, px): + my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px + stack = [] + xup = x[:, :, :my, :mx] + for Y in range(py): + for X in range(px): + stack.append(xup[:, :, Y::py, X::px][None]) + return np.concatenate(stack, axis=0) -def unstack_for_pool(fxn, s, kernel_size, stride): - (ky, kx), (py, px) = kernel_size, stride - for Y in range(ky): - for X in range(kx): - ll = fxn(Y*kx+X) +def unstack_for_pool(fxn, s, py, px): + my, mx = (s[2]//py)*py, (s[3]//px)*px + for Y in range(py): + for X in range(px): + ll = fxn(Y*px+X) if X == 0 and Y == 0: - ret = np.zeros((*s[:2], s[2]+ky, s[3]+kx), dtype=ll.dtype) - ret[..., Y:Y+ll.shape[2]*py:py, X:X+ll.shape[3]*px:px] = ll - return ret[..., :s[2], :s[3]] + ret = np.zeros(s, dtype=ll.dtype) + ret[:, :, Y:my:py, X:mx:px] = ll + return ret class MaxPool2D(Function): @staticmethod - def forward(ctx, x, kernel_size=(2, 2), stride=None): - if not stride: - ctx.stride = stride = kernel_size - stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=-np.inf) - idxs = np.nanargmax(stack, axis=0)[..., :my, :mx] + def forward(ctx, x, kernel_size=(2, 2)): + stack = stack_for_pool(x, *kernel_size) + idxs = np.argmax(stack, axis=0) ctx.save_for_backward(idxs, x.shape) - return np.amax(stack, axis=0)[..., :my, :mx] + return np.max(stack, axis=0) @staticmethod def backward(ctx, grad_output): - # TODO implement for stride != kernel_size - if ctx.kernel_size != ctx.stride: - raise NotImplementedError("CPU MaxPool2D.backward() with stride != kernel_size not implemented") idxs,s = ctx.saved_tensors return unstack_for_pool( lambda idx: grad_output * (idxs == idx), - s, ctx.kernel_size, ctx.stride) + s, *ctx.kernel_size) register('max_pool2d', MaxPool2D) class AvgPool2D(Function): @staticmethod - def forward(ctx, x, kernel_size=(2, 2), stride=None): - if not stride: - ctx.stride = stride = kernel_size - stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=np.nan) + def forward(ctx, x, kernel_size=(2, 2)): + stack = stack_for_pool(x, *kernel_size) ctx.save_for_backward(x.shape) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return np.nanmean(stack, axis=0)[...,:my, :mx] + return np.mean(stack, axis=0) @staticmethod def backward(ctx, grad_output): - # TODO implement for stride != kernel_size - if ctx.kernel_size != ctx.stride: - raise NotImplementedError("CPU AvgPool2D.backward() with stride != kernel_size not implemented") s, = ctx.saved_tensors py, px = ctx.kernel_size return unstack_for_pool( lambda idx: grad_output/py/px, - s, ctx.kernel_size, ctx.stride) + s, py, px) register('avg_pool2d', AvgPool2D)