mirror of https://github.com/commaai/tinygrad.git
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
from tinygrad.tensor import Tensor
|
|
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
|
|
|
class Block:
|
|
def __init__(self, dim):
|
|
self.dwconv = Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
|
|
self.norm = LayerNorm(dim, eps=1e-6)
|
|
self.pwconv1 = Linear(dim, 4 * dim)
|
|
self.pwconv2 = Linear(4 * dim, dim)
|
|
self.gamma = Tensor.ones(dim)
|
|
|
|
def __call__(self, x:Tensor):
|
|
return x + x.sequential([
|
|
self.dwconv, lambda x: x.permute(0, 2, 3, 1), self.norm,
|
|
self.pwconv1, Tensor.gelu, self.pwconv2, lambda x: (self.gamma * x).permute(0, 3, 1, 2)
|
|
])
|
|
|
|
class ConvNeXt:
|
|
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
|
|
self.downsample_layers = [
|
|
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
|
|
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
|
]
|
|
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
|
|
self.norm = LayerNorm(dims[-1])
|
|
self.head = Linear(dims[-1], num_classes)
|
|
|
|
def __call__(self, x:Tensor):
|
|
for downsample, stage in zip(self.downsample_layers, self.stages):
|
|
x = x.sequential(downsample).sequential(stage)
|
|
return x.mean([-2, -1]).sequential([self.norm, self.head])
|
|
|
|
# *** model definition is done ***
|
|
|
|
versions = {
|
|
"tiny": {"depths": [3, 3, 9, 3], "dims": [96, 192, 384, 768]},
|
|
"small": {"depths": [3, 3, 27, 3], "dims": [96, 192, 384, 768]},
|
|
"base": {"depths": [3, 3, 9, 3], "dims": [128, 256, 512, 1024]},
|
|
"large": {"depths": [3, 3, 27, 3], "dims": [192, 384, 768, 1536]},
|
|
"xlarge": {"depths": [3, 3, 27, 3], "dims": [256, 512, 1024, 2048]}
|
|
}
|
|
|
|
def get_model(version, load_weights=False):
|
|
model = ConvNeXt(**versions[version])
|
|
if load_weights:
|
|
from extra.utils import fetch, fake_torch_load, get_child
|
|
weights = fake_torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model']
|
|
for k,v in weights.items():
|
|
mv = get_child(model, k)
|
|
mv.assign(v.reshape(mv.shape)).realize()
|
|
return model
|
|
|
|
if __name__ == "__main__":
|
|
model = get_model("tiny", True)
|
|
|
|
# load image
|
|
from test.models.test_efficientnet import chicken_img, preprocess, _LABELS
|
|
img = Tensor(preprocess(chicken_img))
|
|
|
|
Tensor.training = False
|
|
Tensor.no_grad = True
|
|
|
|
out = model(img).numpy()
|
|
print(_LABELS[out.argmax()])
|