use sequential

This commit is contained in:
George Hotz 2021-11-30 00:25:39 -05:00
parent de938c2d9d
commit 535f02cc64
2 changed files with 2 additions and 43 deletions

View File

@ -74,49 +74,10 @@ img = img.astype(np.float32)[:3].reshape(1,3,224,224)
img /= 255.0
img -= 0.5
img /= 0.5
#img[:] = 0
"""
import torch
from timm.models.vision_transformer import vit_tiny_patch16_224
mdl = vit_tiny_patch16_224(pretrained=True)
#out = mdl(torch.Tensor(img))
#choice = out.argmax(axis=1).item()
#print(out[0, choice], lbls[choice])
pe = m.patch_embed(Tensor(img))
x = m.cls_token.cat(pe, dim=1) + m.pos_embed
x = m.tbs[0](x)
#x = layernorm(x, 192).linear(m.tbs[0].ln1)
xp = mdl.patch_embed(torch.Tensor(img))
xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed
xp = mdl.blocks[0](xp)
#xp = mdl.blocks[0].norm1(xp)
print(x.shape, xp.shape)
print(np.max(x.data), np.max(xp.detach().numpy()))
print(np.max(np.abs(x.data - xp.detach().numpy())))
exit(0)
"""
#import matplotlib.pyplot as plt
#plt.imshow(np.transpose(img[0], (1,2,0)))
#plt.show()
out = m.forward(Tensor(img))
outnp = out.cpu().data.ravel()
choice = outnp.argmax()
print(out.shape, choice, outnp[choice], lbls[choice])
#lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")])
#cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n")
#print(cls[choice], lookup[cls[choice]])

View File

@ -70,8 +70,7 @@ class Transformer:
onehot = onehot.reshape(bs*x.shape[1], self.maxlen+self.syms)
x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1))
for t in self.tbs:
x = t(x)
x = x.sequential(self.tbs)
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
return x.reshape(shape=(bs, -1, x.shape[-1]))
@ -94,8 +93,7 @@ class ViT:
def forward(self, x):
pe = self.patch_embed(x)
x = self.cls_token.add(Tensor.zeros(pe.shape[0],1,1)).cat(pe, dim=1) + self.pos_embed
for l in self.tbs:
x = l(x)
x = x.sequential(self.tbs)
x = x.layernorm().linear(self.norm)
return x[:, 0].linear(self.head)