mirror of https://github.com/commaai/tinygrad.git
finish unsupporting strided pool, add global avg pool test (#92)
This commit is contained in:
parent
7ac1b163a5
commit
16d564a53c
|
@ -90,7 +90,6 @@ class TestOps(unittest.TestCase):
|
|||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_maxpool2d(self):
|
||||
# TODO merge into test_maxpool2d_strided when backward() is implemented
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
|
@ -98,11 +97,12 @@ class TestOps(unittest.TestCase):
|
|||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_avgpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1), shape[2:]]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu, forward_only=self.gpu)
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), gpu=self.gpu)
|
||||
|
||||
if GPU:
|
||||
class TestOpsGPU(TestOps):
|
||||
|
|
|
@ -411,20 +411,14 @@ register('sigmoid', Sigmoid, gpu=True)
|
|||
class AvgPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel_size=(2, 2)):
|
||||
ctx.stride = stride = ctx.kernel_size
|
||||
iter_op = "group_res += input[iid]"
|
||||
result_op = "group_res / (kernel_size.x * kernel_size.y)"
|
||||
ret = subsample_op(ctx, input, kernel_size, stride, iter_op, result_op)
|
||||
ret = subsample_op(ctx, input, kernel_size, kernel_size, iter_op, result_op)
|
||||
ctx.save_for_backward(input.shape)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise NotImplementedError("GPU AvgPool2D.backward() is broken")
|
||||
|
||||
# TODO implement for stride != kernel_size
|
||||
if ctx.kernel_size != ctx.stride:
|
||||
raise NotImplementedError("GPU AvgPool2D.backward() with stride != kernel_size not implemented")
|
||||
orig_shape, = ctx.saved_tensors
|
||||
result_op = "input[iid] / (kernel_size.x * kernel_size.y)"
|
||||
return supersample_op(ctx, grad_output, orig_shape, ctx.kernel_size, result_op)
|
||||
|
@ -433,15 +427,13 @@ register('avg_pool2d', AvgPool2D, gpu=True)
|
|||
class MaxPool2D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, kernel_size=(2, 2)):
|
||||
ctx.stride = stride = ctx.kernel_size
|
||||
init_val = "FLT_MIN"
|
||||
iter_op = "group_res = max(group_res, input[iid])"
|
||||
result_op = "group_res"
|
||||
return subsample_op(ctx, input, kernel_size, stride, iter_op, result_op, init_val=init_val)
|
||||
return subsample_op(ctx, input, kernel_size, kernel_size, iter_op, result_op, init_val=init_val)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# TODO implement with stride support
|
||||
raise NotImplementedError("GPU MaxPool2D.backward() not implemented")
|
||||
register('max_pool2d', MaxPool2D, gpu=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue