diff --git a/examples/vit.py b/examples/vit.py index 63abd56e..3ef04389 100644 --- a/examples/vit.py +++ b/examples/vit.py @@ -74,49 +74,10 @@ img = img.astype(np.float32)[:3].reshape(1,3,224,224) img /= 255.0 img -= 0.5 img /= 0.5 -#img[:] = 0 - -""" -import torch -from timm.models.vision_transformer import vit_tiny_patch16_224 -mdl = vit_tiny_patch16_224(pretrained=True) -#out = mdl(torch.Tensor(img)) -#choice = out.argmax(axis=1).item() -#print(out[0, choice], lbls[choice]) - -pe = m.patch_embed(Tensor(img)) -x = m.cls_token.cat(pe, dim=1) + m.pos_embed -x = m.tbs[0](x) -#x = layernorm(x, 192).linear(m.tbs[0].ln1) - -xp = mdl.patch_embed(torch.Tensor(img)) -xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed -xp = mdl.blocks[0](xp) -#xp = mdl.blocks[0].norm1(xp) - -print(x.shape, xp.shape) -print(np.max(x.data), np.max(xp.detach().numpy())) -print(np.max(np.abs(x.data - xp.detach().numpy()))) - -exit(0) -""" - - - - - -#import matplotlib.pyplot as plt -#plt.imshow(np.transpose(img[0], (1,2,0))) -#plt.show() out = m.forward(Tensor(img)) outnp = out.cpu().data.ravel() choice = outnp.argmax() print(out.shape, choice, outnp[choice], lbls[choice]) -#lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")]) -#cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n") -#print(cls[choice], lookup[cls[choice]]) - - diff --git a/models/transformer.py b/models/transformer.py index c05add42..066b734e 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -70,8 +70,7 @@ class Transformer: onehot = onehot.reshape(bs*x.shape[1], self.maxlen+self.syms) x = Tensor(onehot, device=x.device).dot(self.embed).reshape(shape=(bs, x.shape[1], -1)) - for t in self.tbs: - x = t(x) + x = x.sequential(self.tbs) x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax() return x.reshape(shape=(bs, -1, x.shape[-1])) @@ -94,8 +93,7 @@ class ViT: def forward(self, x): pe = self.patch_embed(x) x = self.cls_token.add(Tensor.zeros(pe.shape[0],1,1)).cat(pe, dim=1) + self.pos_embed - for l in self.tbs: - x = l(x) + x = x.sequential(self.tbs) x = x.layernorm().linear(self.norm) return x[:, 0].linear(self.head)