diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 6907d769..4843839a 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -85,8 +85,8 @@ def AveragePool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_p return padding_included / div def MaxPool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1): - assert ceil_mode == 0 and storage_order == 0 and dilations == 1 - return _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-2:]).max_pool2d(kernel_shape, stride=strides) + assert ceil_mode == 0 and storage_order == 0 + return _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-2:]).max_pool2d(kernel_shape, stride=strides, dilation=dilations) def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0) diff --git a/test/test_ops.py b/test/test_ops.py index 81587f4d..902afaa1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -635,6 +635,12 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) + def test_maxpool2d_dilation(self): + for dilation in [(2, 3), (3, 2), 2, 3]: + helper_test_op([(32,2,110,28)], + lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation), + lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation)) + def test_avgpool2d(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9a046f19..66b67982 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -367,7 +367,7 @@ class Tensor: # NOTE: these work for more than 2D def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) + def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))