mirror of https://github.com/commaai/tinygrad.git
60 lines
2.4 KiB
Python
60 lines
2.4 KiB
Python
from pathlib import Path
|
|
import torch
|
|
from tinygrad import nn
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import fetch, get_child
|
|
|
|
class DownsampleBlock:
|
|
def __init__(self, c0, c1, stride=2):
|
|
self.conv1 = [nn.Conv2d(c0, c1, kernel_size=(3,3,3), stride=stride, padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
|
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
|
|
|
def __call__(self, x):
|
|
return x.sequential(self.conv1).sequential(self.conv2)
|
|
|
|
class UpsampleBlock:
|
|
def __init__(self, c0, c1):
|
|
self.upsample_conv = [nn.ConvTranspose2d(c0, c1, kernel_size=(2,2,2), stride=2)]
|
|
self.conv1 = [nn.Conv2d(2 * c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
|
self.conv2 = [nn.Conv2d(c1, c1, kernel_size=(3,3,3), padding=(1,1,1,1,1,1), bias=False), nn.InstanceNorm(c1), Tensor.relu]
|
|
|
|
def __call__(self, x, skip):
|
|
x = x.sequential(self.upsample_conv)
|
|
x = Tensor.cat(x, skip, dim=1)
|
|
return x.sequential(self.conv1).sequential(self.conv2)
|
|
|
|
class UNet3D:
|
|
def __init__(self, in_channels=1, n_class=3):
|
|
filters = [32, 64, 128, 256, 320]
|
|
inp, out = filters[:-1], filters[1:]
|
|
self.input_block = DownsampleBlock(in_channels, filters[0], stride=1)
|
|
self.downsample = [DownsampleBlock(i, o) for i, o in zip(inp, out)]
|
|
self.bottleneck = DownsampleBlock(filters[-1], filters[-1])
|
|
self.upsample = [UpsampleBlock(filters[-1], filters[-1])] + [UpsampleBlock(i, o) for i, o in zip(out[::-1], inp[::-1])]
|
|
self.output = {"conv": nn.Conv2d(filters[0], n_class, kernel_size=(1, 1, 1))}
|
|
|
|
def __call__(self, x):
|
|
x = self.input_block(x)
|
|
outputs = [x]
|
|
for downsample in self.downsample:
|
|
x = downsample(x)
|
|
outputs.append(x)
|
|
x = self.bottleneck(x)
|
|
for upsample, skip in zip(self.upsample, outputs[::-1]):
|
|
x = upsample(x, skip)
|
|
x = self.output["conv"](x)
|
|
return x
|
|
|
|
def load_from_pretrained(self):
|
|
fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt"
|
|
fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn)
|
|
state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict()
|
|
for k, v in state_dict.items():
|
|
obj = get_child(self, k)
|
|
assert obj.shape == v.shape, (k, obj.shape, v.shape)
|
|
obj.assign(v.numpy())
|
|
|
|
if __name__ == "__main__":
|
|
mdl = UNet3D()
|
|
mdl.load_from_pretrained()
|