mirror of https://github.com/commaai/tinygrad.git
fix bug in getitem, drop int axis
This commit is contained in:
parent
c752033283
commit
9ce881f88c
|
@ -8,9 +8,63 @@ with tf.io.gfile.GFile(fn, "rb") as f:
|
|||
g.write(dat)
|
||||
"""
|
||||
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from models.transformer import TransformerBlock
|
||||
class ViT:
|
||||
def __init__(self):
|
||||
self.conv_weight = Tensor.uniform(192, 3, 16, 16)
|
||||
self.conv_bias = Tensor.zeros(192)
|
||||
self.cls = Tensor.ones(1, 1, 192)
|
||||
self.tbs = [TransformerBlock(embed_dim=192, num_heads=3, ff_dim=768) for i in range(12)]
|
||||
self.pos = Tensor.ones(1, 197, 192)
|
||||
self.head = (Tensor.uniform(192, 21843), Tensor.zeros(21843))
|
||||
|
||||
def forward(self, x):
|
||||
print(x.shape)
|
||||
x = x.conv2d(self.conv_weight, stride=16)
|
||||
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
|
||||
print(x.shape)
|
||||
x = x.reshape(shape=(x.shape[0], 192, -1)).transpose(order=(0,2,1))
|
||||
print(x.shape)
|
||||
x = self.cls.cat(x, dim=1)
|
||||
print(x.shape)
|
||||
for l in self.tbs:
|
||||
x = l(x)
|
||||
return x[:, 0].affine(self.head)
|
||||
|
||||
m = ViT()
|
||||
|
||||
import numpy as np
|
||||
dat = np.load("cache/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz")
|
||||
for x in dat.keys():
|
||||
print(x, dat[x].shape)
|
||||
print(x, dat[x].shape, dat[x].dtype)
|
||||
|
||||
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,1,0)))
|
||||
m.conv_bias.assign(dat['embedding/bias'])
|
||||
m.cls.assign(dat['cls'])
|
||||
m.pos.assign(dat['Transformer/posembed_input/pos_embedding'])
|
||||
|
||||
for i in range(12):
|
||||
m.tbs[i].query_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(192, 192))
|
||||
m.tbs[i].query_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(192))
|
||||
m.tbs[i].key_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(192, 192))
|
||||
m.tbs[i].key_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(192))
|
||||
m.tbs[i].value_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(192, 192))
|
||||
m.tbs[i].value_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(192))
|
||||
m.tbs[i].final[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(192, 192))
|
||||
m.tbs[i].final[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(192))
|
||||
m.tbs[i].ff1[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel'])
|
||||
m.tbs[i].ff1[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias'])
|
||||
m.tbs[i].ff2[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel'])
|
||||
m.tbs[i].ff2[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/bias'])
|
||||
m.tbs[i].ln1[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/scale'])
|
||||
m.tbs[i].ln1[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/bias'])
|
||||
m.tbs[i].ln2[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/scale'])
|
||||
m.tbs[i].ln2[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/bias'])
|
||||
|
||||
test_input = Tensor.ones(1, 3, 224, 224)
|
||||
out = m.forward(test_input)
|
||||
print(out.shape)
|
||||
|
||||
|
||||
|
|
|
@ -64,6 +64,9 @@ class Tensor:
|
|||
return f"<Tensor {self.data!r} with grad {(self.grad.data if self.grad else None)!r}>"
|
||||
|
||||
def assign(self, x):
|
||||
if not isinstance(x, Tensor):
|
||||
x = Tensor(x)
|
||||
assert self.shape == x.shape
|
||||
self.data = x.data
|
||||
|
||||
@property
|
||||
|
@ -170,6 +173,7 @@ class Tensor:
|
|||
|
||||
def __getitem__(self, val):
|
||||
arg = []
|
||||
new_shape = []
|
||||
if val is not None:
|
||||
for i, s in enumerate(val if isinstance(val, (list, tuple)) else [val]):
|
||||
if isinstance(s, int):
|
||||
|
@ -177,8 +181,10 @@ class Tensor:
|
|||
else:
|
||||
arg.append((s.start if s.start is not None else 0,
|
||||
(s.stop if s.stop >=0 else self.shape[i]+s.stop) if s.stop is not None else self.shape[i]))
|
||||
new_shape.append(arg[-1][1] - arg[-1][0])
|
||||
assert s.step is None or s.step == 1
|
||||
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
new_shape += self.shape[len(arg):]
|
||||
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]).reshape(shape=new_shape)
|
||||
|
||||
def cat(self, y, dim=0):
|
||||
assert len(self.shape) == len(y.shape)
|
||||
|
|
Loading…
Reference in New Issue