mirror of https://github.com/commaai/tinygrad.git
537 lines
22 KiB
Python
Executable File
537 lines
22 KiB
Python
Executable File
#!/usr/bin/env python
|
|
import unittest
|
|
import numpy as np
|
|
import torch
|
|
from tinygrad import Tensor, Device, TinyJit
|
|
from tinygrad.ops import UOps
|
|
from tinygrad.helpers import CI, Context
|
|
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
|
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
|
|
from tinygrad.nn.state import load_state_dict
|
|
from tinygrad.engine.schedule import create_schedule
|
|
from tinygrad.engine.realize import run_schedule
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
|
|
class TestNN(unittest.TestCase):
|
|
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU")
|
|
def test_sparse_cat_cross_entropy(self):
|
|
# create in tinygrad
|
|
input_tensor = Tensor.randn(6, 5) # not square to test that mean scaling uses the correct dimension
|
|
target = Tensor([0, 0, 0, 1, 2, 3]) # torch doesn't support target=-1
|
|
torch_input = torch.tensor(input_tensor.numpy())
|
|
torch_target = torch.tensor(target.numpy(), dtype=torch.long)
|
|
|
|
for smoothing in [0.0, 0.1, 0.5, 1.0]:
|
|
for ignore_index in [-1, 0, 2]:
|
|
for reduction in ["none", "sum", "mean"]:
|
|
loss = input_tensor.sparse_categorical_crossentropy(target, label_smoothing=smoothing, ignore_index=ignore_index, reduction=reduction)
|
|
torch_loss = torch.nn.CrossEntropyLoss(reduction=reduction, label_smoothing=smoothing, ignore_index=ignore_index)(torch_input, torch_target)
|
|
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
|
|
|
# also test with a batch dimension (of size 1)
|
|
loss = input_tensor.unsqueeze(0).sparse_categorical_crossentropy(
|
|
target.unsqueeze(0), label_smoothing=smoothing, ignore_index=ignore_index, reduction=reduction
|
|
)
|
|
torch_loss = torch.nn.CrossEntropyLoss(reduction=reduction, label_smoothing=smoothing, ignore_index=ignore_index)(
|
|
torch_input.unsqueeze(0).permute(0,2,1), torch_target.unsqueeze(0)
|
|
)
|
|
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
|
|
|
def test_batchnorm2d(self, training=False, threed=False, track_running_stats=True):
|
|
with Tensor.train(training):
|
|
szs = [4, 8, 16, 32]
|
|
for sz in szs:
|
|
# create in tinygrad
|
|
bn = BatchNorm(sz, eps=1e-5, track_running_stats=track_running_stats)
|
|
bn.weight = Tensor.randn(sz)
|
|
bn.bias = Tensor.randn(sz)
|
|
if track_running_stats:
|
|
bn.running_mean = Tensor.randn(sz)
|
|
bn.running_var = Tensor.randn(sz)
|
|
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
if threed:
|
|
tbn = torch.nn.BatchNorm3d(sz, track_running_stats=track_running_stats).eval()
|
|
else:
|
|
tbn = torch.nn.BatchNorm2d(sz, track_running_stats=track_running_stats).eval()
|
|
tbn.training = training
|
|
tbn.weight[:] = torch.tensor(bn.weight.numpy())
|
|
tbn.bias[:] = torch.tensor(bn.bias.numpy())
|
|
if track_running_stats:
|
|
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
|
|
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
|
|
|
|
if track_running_stats:
|
|
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
# trial
|
|
if threed:
|
|
inn = Tensor.randn(2, sz, 3, 3, 3)
|
|
else:
|
|
inn = Tensor.randn(2, sz, 3, 3)
|
|
|
|
# in tinygrad
|
|
outt = bn(inn)
|
|
|
|
# in torch
|
|
toutt = tbn(torch.tensor(inn.numpy()))
|
|
|
|
# close
|
|
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6)
|
|
if track_running_stats:
|
|
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
def test_batchnorm2d_training(self): self.test_batchnorm2d(True, False, True)
|
|
def test_batchnorm2d_no_running_stats(self): self.test_batchnorm2d(False, False, False)
|
|
def test_batchnorm2d_training_no_running_stats(self): self.test_batchnorm2d(True, False, False)
|
|
def test_batchnorm3d(self): self.test_batchnorm2d(False, True, True)
|
|
def test_batchnorm3d_training(self): self.test_batchnorm2d(True, True, True)
|
|
def test_batchnorm3d_no_running_stats(self): self.test_batchnorm2d(False, True, False)
|
|
def test_batchnorm3d_training_no_running_stats(self): self.test_batchnorm2d(True, True, False)
|
|
|
|
def test_batchnorm_axis(self):
|
|
sz = (2, 4, 3, 2, 2)
|
|
x = Tensor.randn(sz)
|
|
weight = Tensor.randn(2, 3)
|
|
bias = Tensor.randn(2, 3)
|
|
mean = Tensor.randn(2, 3)
|
|
invstd = Tensor.randn(2, 3)
|
|
a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
|
|
.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
|
|
b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
|
|
.batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
|
|
t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
|
|
t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
|
|
t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
|
|
torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
|
|
|
|
np.testing.assert_allclose(a.numpy(), b.numpy())
|
|
|
|
def test_linear(self):
|
|
def _test_linear(x, in_dim, out_dim):
|
|
# create in tinygrad
|
|
model = Linear(in_dim, out_dim)
|
|
z = model(x)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
|
|
torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32)
|
|
torch_x = torch.tensor(x.numpy(), dtype=torch.float32)
|
|
torch_z = torch_layer(torch_x)
|
|
|
|
# test
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
BS, T, in_dim, out_dim = 4, 2, 8, 16
|
|
_test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim)
|
|
_test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims
|
|
|
|
def test_conv1d(self):
|
|
BS, C1, W = 4, 16, 224//4
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
def test_conv2d(self):
|
|
BS, C1, H, W = 4, 16, 224//4, 224//4
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
@unittest.skip("Takes too long to compile for Compiled backends")
|
|
def test_conv2d_winograd(self):
|
|
BS, C1, H, W = 2, 8, 16, 16
|
|
C2, K, S, P = 8, 3, 1, 1
|
|
|
|
# create in tinygrad
|
|
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
layer.weight.requires_grad = True
|
|
layer.bias.requires_grad = True
|
|
|
|
# create in torch
|
|
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight = torch.nn.Parameter(torch.tensor(layer.weight.numpy(), dtype=torch.float32))
|
|
torch_layer.bias = torch.nn.Parameter(torch.tensor(layer.bias.numpy(), dtype=torch.float32))
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, H, W, requires_grad=True)
|
|
|
|
with Context(WINO=1):
|
|
z = layer(x)
|
|
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
m = z.mean()
|
|
m.backward()
|
|
gw = layer.weight.grad.realize()
|
|
gb = layer.bias.grad.realize()
|
|
gx = x.grad.realize()
|
|
|
|
torch_z.mean().backward()
|
|
np.testing.assert_allclose(gw.numpy(), torch_layer.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
|
np.testing.assert_allclose(gb.numpy(), torch_layer.bias.grad.numpy(), atol=5e-4, rtol=1e-5)
|
|
np.testing.assert_allclose(gx.numpy(), torch_x.grad.numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI")
|
|
def test_conv_transpose1d(self):
|
|
BS, C1, W = 4, 16, 224//4
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI")
|
|
def test_conv_transpose2d(self):
|
|
BS, C1, H, W = 4, 16, 224//4, 224//4
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
def test_groupnorm(self):
|
|
BS, H, W, C, G = 20, 10, 10, 6, 3
|
|
|
|
# create in torch
|
|
torch_layer = torch.nn.GroupNorm(G, C).eval()
|
|
|
|
# create in tinygrad
|
|
layer = GroupNorm(G, C)
|
|
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
|
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
|
|
|
for _ in range(10):
|
|
# forward
|
|
x = Tensor.randn(BS, C, H, W, requires_grad=True)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
|
|
|
# backward
|
|
z.sum().backward()
|
|
torch_z.sum().backward(retain_graph=True)
|
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
|
|
def test_layernorm(self):
|
|
N, C, H, W = 20, 5, 10, 10
|
|
|
|
# create in torch
|
|
torch_layer = torch.nn.LayerNorm([H, W]).eval()
|
|
|
|
# create in tinygrad
|
|
layer = LayerNorm([H, W])
|
|
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
|
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
|
|
|
for _ in range(10):
|
|
# forward
|
|
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
|
|
|
# backward
|
|
z.sum().backward()
|
|
torch_z.sum().backward(retain_graph=True)
|
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
|
|
def test_layernorm_2d(self):
|
|
N, C, H, W = 20, 5, 10, 10
|
|
|
|
# create in torch
|
|
torch_layer = torch.nn.LayerNorm([C]).eval()
|
|
|
|
# create in tinygrad
|
|
layer = LayerNorm2d(C)
|
|
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
|
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
|
|
|
for _ in range(10):
|
|
# forward
|
|
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
|
|
|
# backward
|
|
z.sum().backward()
|
|
torch_z.sum().backward(retain_graph=True)
|
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
|
|
|
def test_instancenorm_2d(self):
|
|
N, C, H, W = 20, 10, 10, 10
|
|
|
|
# create in torch
|
|
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
|
|
|
|
# create in tinygrad
|
|
layer = InstanceNorm(C)
|
|
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
|
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
|
|
|
for _ in range(10):
|
|
# forward
|
|
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
|
|
|
# backward
|
|
z.sum().backward()
|
|
torch_z.sum().backward(retain_graph=True)
|
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
|
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
|
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
|
|
|
def test_instancenorm_3d(self):
|
|
N, C, D, H, W = 20, 10, 10, 10, 10
|
|
|
|
# create in torch
|
|
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
|
|
|
|
# create in tinygrad
|
|
layer = InstanceNorm(C)
|
|
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
|
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
|
|
|
for _ in range(10):
|
|
# forward
|
|
x = Tensor.randn(N, C, D, H, W, requires_grad=True)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
|
|
|
# backward
|
|
z.sum().backward()
|
|
torch_z.sum().backward(retain_graph=True)
|
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
|
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
|
|
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
|
|
|
def test_rmsnorm(self):
|
|
class TorchRMSNorm(torch.nn.Module):
|
|
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L34C1-L77C36
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
|
|
|
def _norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x):
|
|
output = self._norm(x.float()).type_as(x)
|
|
return output * self.weight
|
|
|
|
B, T, embed_size = 4, 10, 20
|
|
torch_layer = TorchRMSNorm(embed_size)
|
|
layer = RMSNorm(embed_size)
|
|
layer.weight.requires_grad = True
|
|
|
|
for _ in range(10):
|
|
# forward
|
|
x = Tensor.randn(B, T, embed_size, requires_grad=True)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
|
|
|
# backward
|
|
z.sum().backward()
|
|
torch_z.sum().backward(retain_graph=True)
|
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
|
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
|
|
|
|
def test_embedding(self):
|
|
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
|
|
|
# create in tinygrad
|
|
layer = Embedding(vocab_size, embed_size)
|
|
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Embedding(vocab_size, embed_size).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
|
|
|
# test with empty input length
|
|
x = Tensor(np.random.randint(0, vocab_size, (B, 0)))
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
|
|
|
# test with jit enabled
|
|
@TinyJit
|
|
def layer_jit(x):
|
|
return layer(x).realize()
|
|
|
|
for _ in range(3):
|
|
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
|
|
z = layer_jit(x)
|
|
torch_x = torch.tensor(x.numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
|
|
|
def test_embedding_one_kernel(self):
|
|
layer = Embedding(20, 30)
|
|
layer.weight = Tensor.zeros_like(layer.weight).contiguous()
|
|
a = Tensor([[1, 5, 9, 11],
|
|
[12, 19, 8, 1]])
|
|
result = layer(a)
|
|
schedule = create_schedule([result.lazydata])
|
|
self.assertEqual(3, len([item for item in schedule if item.ast.op is UOps.SINK]), "first run realizes arange, weight, and embedding")
|
|
run_schedule(schedule)
|
|
|
|
b = Tensor([[1, 2, 3],
|
|
[4, 5, 6],
|
|
[7, 8, 9]])
|
|
result = layer(b)
|
|
schedule = create_schedule([result.lazydata])
|
|
self.assertEqual(1, len([item for item in schedule if item.ast.op is UOps.SINK]), "second run realizes embedding only")
|
|
run_schedule(schedule)
|
|
|
|
def test_load_state_dict(self):
|
|
layer = Conv2d(3, 5, kernel_size=3)
|
|
|
|
state_dict = {
|
|
'weight': Tensor.randn(5, 3, 3, 3),
|
|
'bias': Tensor.randn(5),
|
|
}
|
|
load_state_dict(layer, state_dict)
|
|
|
|
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
|
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
|
|
|
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
|
def test_load_state_dict_sharded(self):
|
|
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
|
|
|
layer = Conv2d(3, 5, kernel_size=3)
|
|
layer.weight.shard_(devices, -1)
|
|
layer.bias.shard_(devices, None)
|
|
state_dict = {
|
|
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, -1),
|
|
'bias': Tensor.randn(5).shard(devices, None),
|
|
}
|
|
load_state_dict(layer, state_dict)
|
|
|
|
self.assertEqual(layer.weight.device, devices)
|
|
self.assertEqual(layer.bias.device, devices)
|
|
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
|
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
|
|
|
def test_lstm_cell(self):
|
|
layer = LSTMCell(32, 16)
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.LSTMCell(32, 16)
|
|
layer.weight_hh.assign(torch_layer.weight_hh.numpy())
|
|
layer.weight_ih.assign(torch_layer.weight_ih.numpy())
|
|
layer.bias_hh.assign(torch_layer.bias_hh.numpy())
|
|
layer.bias_ih.assign(torch_layer.bias_ih.numpy())
|
|
|
|
inp = Tensor.randn(1, 32)
|
|
out_h, out_c = layer(inp)
|
|
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()))
|
|
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
|
|
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
|
|
|
|
out_h, out_c = layer(inp, (out_h, out_c))
|
|
torch_out_h, torch_out_c = torch_layer(torch.tensor(inp.numpy()), (torch_out_h, torch_out_c))
|
|
np.testing.assert_allclose(out_h.numpy(), torch_out_h.numpy(), atol=1e-6)
|
|
np.testing.assert_allclose(out_c.numpy(), torch_out_c.numpy(), atol=1e-6)
|
|
|
|
def test_lstm_cell_no_bias(self):
|
|
layer = LSTMCell(32, 16, bias=False)
|
|
inp = Tensor.randn(1, 32)
|
|
out_h, out_c = layer(inp)
|
|
out_h.realize()
|
|
out_c.realize()
|
|
h = Tensor.randn(1, 16)
|
|
c = Tensor.randn(1, 16)
|
|
out_h, out_c = layer(inp, (h, c))
|
|
out_h.realize()
|
|
out_c.realize()
|
|
assert layer.bias_hh is None
|
|
assert layer.bias_ih is None
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|