mirror of https://github.com/commaai/tinygrad.git
Strided CPU Pooling was introduced but assumes small kernel size (<=(10,10)), but efficientnet.py feeds kernel_size=(112,112). This causes a huge array buffer allocation in stack_for_pool() that hangs inference for a long time or until system OOM. Revert CPU Pooling for now, and re-introduce #74 later with a new global-average-pooling op that can be used instead of avgpool2d with large kernel size for efficientnet inference. Co-authored-by: Ryan Neph <ryanneph@google.com>
This commit is contained in:
parent
53157fb876
commit
c21c2a0b62
|
@ -97,29 +97,13 @@ class TestOps(unittest.TestCase):
|
|||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_maxpool2d_strided_fwd(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
for strd in [(1,1), (2,1), (2,2), (4,2)]:
|
||||
with self.subTest(kernel_size=ksz, stride=strd):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, stride=strd),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=True)
|
||||
|
||||
def test_avgpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
for ksz in [(2,2), ]:#, (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu)
|
||||
|
||||
def test_avgpool2d_strided_fwd(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
for strd in [(1,1), (2,1), (2,2), (4,2)]:
|
||||
with self.subTest(kernel_size=ksz, stride=strd):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, stride=strd),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, stride=strd), gpu=self.gpu, forward_only=True)
|
||||
|
||||
if GPU:
|
||||
class TestOpsGPU(TestOps):
|
||||
gpu = True
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import sys
|
||||
import warnings
|
||||
import numpy as np
|
||||
from .tensor import Function, register
|
||||
|
@ -237,67 +238,54 @@ register('conv2d', Conv2D)
|
|||
|
||||
# ************* pooling ops *************
|
||||
|
||||
def stack_for_pool(x, kernel_size, stride, fill_value=0):
|
||||
(ky, kx), (py, px) = kernel_size, stride
|
||||
my, mx = (x.shape[2]-ky)//py+1, (x.shape[3]-kx)//px+1
|
||||
stack = fill_value*np.ones((ky, kx, *x.shape[:2], my+ky, mx+kx), dtype=x.dtype)
|
||||
for Y in range(ky):
|
||||
for X in range(kx):
|
||||
sl = x[..., Y:Y+my*py+ky:py, X:X+mx*px+kx:px]
|
||||
stack[Y, X, ..., :sl.shape[2], :sl.shape[3]] = sl
|
||||
return stack.reshape(-1, *stack.shape[2:]), (my, mx)
|
||||
def stack_for_pool(x, py, px):
|
||||
my, mx = (x.shape[2]//py)*py, (x.shape[3]//px)*px
|
||||
stack = []
|
||||
xup = x[:, :, :my, :mx]
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
stack.append(xup[:, :, Y::py, X::px][None])
|
||||
return np.concatenate(stack, axis=0)
|
||||
|
||||
def unstack_for_pool(fxn, s, kernel_size, stride):
|
||||
(ky, kx), (py, px) = kernel_size, stride
|
||||
for Y in range(ky):
|
||||
for X in range(kx):
|
||||
ll = fxn(Y*kx+X)
|
||||
def unstack_for_pool(fxn, s, py, px):
|
||||
my, mx = (s[2]//py)*py, (s[3]//px)*px
|
||||
for Y in range(py):
|
||||
for X in range(px):
|
||||
ll = fxn(Y*px+X)
|
||||
if X == 0 and Y == 0:
|
||||
ret = np.zeros((*s[:2], s[2]+ky, s[3]+kx), dtype=ll.dtype)
|
||||
ret[..., Y:Y+ll.shape[2]*py:py, X:X+ll.shape[3]*px:px] = ll
|
||||
return ret[..., :s[2], :s[3]]
|
||||
ret = np.zeros(s, dtype=ll.dtype)
|
||||
ret[:, :, Y:my:py, X:mx:px] = ll
|
||||
return ret
|
||||
|
||||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel_size=(2, 2), stride=None):
|
||||
if not stride:
|
||||
ctx.stride = stride = kernel_size
|
||||
stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=-np.inf)
|
||||
idxs = np.nanargmax(stack, axis=0)[..., :my, :mx]
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
idxs = np.argmax(stack, axis=0)
|
||||
ctx.save_for_backward(idxs, x.shape)
|
||||
return np.amax(stack, axis=0)[..., :my, :mx]
|
||||
return np.max(stack, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# TODO implement for stride != kernel_size
|
||||
if ctx.kernel_size != ctx.stride:
|
||||
raise NotImplementedError("CPU MaxPool2D.backward() with stride != kernel_size not implemented")
|
||||
idxs,s = ctx.saved_tensors
|
||||
return unstack_for_pool(
|
||||
lambda idx: grad_output * (idxs == idx),
|
||||
s, ctx.kernel_size, ctx.stride)
|
||||
s, *ctx.kernel_size)
|
||||
register('max_pool2d', MaxPool2D)
|
||||
|
||||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, kernel_size=(2, 2), stride=None):
|
||||
if not stride:
|
||||
ctx.stride = stride = kernel_size
|
||||
stack, (my, mx) = stack_for_pool(x, kernel_size, stride, fill_value=np.nan)
|
||||
def forward(ctx, x, kernel_size=(2, 2)):
|
||||
stack = stack_for_pool(x, *kernel_size)
|
||||
ctx.save_for_backward(x.shape)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return np.nanmean(stack, axis=0)[...,:my, :mx]
|
||||
return np.mean(stack, axis=0)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# TODO implement for stride != kernel_size
|
||||
if ctx.kernel_size != ctx.stride:
|
||||
raise NotImplementedError("CPU AvgPool2D.backward() with stride != kernel_size not implemented")
|
||||
s, = ctx.saved_tensors
|
||||
py, px = ctx.kernel_size
|
||||
return unstack_for_pool(
|
||||
lambda idx: grad_output/py/px,
|
||||
s, ctx.kernel_size, ctx.stride)
|
||||
s, py, px)
|
||||
register('avg_pool2d', AvgPool2D)
|
||||
|
||||
|
|
Loading…
Reference in New Issue