vit is now tested

This commit is contained in:
George Hotz 2021-11-30 00:23:06 -05:00
parent aff810e722
commit de938c2d9d
3 changed files with 32 additions and 27 deletions

View File

@ -13,32 +13,7 @@ import io
from extra.utils import fetch
from tinygrad.tensor import Tensor
from models.transformer import TransformerBlock
class ViT:
def __init__(self, embed_dim=192):
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
self.conv_bias = Tensor.zeros(embed_dim)
self.cls_token = Tensor.ones(1, 1, embed_dim)
self.tbs = [TransformerBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768, prenorm=True) for i in range(12)]
self.pos_embed = Tensor.ones(1, 197, embed_dim)
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
def patch_embed(self, x):
x = x.conv2d(self.conv_weight, stride=16)
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
return x
def forward(self, x):
pe = self.patch_embed(x)
# TODO: expand cls_token for batch
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
for l in self.tbs:
x = l(x)
x = x.layernorm().linear(self.norm)
return x[:, 0].linear(self.head)
from models.transformer import ViT
Tensor.training = False
m = ViT()

View File

@ -75,3 +75,27 @@ class Transformer:
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
return x.reshape(shape=(bs, -1, x.shape[-1]))
class ViT:
def __init__(self, embed_dim=192):
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
self.conv_bias = Tensor.zeros(embed_dim)
self.cls_token = Tensor.ones(1, 1, embed_dim)
self.tbs = [TransformerBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768, prenorm=True) for i in range(12)]
self.pos_embed = Tensor.ones(1, 197, embed_dim)
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
def patch_embed(self, x):
x = x.conv2d(self.conv_weight, stride=16)
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
return x
def forward(self, x):
pe = self.patch_embed(x)
x = self.cls_token.add(Tensor.zeros(pe.shape[0],1,1)).cat(pe, dim=1) + self.pos_embed
for l in self.tbs:
x = l(x)
x = x.layernorm().linear(self.norm)
return x[:, 0].linear(self.head)

View File

@ -7,7 +7,7 @@ from tinygrad.tensor import Tensor
from extra.training import train
from extra.utils import get_parameters
from models.efficientnet import EfficientNet
from models.transformer import Transformer
from models.transformer import Transformer, ViT
from models.resnet import ResNet18
BS = int(os.getenv("BS", "4"))
@ -31,6 +31,12 @@ class TestTrain(unittest.TestCase):
Y = np.zeros((BS), dtype=np.int32)
train_one_step(model,X,Y)
def test_vit(self):
model = ViT()
X = np.zeros((BS,3,224,224), dtype=np.float32)
Y = np.zeros((BS,), dtype=np.int32)
train_one_step(model,X,Y)
def test_transformer(self):
# this should be small GPT-2, but the param count is wrong
# (real ff_dim is 768*4)