remove dumb nn ops

This commit is contained in:
George Hotz 2021-11-29 18:05:31 -05:00
parent 33720e733f
commit dca076dbf1
7 changed files with 45 additions and 85 deletions

View File

@ -15,13 +15,9 @@ from extra.utils import fetch
from tinygrad.tensor import Tensor
def layernorm(x, sz, eps=1e-5):
in_shape = x.shape
x = x.reshape(shape=(-1, sz))
layer_mean = x.mean(axis=(-1,)).reshape(shape=[-1, 1])
y = (x - layer_mean)
layer_var = (y*y).mean(axis=(-1,))
ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt())
return ret.reshape(shape=in_shape)
y = (x - x.mean(axis=-1, keepdim=True))
layer_var = (y*y).mean(axis=-1, keepdim=True)
return y.div(layer_var.add(eps).sqrt())
class ViTBlock:
def __init__(self, embed_dim, num_heads, ff_dim):
@ -46,7 +42,7 @@ class ViTBlock:
def attn(self, x, bs):
embed_dim = self.num_heads * self.head_size
query, key, value = [x.affine(y) \
query, key, value = [x.linear(y) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
@ -58,7 +54,7 @@ class ViTBlock:
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
return attention.reshape(shape=(-1, embed_dim)).affine(self.final)
return attention.reshape(shape=(-1, embed_dim)).linear(self.final)
def __call__(self, x):
# bs x T x embed_dim
@ -67,11 +63,11 @@ class ViTBlock:
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
x = layernorm(inputs, embed_dim).affine(self.ln1)
x = layernorm(inputs, embed_dim).linear(self.ln1)
x = inputs + self.attn(x, bs).dropout(0.1)
xin = layernorm(x, embed_dim).affine(self.ln2)
x = x + xin.affine(self.ff1).gelu().affine(self.ff2).dropout(0.1)
xin = layernorm(x, embed_dim).linear(self.ln2)
x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
return x.reshape(shape=(bs, -1, embed_dim))
class ViT:
@ -96,8 +92,8 @@ class ViT:
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
for l in self.tbs:
x = l(x)
x = layernorm(x, x.shape[-1]).affine(self.norm)
return x[:, 0].affine(self.head)
x = layernorm(x, x.shape[-1]).linear(self.norm)
return x[:, 0].linear(self.head)
Tensor.training = False
m = ViT()
@ -171,7 +167,7 @@ mdl = vit_tiny_patch16_224(pretrained=True)
pe = m.patch_embed(Tensor(img))
x = m.cls_token.cat(pe, dim=1) + m.pos_embed
x = m.tbs[0](x)
#x = layernorm(x, 192).affine(m.tbs[0].ln1)
#x = layernorm(x, 192).linear(m.tbs[0].ln1)
xp = mdl.patch_embed(torch.Tensor(img))
xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed

View File

@ -25,11 +25,6 @@ def get_parameters(obj):
for x in obj:
parameters.extend(get_parameters(x))
elif hasattr(obj, '__dict__'):
if isinstance(obj, nn.Sequential):
for layer in obj.layers:
for v in layer.__dict__.values():
parameters.extend(get_parameters(v))
else:
for v in obj.__dict__.values():
parameters.extend(get_parameters(v))
return parameters

View File

@ -47,17 +47,17 @@ class BasicBlock:
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = nn.Sequential()
self.downsample = []
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = nn.Sequential(
self.downsample = [
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2D(self.expansion*planes)
)
]
def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.bn2(self.conv2(out))
out = out + self.downsample(x)
out = out + x.sequential(self.downsample)
out = out.relu()
return out
@ -72,12 +72,12 @@ class Bottleneck:
self.bn2 = nn.BatchNorm2D(planes)
self.conv3 = nn.Conv2d(planes, self.expansion *planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2D(self.expansion*planes)
self.downsample = nn.Sequential()
self.downsample = []
if stride != 1 or in_planes != self.expansion*planes:
self.downsample = nn.Sequential(
self.downsample = [
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2D(self.expansion*planes)
)
]
def __call__(self, x):
out = self.bn1(self.conv1(x)).relu()
@ -105,14 +105,14 @@ class ResNet:
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
return layers
def forward(self, x):
out = self.bn1(self.conv1(x)).relu()
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = out.sequential(self.layer1)
out = out.sequential(self.layer2)
out = out.sequential(self.layer3)
out = out.sequential(self.layer4)
out = out.mean(3).mean(2)
out = self.fc(out).logsoftmax()
return out

View File

@ -37,7 +37,7 @@ class TransformerBlock:
inputs = x.reshape(shape=(-1, embed_dim))
# run multi head attention (bs, T, num_heads, head_size)
query, key, value = [inputs.affine(y) \
query, key, value = [inputs.linear(y) \
.reshape(shape=(bs, -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
@ -49,10 +49,10 @@ class TransformerBlock:
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1)
x = layernorm(x, embed_dim).affine(self.ln1)
x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1)
x = layernorm(x, embed_dim).affine(self.ln2)
x = inputs + attention.reshape(shape=(-1, embed_dim)).linear(self.final).dropout(0.1)
x = layernorm(x, embed_dim).linear(self.ln1)
x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1)
x = layernorm(x, embed_dim).linear(self.ln2)
return x.reshape(shape=(bs, -1, embed_dim))
class Transformer:

View File

@ -56,14 +56,14 @@ class TestNN(unittest.TestCase):
def _test_linear(x):
# create in tinygrad
layer = Linear(in_dim, out_dim)
z = layer(x)
layer = (Tensor.uniform(in_dim, out_dim), Tensor.zeros(out_dim))
z = x.linear(layer)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.data.T, dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32)
torch_layer.weight[:] = torch.tensor(layer[0].data.T, dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer[1].data, dtype=torch.float32)
torch_x = torch.tensor(x.cpu().data, dtype=torch.float32)
torch_z = torch_layer(torch_x)

View File

@ -31,35 +31,6 @@ class BatchNorm2D:
x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1])
return x.div(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**0.5) + self.bias.reshape(shape=[1, -1, 1, 1])
class Linear:
def __init__(self, in_dim, out_dim, bias=True):
self.in_dim = in_dim
self.out_dim = out_dim
self.use_bias = bias
self.weight = Tensor.uniform(in_dim, out_dim)
if self.use_bias:
self.bias = Tensor.zeros(out_dim)
def __call__(self, x):
B, *dims, D = x.shape
x = x.reshape(shape=(B * np.prod(dims).astype(np.int32), D))
x = x.dot(self.weight)
if self.use_bias:
x = x.add(self.bias.reshape(shape=[1, -1]))
x = x.reshape(shape=(B, *dims, -1))
return x
class Dropout:
def __init__(self, p=0.5):
self.p = p
def __call__(self, x):
return x.dropout(p=self.p)
class Identity:
def __call__(self, x):
return x
class Conv2d:
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
self.out_channels = out_channels
@ -79,11 +50,3 @@ class Conv2d:
x = x.add(self.bias.reshape(shape=(1, -1, 1, 1)))
return x
class Sequential:
def __init__(self, *layers):
self.layers = layers
def __call__(self, x):
for l in self.layers:
x = l(x)
return x

View File

@ -304,12 +304,18 @@ class Tensor:
def max_pool2d(self, kernel_size=(2,2)):
return self._pool2d(*kernel_size).max(axis=(3,5))
def affine(self, params):
# ***** functional nn ops *****
def linear(self, params):
shp = [1] * (len(self.shape)-1) + [-1]
if len(params[0].shape) == 1: # elementwise affine
return self.mul(params[0].reshape(shape=shp)).add(params[1].reshape(shape=shp))
else:
return self.dot(params[0]).add(params[1].reshape(shape=shp))
ret = self.mul(params[0].reshape(shape=shp)) if len(params[0].shape) == 1 else self.dot(params[0])
return ret.add(params[1].reshape(shape=shp))
def sequential(self, ll):
ret = self
for l in ll:
ret = l(ret)
return ret
# An instantiation of the Function is the Context
class Function: