Add max_pool2d dilation (#833)

This commit is contained in:
Marcello Fuschi 2023-05-29 00:16:48 +02:00 committed by GitHub
parent 7460bd9b02
commit 6d49925a26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 3 deletions

View File

@ -85,8 +85,8 @@ def AveragePool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, count_include_p
return padding_included / div return padding_included / div
def MaxPool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1): def MaxPool(X, kernel_shape, auto_pad="NOTSET", ceil_mode=0, dilations=1, pads=None, storage_order=0, strides=1):
assert ceil_mode == 0 and storage_order == 0 and dilations == 1 assert ceil_mode == 0 and storage_order == 0
return _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-2:]).max_pool2d(kernel_shape, stride=strides) return _padding(X, pads, auto_pad, constant_value=-np.inf, axes=tuple(range(len(X.shape)))[-2:]).max_pool2d(kernel_shape, stride=strides, dilation=dilations)
def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): def Conv(X, W, B=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1):
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0) return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=(pads[1], pads[3], pads[0], pads[2]) if pads is not None else 0)

View File

@ -635,6 +635,12 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride))
def test_maxpool2d_dilation(self):
for dilation in [(2, 3), (3, 2), 2, 3]:
helper_test_op([(32,2,110,28)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
def test_avgpool2d(self): def test_avgpool2d(self):
shape = (32,2,111,28) shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:

View File

@ -367,7 +367,7 @@ class Tensor:
# NOTE: these work for more than 2D # NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def avg_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))