add support for padding='same' in nn.conv (#6975)

* add support for padding='same' in nn.conv

* express concisely

* simplify loop

* test same padding with dilation and conv1d

* fix bad indentation

* make loop one liner
This commit is contained in:
Bhavya Gada 2024-10-10 20:39:07 -07:00 committed by GitHub
parent 54dcea235d
commit 23c09f4b4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 2 deletions

View File

@ -172,6 +172,66 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv1d_same_padding(self):
BS, C1, W = 8, 3, 32
C2, K, S, P = 16, 3, 1, 'same'
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv2d_same_padding_odd_input(self):
BS, C1, H, W = 16, 16, 29, 31
C2, K, S, P = 32, 4, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
def test_conv2d_same_padding_large_kernel(self):
BS, C1, H, W = 16, 16, 28, 33
C2, K, S, P = 32, 9, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
def test_conv2d_same_padding_with_dilation(self):
BS, C1, H, W = 16, 3, 28, 28
C2, K, S, P, D = 32, 3, 1, 'same', 3
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
def test_conv2d_same_padding_invalid_stride(self):
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
def test_conv2d_same_padding_invalid_padding_str(self):
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
@unittest.skip("Takes too long to compile for Compiled backends")
def test_conv2d_winograd(self):
BS, C1, H, W = 2, 8, 16, 16

View File

@ -1,7 +1,7 @@
import math
from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from tinygrad.helpers import prod, make_pair
from tinygrad.nn import optim, state, datasets # noqa: F401
class BatchNorm:
@ -95,7 +95,12 @@ class Conv2d:
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
if isinstance(padding, str):
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
self.padding = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)]
else: self.padding = padding
self.stride, self.dilation, self.groups = stride, dilation, groups
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None