mirror of https://github.com/commaai/tinygrad.git
add support for padding='same' in nn.conv (#6975)
* add support for padding='same' in nn.conv * express concisely * simplify loop * test same padding with dilation and conv1d * fix bad indentation * make loop one liner
This commit is contained in:
parent
54dcea235d
commit
23c09f4b4c
|
@ -172,6 +172,66 @@ class TestNN(unittest.TestCase):
|
|||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv1d_same_padding(self):
|
||||
BS, C1, W = 8, 3, 32
|
||||
C2, K, S, P = 16, 3, 1, 'same'
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def _run_conv2d_same_padding_test(self, BS, C1, C2, H, W, K, S, padding='same', D=1):
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=padding, dilation=D).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
def test_conv2d_same_padding_odd_input(self):
|
||||
BS, C1, H, W = 16, 16, 29, 31
|
||||
C2, K, S, P = 32, 4, 1, 'same'
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
|
||||
|
||||
def test_conv2d_same_padding_large_kernel(self):
|
||||
BS, C1, H, W = 16, 16, 28, 33
|
||||
C2, K, S, P = 32, 9, 1, 'same'
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P)
|
||||
|
||||
def test_conv2d_same_padding_with_dilation(self):
|
||||
BS, C1, H, W = 16, 3, 28, 28
|
||||
C2, K, S, P, D = 32, 3, 1, 'same', 3
|
||||
self._run_conv2d_same_padding_test(BS, C1, C2, H, W, K, S, P, D)
|
||||
|
||||
def test_conv2d_same_padding_invalid_stride(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 2, 'same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
def test_conv2d_same_padding_invalid_padding_str(self):
|
||||
C1, C2, K, S, P = 16, 32, 2, 1, 'not_same'
|
||||
self.assertRaises(ValueError, Conv2d, C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
@unittest.skip("Takes too long to compile for Compiled backends")
|
||||
def test_conv2d_winograd(self):
|
||||
BS, C1, H, W = 2, 8, 16, 16
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
from typing import Optional, Union, Tuple
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, make_pair
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
|
||||
class BatchNorm:
|
||||
|
@ -95,7 +95,12 @@ class Conv2d:
|
|||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups
|
||||
if isinstance(padding, str):
|
||||
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
|
||||
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
|
||||
self.padding = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)]
|
||||
else: self.padding = padding
|
||||
self.stride, self.dilation, self.groups = stride, dilation, groups
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
||||
|
|
Loading…
Reference in New Issue