fix bug in getitem, drop int axis

This commit is contained in:
George Hotz 2021-11-29 14:01:24 -05:00
parent c752033283
commit 9ce881f88c
2 changed files with 62 additions and 2 deletions

View File

@ -8,9 +8,63 @@ with tf.io.gfile.GFile(fn, "rb") as f:
g.write(dat)
"""
from tinygrad.tensor import Tensor
from models.transformer import TransformerBlock
class ViT:
def __init__(self):
self.conv_weight = Tensor.uniform(192, 3, 16, 16)
self.conv_bias = Tensor.zeros(192)
self.cls = Tensor.ones(1, 1, 192)
self.tbs = [TransformerBlock(embed_dim=192, num_heads=3, ff_dim=768) for i in range(12)]
self.pos = Tensor.ones(1, 197, 192)
self.head = (Tensor.uniform(192, 21843), Tensor.zeros(21843))
def forward(self, x):
print(x.shape)
x = x.conv2d(self.conv_weight, stride=16)
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
print(x.shape)
x = x.reshape(shape=(x.shape[0], 192, -1)).transpose(order=(0,2,1))
print(x.shape)
x = self.cls.cat(x, dim=1)
print(x.shape)
for l in self.tbs:
x = l(x)
return x[:, 0].affine(self.head)
m = ViT()
import numpy as np
dat = np.load("cache/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz")
for x in dat.keys():
print(x, dat[x].shape)
print(x, dat[x].shape, dat[x].dtype)
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,1,0)))
m.conv_bias.assign(dat['embedding/bias'])
m.cls.assign(dat['cls'])
m.pos.assign(dat['Transformer/posembed_input/pos_embedding'])
for i in range(12):
m.tbs[i].query_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(192, 192))
m.tbs[i].query_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(192))
m.tbs[i].key_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(192, 192))
m.tbs[i].key_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(192))
m.tbs[i].value_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(192, 192))
m.tbs[i].value_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(192))
m.tbs[i].final[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(192, 192))
m.tbs[i].final[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(192))
m.tbs[i].ff1[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel'])
m.tbs[i].ff1[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias'])
m.tbs[i].ff2[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel'])
m.tbs[i].ff2[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/bias'])
m.tbs[i].ln1[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/scale'])
m.tbs[i].ln1[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/bias'])
m.tbs[i].ln2[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/scale'])
m.tbs[i].ln2[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/bias'])
test_input = Tensor.ones(1, 3, 224, 224)
out = m.forward(test_input)
print(out.shape)

View File

@ -64,6 +64,9 @@ class Tensor:
return f"<Tensor {self.data!r} with grad {(self.grad.data if self.grad else None)!r}>"
def assign(self, x):
if not isinstance(x, Tensor):
x = Tensor(x)
assert self.shape == x.shape
self.data = x.data
@property
@ -170,6 +173,7 @@ class Tensor:
def __getitem__(self, val):
arg = []
new_shape = []
if val is not None:
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]):
if isinstance(s, int):
@ -177,8 +181,10 @@ class Tensor:
else:
arg.append((s.start if s.start is not None else 0,
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
new_shape.append(arg[-1][1] - arg[-1][0])
assert s.step is None or s.step == 1
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
new_shape += self.shape[len(arg):]
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]).reshape(shape=new_shape)
def cat(self, y, dim=0):
assert len(self.shape) == len(y.shape)