mirror of https://github.com/commaai/tinygrad.git
Added `nn.ConvTranspose1d` (#1243)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
7399f6dad7
commit
872e2198fe
|
@ -5,7 +5,7 @@ from extra.utils import WINDOWS
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import BatchNorm2d, Conv1d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
|
||||
from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
|
||||
import torch
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
|
@ -115,6 +115,27 @@ class TestNN(unittest.TestCase):
|
|||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
|
||||
def test_conv_transpose1d(self):
|
||||
BS, C1, W = 4, 16, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
# create in tinygrad
|
||||
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.uniform(BS, C1, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
|
||||
def test_conv_transpose2d(self):
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
|
|
|
@ -50,6 +50,9 @@ class Conv2d:
|
|||
def __call__(self, x):
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
return ConvTranspose2d(in_channels, out_channels, (kernel_size,), stride, padding, output_padding, dilation, groups, bias)
|
||||
|
||||
class ConvTranspose2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
|
|
Loading…
Reference in New Issue