mirror of https://github.com/commaai/tinygrad.git
remove dumb nn ops
This commit is contained in:
parent
33720e733f
commit
dca076dbf1
|
@ -15,13 +15,9 @@ from extra.utils import fetch
|
|||
from tinygrad.tensor import Tensor
|
||||
|
||||
def layernorm(x, sz, eps=1e-5):
|
||||
in_shape = x.shape
|
||||
x = x.reshape(shape=(-1, sz))
|
||||
layer_mean = x.mean(axis=(-1,)).reshape(shape=[-1, 1])
|
||||
y = (x - layer_mean)
|
||||
layer_var = (y*y).mean(axis=(-1,))
|
||||
ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt())
|
||||
return ret.reshape(shape=in_shape)
|
||||
y = (x - x.mean(axis=-1, keepdim=True))
|
||||
layer_var = (y*y).mean(axis=-1, keepdim=True)
|
||||
return y.div(layer_var.add(eps).sqrt())
|
||||
|
||||
class ViTBlock:
|
||||
def __init__(self, embed_dim, num_heads, ff_dim):
|
||||
|
@ -46,7 +42,7 @@ class ViTBlock:
|
|||
def attn(self, x, bs):
|
||||
embed_dim = self.num_heads * self.head_size
|
||||
|
||||
query, key, value = [x.affine(y) \
|
||||
query, key, value = [x.linear(y) \
|
||||
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
|
||||
for y in [self.query_dense, self.key_dense, self.value_dense]]
|
||||
|
||||
|
@ -58,7 +54,7 @@ class ViTBlock:
|
|||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
|
||||
return attention.reshape(shape=(-1, embed_dim)).affine(self.final)
|
||||
return attention.reshape(shape=(-1, embed_dim)).linear(self.final)
|
||||
|
||||
def __call__(self, x):
|
||||
# bs x T x embed_dim
|
||||
|
@ -67,11 +63,11 @@ class ViTBlock:
|
|||
inputs = x.reshape(shape=(-1, embed_dim))
|
||||
|
||||
# run multi head attention (bs, T, num_heads, head_size)
|
||||
x = layernorm(inputs, embed_dim).affine(self.ln1)
|
||||
x = layernorm(inputs, embed_dim).linear(self.ln1)
|
||||
x = inputs + self.attn(x, bs).dropout(0.1)
|
||||
|
||||
xin = layernorm(x, embed_dim).affine(self.ln2)
|
||||
x = x + xin.affine(self.ff1).gelu().affine(self.ff2).dropout(0.1)
|
||||
xin = layernorm(x, embed_dim).linear(self.ln2)
|
||||
x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
|
||||
return x.reshape(shape=(bs, -1, embed_dim))
|
||||
|
||||
class ViT:
|
||||
|
@ -96,8 +92,8 @@ class ViT:
|
|||
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
|
||||
for l in self.tbs:
|
||||
x = l(x)
|
||||
x = layernorm(x, x.shape[-1]).affine(self.norm)
|
||||
return x[:, 0].affine(self.head)
|
||||
x = layernorm(x, x.shape[-1]).linear(self.norm)
|
||||
return x[:, 0].linear(self.head)
|
||||
|
||||
Tensor.training = False
|
||||
m = ViT()
|
||||
|
@ -171,7 +167,7 @@ mdl = vit_tiny_patch16_224(pretrained=True)
|
|||
pe = m.patch_embed(Tensor(img))
|
||||
x = m.cls_token.cat(pe, dim=1) + m.pos_embed
|
||||
x = m.tbs[0](x)
|
||||
#x = layernorm(x, 192).affine(m.tbs[0].ln1)
|
||||
#x = layernorm(x, 192).linear(m.tbs[0].ln1)
|
||||
|
||||
xp = mdl.patch_embed(torch.Tensor(img))
|
||||
xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed
|
||||
|
|
|
@ -25,11 +25,6 @@ def get_parameters(obj):
|
|||
for x in obj:
|
||||
parameters.extend(get_parameters(x))
|
||||
elif hasattr(obj, '__dict__'):
|
||||
if isinstance(obj, nn.Sequential):
|
||||
for layer in obj.layers:
|
||||
for v in layer.__dict__.values():
|
||||
parameters.extend(get_parameters(v))
|
||||
else:
|
||||
for v in obj.__dict__.values():
|
||||
parameters.extend(get_parameters(v))
|
||||
return parameters
|
||||
|
|
|
@ -47,17 +47,17 @@ class BasicBlock:
|
|||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = nn.Sequential()
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.downsample = nn.Sequential(
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2D(self.expansion*planes)
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = out + self.downsample(x)
|
||||
out = out + x.sequential(self.downsample)
|
||||
out = out.relu()
|
||||
return out
|
||||
|
||||
|
@ -72,12 +72,12 @@ class Bottleneck:
|
|||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.conv3 = nn.Conv2d(planes, self.expansion *planes, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2D(self.expansion*planes)
|
||||
self.downsample = nn.Sequential()
|
||||
self.downsample = []
|
||||
if stride != 1 or in_planes != self.expansion*planes:
|
||||
self.downsample = nn.Sequential(
|
||||
self.downsample = [
|
||||
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2D(self.expansion*planes)
|
||||
)
|
||||
]
|
||||
|
||||
def __call__(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
|
@ -105,14 +105,14 @@ class ResNet:
|
|||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
return layers
|
||||
|
||||
def forward(self, x):
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.layer1(out)
|
||||
out = self.layer2(out)
|
||||
out = self.layer3(out)
|
||||
out = self.layer4(out)
|
||||
out = out.sequential(self.layer1)
|
||||
out = out.sequential(self.layer2)
|
||||
out = out.sequential(self.layer3)
|
||||
out = out.sequential(self.layer4)
|
||||
out = out.mean(3).mean(2)
|
||||
out = self.fc(out).logsoftmax()
|
||||
return out
|
||||
|
|
|
@ -37,7 +37,7 @@ class TransformerBlock:
|
|||
inputs = x.reshape(shape=(-1, embed_dim))
|
||||
|
||||
# run multi head attention (bs, T, num_heads, head_size)
|
||||
query, key, value = [inputs.affine(y) \
|
||||
query, key, value = [inputs.linear(y) \
|
||||
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
|
||||
for y in [self.query_dense, self.key_dense, self.value_dense]]
|
||||
|
||||
|
@ -49,10 +49,10 @@ class TransformerBlock:
|
|||
weights = score.softmax() # (bs, num_heads, T, T)
|
||||
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
||||
|
||||
x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1)
|
||||
x = layernorm(x, embed_dim).affine(self.ln1)
|
||||
x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1)
|
||||
x = layernorm(x, embed_dim).affine(self.ln2)
|
||||
x = inputs + attention.reshape(shape=(-1, embed_dim)).linear(self.final).dropout(0.1)
|
||||
x = layernorm(x, embed_dim).linear(self.ln1)
|
||||
x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1)
|
||||
x = layernorm(x, embed_dim).linear(self.ln2)
|
||||
return x.reshape(shape=(bs, -1, embed_dim))
|
||||
|
||||
class Transformer:
|
||||
|
|
|
@ -56,14 +56,14 @@ class TestNN(unittest.TestCase):
|
|||
def _test_linear(x):
|
||||
|
||||
# create in tinygrad
|
||||
layer = Linear(in_dim, out_dim)
|
||||
z = layer(x)
|
||||
layer = (Tensor.uniform(in_dim, out_dim), Tensor.zeros(out_dim))
|
||||
z = x.linear(layer)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.data.T, dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
|
||||
torch_layer.weight[:] = torch.tensor(layer[0].data.T, dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer[1].data, dtype=torch.float32)
|
||||
torch_x = torch.tensor(x.cpu().data, dtype=torch.float32)
|
||||
torch_z = torch_layer(torch_x)
|
||||
|
||||
|
|
|
@ -31,35 +31,6 @@ class BatchNorm2D:
|
|||
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1])
|
||||
return x.div(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**0.5) + self.bias.reshape(shape=[1, -1, 1, 1])
|
||||
|
||||
class Linear:
|
||||
def __init__(self, in_dim, out_dim, bias=True):
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.use_bias = bias
|
||||
self.weight = Tensor.uniform(in_dim, out_dim)
|
||||
if self.use_bias:
|
||||
self.bias = Tensor.zeros(out_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
B, *dims, D = x.shape
|
||||
x = x.reshape(shape=(B * np.prod(dims).astype(np.int32), D))
|
||||
x = x.dot(self.weight)
|
||||
if self.use_bias:
|
||||
x = x.add(self.bias.reshape(shape=[1, -1]))
|
||||
x = x.reshape(shape=(B, *dims, -1))
|
||||
return x
|
||||
|
||||
class Dropout:
|
||||
def __init__(self, p=0.5):
|
||||
self.p = p
|
||||
|
||||
def __call__(self, x):
|
||||
return x.dropout(p=self.p)
|
||||
|
||||
class Identity:
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
class Conv2d:
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
|
||||
self.out_channels = out_channels
|
||||
|
@ -79,11 +50,3 @@ class Conv2d:
|
|||
x = x.add(self.bias.reshape(shape=(1, -1, 1, 1)))
|
||||
return x
|
||||
|
||||
class Sequential:
|
||||
def __init__(self, *layers):
|
||||
self.layers = layers
|
||||
|
||||
def __call__(self, x):
|
||||
for l in self.layers:
|
||||
x = l(x)
|
||||
return x
|
||||
|
|
|
@ -304,12 +304,18 @@ class Tensor:
|
|||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
return self._pool2d(*kernel_size).max(axis=(3,5))
|
||||
|
||||
def affine(self, params):
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, params):
|
||||
shp = [1] * (len(self.shape)-1) + [-1]
|
||||
if len(params[0].shape) == 1: # elementwise affine
|
||||
return self.mul(params[0].reshape(shape=shp)).add(params[1].reshape(shape=shp))
|
||||
else:
|
||||
return self.dot(params[0]).add(params[1].reshape(shape=shp))
|
||||
ret = self.mul(params[0].reshape(shape=shp)) if len(params[0].shape) == 1 else self.dot(params[0])
|
||||
return ret.add(params[1].reshape(shape=shp))
|
||||
|
||||
def sequential(self, ll):
|
||||
ret = self
|
||||
for l in ll:
|
||||
ret = l(ret)
|
||||
return ret
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
|
|
Loading…
Reference in New Issue