add support for padding='same' in nn.conv (#6975)

* add support for padding='same' in nn.conv

* express concisely

* simplify loop

* test same padding with dilation and conv1d

* fix bad indentation

* make loop one liner
This commit is contained in:
Bhavya Gada 2024-10-10 20:39:07 -07:00 committed by GitHub
parent 54dcea235d
commit 23c09f4b4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 67 additions and 2 deletions

View File

@ -172,6 +172,66 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x) torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv1d_same_padding(self):
BS, C1, W = 8, 3, 32
C2, K, S, P = 16, 3, 1, 'same'
# create in tinygrad
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.uniform(BS, C1, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
def test_conv2d_same_padding_odd_input(self):
BS, C1, H, W = 16, 16, 29, 31
C2, K, S, P = 32, 4, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
def test_conv2d_same_padding_large_kernel(self):
BS, C1, H, W = 16, 16, 28, 33
C2, K, S, P = 32, 9, 1, 'same'
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
def test_conv2d_same_padding_with_dilation(self):
BS, C1, H, W = 16, 3, 28, 28
C2, K, S, P, D = 32, 3, 1, 'same', 3
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
def test_conv2d_same_padding_invalid_stride(self):
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
def test_conv2d_same_padding_invalid_padding_str(self):
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
@unittest.skip("Takes too long to compile for Compiled backends") @unittest.skip("Takes too long to compile for Compiled backends")
def test_conv2d_winograd(self): def test_conv2d_winograd(self):
BS, C1, H, W = 2, 8, 16, 16 BS, C1, H, W = 2, 8, 16, 16

View File

@ -1,7 +1,7 @@
import math import math
from typing import Optional, Union, Tuple from typing import Optional, Union, Tuple
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import prod from tinygrad.helpers import prod, make_pair
from tinygrad.nn import optim, state, datasets # noqa: F401 from tinygrad.nn import optim, state, datasets # noqa: F401
class BatchNorm: class BatchNorm:
@ -95,7 +95,12 @@ class Conv2d:
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups if isinstance(padding, str):
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
self.padding = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)]
else: self.padding = padding
self.stride, self.dilation, self.groups = stride, dilation, groups
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size)) scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale) self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None