promote layernorm to tensor op

This commit is contained in:
George Hotz 2021-11-29 18:08:21 -05:00
parent dca076dbf1
commit 125e74293f
2 changed files with 7 additions and 8 deletions

View File

@ -14,11 +14,6 @@ from extra.utils import fetch
from tinygrad.tensor import Tensor
def layernorm(x, sz, eps=1e-5):
y = (x - x.mean(axis=-1, keepdim=True))
layer_var = (y*y).mean(axis=-1, keepdim=True)
return y.div(layer_var.add(eps).sqrt())
class ViTBlock:
def __init__(self, embed_dim, num_heads, ff_dim):
# Multi-Head Attention
@ -63,10 +58,10 @@ class ViTBlock:
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
x = layernorm(inputs, embed_dim).linear(self.ln1)
x = inputs.layernorm().linear(self.ln1)
x = inputs + self.attn(x, bs).dropout(0.1)
xin = layernorm(x, embed_dim).linear(self.ln2)
xin = x.layernorm().linear(self.ln2)
x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
return x.reshape(shape=(bs, -1, embed_dim))
@ -92,7 +87,7 @@ class ViT:
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
for l in self.tbs:
x = l(x)
x = layernorm(x, x.shape[-1]).linear(self.norm)
x = x.layernorm().linear(self.norm)
return x[:, 0].linear(self.head)
Tensor.training = False

View File

@ -317,6 +317,10 @@ class Tensor:
ret = l(ret)
return ret
def layernorm(x, eps=1e-5):
y = (x - x.mean(axis=-1, keepdim=True))
return y.div((y*y).mean(axis=-1, keepdim=True).add(eps).sqrt())
# An instantiation of the Function is the Context
class Function:
def __new__(cls, *args, **kwargs):