diff --git a/examples/vit.py b/examples/vit.py index 5a714bfe..54c297f8 100644 --- a/examples/vit.py +++ b/examples/vit.py @@ -15,13 +15,9 @@ from extra.utils import fetch from tinygrad.tensor import Tensor def layernorm(x, sz, eps=1e-5): - in_shape = x.shape - x = x.reshape(shape=(-1, sz)) - layer_mean = x.mean(axis=(-1,)).reshape(shape=[-1, 1]) - y = (x - layer_mean) - layer_var = (y*y).mean(axis=(-1,)) - ret = y.div(layer_var.add(eps).reshape(shape=[-1, 1]).sqrt()) - return ret.reshape(shape=in_shape) + y = (x - x.mean(axis=-1, keepdim=True)) + layer_var = (y*y).mean(axis=-1, keepdim=True) + return y.div(layer_var.add(eps).sqrt()) class ViTBlock: def __init__(self, embed_dim, num_heads, ff_dim): @@ -46,7 +42,7 @@ class ViTBlock: def attn(self, x, bs): embed_dim = self.num_heads * self.head_size - query, key, value = [x.affine(y) \ + query, key, value = [x.linear(y) \ .reshape(shape=(bs, -1, self.num_heads, self.head_size)) \ for y in [self.query_dense, self.key_dense, self.value_dense]] @@ -58,7 +54,7 @@ class ViTBlock: weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) - return attention.reshape(shape=(-1, embed_dim)).affine(self.final) + return attention.reshape(shape=(-1, embed_dim)).linear(self.final) def __call__(self, x): # bs x T x embed_dim @@ -67,11 +63,11 @@ class ViTBlock: inputs = x.reshape(shape=(-1, embed_dim)) # run multi head attention (bs, T, num_heads, head_size) - x = layernorm(inputs, embed_dim).affine(self.ln1) + x = layernorm(inputs, embed_dim).linear(self.ln1) x = inputs + self.attn(x, bs).dropout(0.1) - xin = layernorm(x, embed_dim).affine(self.ln2) - x = x + xin.affine(self.ff1).gelu().affine(self.ff2).dropout(0.1) + xin = layernorm(x, embed_dim).linear(self.ln2) + x = x + xin.linear(self.ff1).gelu().linear(self.ff2).dropout(0.1) return x.reshape(shape=(bs, -1, embed_dim)) class ViT: @@ -96,8 +92,8 @@ class ViT: x = self.cls_token.cat(pe, dim=1) + self.pos_embed for l in self.tbs: x = l(x) - x = layernorm(x, x.shape[-1]).affine(self.norm) - return x[:, 0].affine(self.head) + x = layernorm(x, x.shape[-1]).linear(self.norm) + return x[:, 0].linear(self.head) Tensor.training = False m = ViT() @@ -171,7 +167,7 @@ mdl = vit_tiny_patch16_224(pretrained=True) pe = m.patch_embed(Tensor(img)) x = m.cls_token.cat(pe, dim=1) + m.pos_embed x = m.tbs[0](x) -#x = layernorm(x, 192).affine(m.tbs[0].ln1) +#x = layernorm(x, 192).linear(m.tbs[0].ln1) xp = mdl.patch_embed(torch.Tensor(img)) xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed diff --git a/extra/utils.py b/extra/utils.py index 5ac04a98..01ea32cf 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -25,13 +25,8 @@ def get_parameters(obj): for x in obj: parameters.extend(get_parameters(x)) elif hasattr(obj, '__dict__'): - if isinstance(obj, nn.Sequential): - for layer in obj.layers: - for v in layer.__dict__.values(): - parameters.extend(get_parameters(v)) - else: - for v in obj.__dict__.values(): - parameters.extend(get_parameters(v)) + for v in obj.__dict__.values(): + parameters.extend(get_parameters(v)) return parameters def my_unpickle(fb0): diff --git a/models/resnet.py b/models/resnet.py index dd41e4e0..5a3e623d 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -47,17 +47,17 @@ class BasicBlock: self.bn1 = nn.BatchNorm2D(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False) self.bn2 = nn.BatchNorm2D(planes) - self.downsample = nn.Sequential() + self.downsample = [] if stride != 1 or in_planes != self.expansion*planes: - self.downsample = nn.Sequential( + self.downsample = [ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2D(self.expansion*planes) - ) + ] def __call__(self, x): out = self.bn1(self.conv1(x)).relu() out = self.bn2(self.conv2(out)) - out = out + self.downsample(x) + out = out + x.sequential(self.downsample) out = out.relu() return out @@ -72,12 +72,12 @@ class Bottleneck: self.bn2 = nn.BatchNorm2D(planes) self.conv3 = nn.Conv2d(planes, self.expansion *planes, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2D(self.expansion*planes) - self.downsample = nn.Sequential() + self.downsample = [] if stride != 1 or in_planes != self.expansion*planes: - self.downsample = nn.Sequential( + self.downsample = [ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2D(self.expansion*planes) - ) + ] def __call__(self, x): out = self.bn1(self.conv1(x)).relu() @@ -105,14 +105,14 @@ class ResNet: for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion - return nn.Sequential(*layers) + return layers def forward(self, x): out = self.bn1(self.conv1(x)).relu() - out = self.layer1(out) - out = self.layer2(out) - out = self.layer3(out) - out = self.layer4(out) + out = out.sequential(self.layer1) + out = out.sequential(self.layer2) + out = out.sequential(self.layer3) + out = out.sequential(self.layer4) out = out.mean(3).mean(2) out = self.fc(out).logsoftmax() return out diff --git a/models/transformer.py b/models/transformer.py index 3f75feb4..f56a77d9 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -37,7 +37,7 @@ class TransformerBlock: inputs = x.reshape(shape=(-1, embed_dim)) # run multi head attention (bs, T, num_heads, head_size) - query, key, value = [inputs.affine(y) \ + query, key, value = [inputs.linear(y) \ .reshape(shape=(bs, -1, self.num_heads, self.head_size)) \ for y in [self.query_dense, self.key_dense, self.value_dense]] @@ -49,10 +49,10 @@ class TransformerBlock: weights = score.softmax() # (bs, num_heads, T, T) attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size) - x = inputs + attention.reshape(shape=(-1, embed_dim)).affine(self.final).dropout(0.1) - x = layernorm(x, embed_dim).affine(self.ln1) - x = x + x.affine(self.ff1).relu().affine(self.ff2).dropout(0.1) - x = layernorm(x, embed_dim).affine(self.ln2) + x = inputs + attention.reshape(shape=(-1, embed_dim)).linear(self.final).dropout(0.1) + x = layernorm(x, embed_dim).linear(self.ln1) + x = x + x.linear(self.ff1).relu().linear(self.ff2).dropout(0.1) + x = layernorm(x, embed_dim).linear(self.ln2) return x.reshape(shape=(bs, -1, embed_dim)) class Transformer: diff --git a/test/test_nn.py b/test/test_nn.py index b5c1aee1..36c5be9c 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -56,14 +56,14 @@ class TestNN(unittest.TestCase): def _test_linear(x): # create in tinygrad - layer = Linear(in_dim, out_dim) - z = layer(x) + layer = (Tensor.uniform(in_dim, out_dim), Tensor.zeros(out_dim)) + z = x.linear(layer) # create in torch with torch.no_grad(): torch_layer = torch.nn.Linear(in_dim, out_dim).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.data.T, dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.data, dtype=torch.float32) + torch_layer.weight[:] = torch.tensor(layer[0].data.T, dtype=torch.float32) + torch_layer.bias[:] = torch.tensor(layer[1].data, dtype=torch.float32) torch_x = torch.tensor(x.cpu().data, dtype=torch.float32) torch_z = torch_layer(torch_x) diff --git a/tinygrad/nn.py b/tinygrad/nn.py index 2ddb7647..ea9508dc 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -31,35 +31,6 @@ class BatchNorm2D: x = (x - mean.reshape(shape=[1, -1, 1, 1])) * self.weight.reshape(shape=[1, -1, 1, 1]) return x.div(var.add(self.eps).reshape(shape=[1, -1, 1, 1])**0.5) + self.bias.reshape(shape=[1, -1, 1, 1]) -class Linear: - def __init__(self, in_dim, out_dim, bias=True): - self.in_dim = in_dim - self.out_dim = out_dim - self.use_bias = bias - self.weight = Tensor.uniform(in_dim, out_dim) - if self.use_bias: - self.bias = Tensor.zeros(out_dim) - - def __call__(self, x): - B, *dims, D = x.shape - x = x.reshape(shape=(B * np.prod(dims).astype(np.int32), D)) - x = x.dot(self.weight) - if self.use_bias: - x = x.add(self.bias.reshape(shape=[1, -1])) - x = x.reshape(shape=(B, *dims, -1)) - return x - -class Dropout: - def __init__(self, p=0.5): - self.p = p - - def __call__(self, x): - return x.dropout(p=self.p) - -class Identity: - def __call__(self, x): - return x - class Conv2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): self.out_channels = out_channels @@ -79,11 +50,3 @@ class Conv2d: x = x.add(self.bias.reshape(shape=(1, -1, 1, 1))) return x -class Sequential: - def __init__(self, *layers): - self.layers = layers - - def __call__(self, x): - for l in self.layers: - x = l(x) - return x diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4ebadd0d..45bf7136 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -304,12 +304,18 @@ class Tensor: def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).max(axis=(3,5)) - def affine(self, params): + # ***** functional nn ops ***** + + def linear(self, params): shp = [1] * (len(self.shape)-1) + [-1] - if len(params[0].shape) == 1: # elementwise affine - return self.mul(params[0].reshape(shape=shp)).add(params[1].reshape(shape=shp)) - else: - return self.dot(params[0]).add(params[1].reshape(shape=shp)) + ret = self.mul(params[0].reshape(shape=shp)) if len(params[0].shape) == 1 else self.dot(params[0]) + return ret.add(params[1].reshape(shape=shp)) + + def sequential(self, ll): + ret = self + for l in ll: + ret = l(ret) + return ret # An instantiation of the Function is the Context class Function: