mirror of https://github.com/commaai/tinygrad.git
add ff_dim to transformer
This commit is contained in:
parent
b4839eb6bb
commit
99b6051467
|
@ -0,0 +1 @@
|
|||
*
|
|
@ -27,7 +27,7 @@ def make_dataset():
|
|||
|
||||
from tinygrad.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
model = Transformer(10, 6, 2, 128, 4, 32)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
lr = 0.003
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
|
||||
"""
|
||||
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
||||
import tensorflow as tf
|
||||
with tf.io.gfile.GFile(fn, "rb") as f:
|
||||
dat = f.read()
|
||||
with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
|
||||
g.write(dat)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
dat = np.load("cache/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz")
|
||||
for x in dat.keys():
|
||||
print(x, dat[x].shape)
|
||||
|
||||
|
|
@ -11,7 +11,7 @@ def layernorm(x, sz, eps=1e-5):
|
|||
return ret.reshape(shape=in_shape)
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, embed_dim, num_heads):
|
||||
def __init__(self, embed_dim, num_heads, ff_dim):
|
||||
# Multi-Head Attention
|
||||
self.num_heads = num_heads
|
||||
self.head_size = embed_dim // num_heads
|
||||
|
@ -24,8 +24,8 @@ class TransformerBlock:
|
|||
|
||||
self.final = Tensor.uniform(embed_dim, embed_dim)
|
||||
|
||||
self.ff1 = Tensor.uniform(embed_dim, embed_dim)
|
||||
self.ff2 = Tensor.uniform(embed_dim, embed_dim)
|
||||
self.ff1 = Tensor.uniform(embed_dim, ff_dim)
|
||||
self.ff2 = Tensor.uniform(ff_dim, embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
# bs x T x embed_dim
|
||||
|
@ -54,12 +54,12 @@ class TransformerBlock:
|
|||
|
||||
class Transformer:
|
||||
# L = layers, H = embed_dim, A = num_heads
|
||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads):
|
||||
def __init__(self, syms, maxlen, layers, embed_dim, num_heads, ff_dim):
|
||||
self.maxlen, self.syms = maxlen, syms
|
||||
self.embed = Tensor.uniform(maxlen+syms, embed_dim, requires_grad=False)
|
||||
self.tbs = []
|
||||
for i in range(layers):
|
||||
self.tbs.append(TransformerBlock(embed_dim, num_heads))
|
||||
self.tbs.append(TransformerBlock(embed_dim, num_heads, ff_dim))
|
||||
self.final = Tensor.uniform(embed_dim, syms)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -33,7 +33,7 @@ class TestTrain(unittest.TestCase):
|
|||
|
||||
def test_transformer(self):
|
||||
# this should be small GPT-2, but the param count is wrong
|
||||
model = Transformer(syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12)
|
||||
model = Transformer(syms=10, maxlen=6, layers=12, embed_dim=768, num_heads=12, ff_dim=768*4)
|
||||
X = np.zeros((BS,6), dtype=np.float32)
|
||||
Y = np.zeros((BS,6), dtype=np.int32)
|
||||
train_one_step(model,X,Y)
|
||||
|
|
Loading…
Reference in New Issue