2023-05-29 11:38:19 +08:00
|
|
|
import random
|
|
|
|
import functools
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
import nibabel as nib
|
2024-04-29 10:34:18 +08:00
|
|
|
from scipy import signal, ndimage
|
|
|
|
import os
|
2023-05-29 11:38:19 +08:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2024-04-29 10:34:18 +08:00
|
|
|
from tqdm import tqdm
|
2023-05-29 11:38:19 +08:00
|
|
|
from tinygrad.tensor import Tensor
|
2023-12-02 02:42:50 +08:00
|
|
|
from tinygrad.helpers import fetch
|
2023-05-29 11:38:19 +08:00
|
|
|
|
2023-07-08 09:41:58 +08:00
|
|
|
BASEDIR = Path(__file__).parent / "kits19" / "data"
|
2024-07-31 00:35:26 +08:00
|
|
|
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
|
|
|
|
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
|
2023-05-29 11:38:19 +08:00
|
|
|
|
2024-04-29 10:34:18 +08:00
|
|
|
@functools.lru_cache(None)
|
|
|
|
def get_train_files():
|
|
|
|
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
|
|
|
|
|
2023-05-29 11:38:19 +08:00
|
|
|
@functools.lru_cache(None)
|
|
|
|
def get_val_files():
|
2024-08-20 20:43:13 +08:00
|
|
|
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/retired_benchmarks/unet3d/pytorch/evaluation_cases.txt").read_text()
|
2023-12-02 02:42:50 +08:00
|
|
|
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
|
2023-05-29 11:38:19 +08:00
|
|
|
|
|
|
|
def load_pair(file_path):
|
|
|
|
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
|
|
|
|
image_spacings = image.header["pixdim"][1:4].tolist()
|
|
|
|
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
|
|
|
|
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
|
|
|
|
return image, label, image_spacings
|
|
|
|
|
|
|
|
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
|
|
|
|
if image_spacings != target_spacing:
|
|
|
|
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
|
|
|
|
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
|
|
|
|
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
|
|
|
|
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
|
|
|
|
image = np.squeeze(image.numpy(), axis=0)
|
|
|
|
label = np.squeeze(label.numpy(), axis=0)
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
|
|
|
|
image = np.clip(image, min_clip, max_clip)
|
|
|
|
image = (image - mean) / std
|
|
|
|
return image
|
|
|
|
|
|
|
|
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
|
|
|
|
current_shape = image.shape[1:]
|
|
|
|
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
|
|
|
|
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
|
|
|
|
image = np.pad(image, paddings, mode="edge")
|
|
|
|
label = np.pad(label, paddings, mode="edge")
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
def preprocess(file_path):
|
|
|
|
image, label, image_spacings = load_pair(file_path)
|
|
|
|
image, label = resample3d(image, label, image_spacings)
|
|
|
|
image = normal_intensity(image.copy())
|
|
|
|
image, label = pad_to_min_shape(image, label)
|
|
|
|
return image, label
|
|
|
|
|
2024-04-29 10:34:18 +08:00
|
|
|
def preprocess_dataset(filenames, preprocessed_dir, val):
|
2024-07-31 00:35:26 +08:00
|
|
|
if not preprocessed_dir.is_dir(): os.makedirs(preprocessed_dir)
|
2024-04-29 10:34:18 +08:00
|
|
|
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
|
|
|
|
case = os.path.basename(fn)
|
|
|
|
image, label = preprocess(fn)
|
|
|
|
image, label = image.astype(np.float32), label.astype(np.uint8)
|
2024-07-31 00:35:26 +08:00
|
|
|
np.save(preprocessed_dir / f"{case}_x.npy", image, allow_pickle=False)
|
|
|
|
np.save(preprocessed_dir / f"{case}_y.npy", label, allow_pickle=False)
|
2024-04-29 10:34:18 +08:00
|
|
|
|
|
|
|
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
|
2023-05-29 11:38:19 +08:00
|
|
|
order = list(range(0, len(files)))
|
|
|
|
if shuffle: random.shuffle(order)
|
2024-04-29 10:34:18 +08:00
|
|
|
for i in range(0, len(files), bs):
|
|
|
|
samples = []
|
|
|
|
for i in order[i:i+bs]:
|
2024-07-31 00:35:26 +08:00
|
|
|
if preprocessed_dir is not None:
|
|
|
|
x_cached_path, y_cached_path = preprocessed_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dir / f"{os.path.basename(files[i])}_y.npy"
|
2024-04-29 10:34:18 +08:00
|
|
|
if x_cached_path.exists() and y_cached_path.exists():
|
|
|
|
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
|
|
|
|
else: samples += [preprocess(files[i])]
|
|
|
|
X, Y = [x[0] for x in samples], [x[1] for x in samples]
|
|
|
|
if val:
|
|
|
|
yield X[0][None], Y[0]
|
|
|
|
else:
|
|
|
|
X_preprocessed, Y_preprocessed = [], []
|
|
|
|
for x, y in zip(X, Y):
|
|
|
|
x, y = rand_balanced_crop(x, y)
|
|
|
|
x, y = rand_flip(x, y)
|
|
|
|
x, y = x.astype(np.float32), y.astype(np.uint8)
|
|
|
|
x = random_brightness_augmentation(x)
|
|
|
|
x = gaussian_noise(x)
|
|
|
|
X_preprocessed.append(x)
|
|
|
|
Y_preprocessed.append(y)
|
|
|
|
yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0)
|
2023-05-29 11:38:19 +08:00
|
|
|
|
|
|
|
def gaussian_kernel(n, std):
|
2024-04-18 07:15:37 +08:00
|
|
|
gaussian_1d = signal.windows.gaussian(n, std)
|
2023-05-29 11:38:19 +08:00
|
|
|
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
|
|
|
|
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
|
|
|
|
gaussian_3d = gaussian_3d.reshape(n, n, n)
|
|
|
|
gaussian_3d = np.cbrt(gaussian_3d)
|
|
|
|
gaussian_3d /= gaussian_3d.max()
|
|
|
|
return gaussian_3d
|
|
|
|
|
|
|
|
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
|
|
|
|
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
|
|
|
|
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
|
|
|
|
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
|
|
|
|
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
|
|
|
|
|
2024-09-10 16:37:28 +08:00
|
|
|
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5, gpus=None):
|
2024-03-27 11:38:03 +08:00
|
|
|
from tinygrad.engine.jit import TinyJit
|
2023-05-29 11:38:19 +08:00
|
|
|
mdl_run = TinyJit(lambda x: model(x).realize())
|
|
|
|
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
|
|
|
|
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
|
|
|
|
bounds = [image_shape[i] % strides[i] for i in range(dim)]
|
|
|
|
bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
|
|
|
|
inputs = inputs[
|
|
|
|
...,
|
|
|
|
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
|
|
|
|
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
|
|
|
|
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
|
|
|
|
]
|
|
|
|
labels = labels[
|
|
|
|
...,
|
|
|
|
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
|
|
|
|
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
|
|
|
|
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
|
|
|
|
]
|
|
|
|
inputs, paddings = pad_input(inputs, roi_shape, strides)
|
|
|
|
padded_shape = inputs.shape[2:]
|
|
|
|
size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
|
|
|
|
result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
|
|
|
norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
|
|
|
|
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
|
|
|
|
norm_patch = np.expand_dims(norm_patch, axis=0)
|
|
|
|
for i in range(0, strides[0] * size[0], strides[0]):
|
|
|
|
for j in range(0, strides[1] * size[1], strides[1]):
|
|
|
|
for k in range(0, strides[2] * size[2], strides[2]):
|
2024-09-10 16:37:28 +08:00
|
|
|
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k], device=gpus)).numpy()
|
2023-05-29 11:38:19 +08:00
|
|
|
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
|
|
|
|
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
|
|
|
|
result /= norm_map
|
|
|
|
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
|
|
|
|
return result, labels
|
|
|
|
|
2024-04-29 10:34:18 +08:00
|
|
|
def rand_flip(image, label, axis=(1, 2, 3)):
|
|
|
|
prob = 1 / len(axis)
|
|
|
|
for ax in axis:
|
|
|
|
if random.random() < prob:
|
|
|
|
image = np.flip(image, axis=ax).copy()
|
|
|
|
label = np.flip(label, axis=ax).copy()
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1):
|
|
|
|
if random.random() < prob:
|
|
|
|
factor = np.random.uniform(low=low, high=high, size=1)
|
|
|
|
image = (image * (1 + factor)).astype(image.dtype)
|
|
|
|
return image
|
|
|
|
|
|
|
|
def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1):
|
|
|
|
if random.random() < prob:
|
|
|
|
scale = np.random.uniform(low=0.0, high=std)
|
|
|
|
noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype)
|
|
|
|
image += noise
|
|
|
|
return image
|
|
|
|
|
|
|
|
def _rand_foreg_cropb(image, label, patch_size):
|
|
|
|
def adjust(foreg_slice, label, idx):
|
|
|
|
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
|
|
|
|
sign = -1 if diff < 0 else 1
|
|
|
|
diff = abs(diff)
|
|
|
|
ladj = 0 if diff == 0 else random.randrange(diff)
|
|
|
|
hadj = diff - ladj
|
|
|
|
low = max(0, foreg_slice[idx].start - sign * ladj)
|
|
|
|
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
|
|
|
|
diff = patch_size[idx - 1] - (high - low)
|
|
|
|
if diff > 0 and low == 0: high += diff
|
|
|
|
elif diff > 0: low -= diff
|
|
|
|
return low, high
|
|
|
|
|
|
|
|
cl = np.random.choice(np.unique(label[label > 0]))
|
|
|
|
foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0])
|
|
|
|
foreg_slices = [x for x in foreg_slices if x is not None]
|
|
|
|
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
|
|
|
|
slice_idx = np.argsort(slice_volumes)[-2:]
|
|
|
|
foreg_slices = [foreg_slices[i] for i in slice_idx]
|
|
|
|
if not foreg_slices: return _rand_crop(image, label)
|
|
|
|
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
|
|
|
|
low_x, high_x = adjust(foreg_slice, label, 1)
|
|
|
|
low_y, high_y = adjust(foreg_slice, label, 2)
|
|
|
|
low_z, high_z = adjust(foreg_slice, label, 3)
|
|
|
|
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
|
|
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
def _rand_crop(image, label, patch_size):
|
|
|
|
ranges = [s - p for s, p in zip(image.shape[1:], patch_size)]
|
|
|
|
cord = [0 if x == 0 else random.randrange(x) for x in ranges]
|
|
|
|
low_x, high_x = cord[0], cord[0] + patch_size[0]
|
|
|
|
low_y, high_y = cord[1], cord[1] + patch_size[1]
|
|
|
|
low_z, high_z = cord[2], cord[2] + patch_size[2]
|
|
|
|
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
|
|
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
|
|
|
|
return image, label
|
|
|
|
|
|
|
|
def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4):
|
|
|
|
if random.random() < oversampling:
|
|
|
|
image, label = _rand_foreg_cropb(image, label, patch_size)
|
|
|
|
else:
|
|
|
|
image, label = _rand_crop(image, label, patch_size)
|
|
|
|
return image, label
|
|
|
|
|
2023-05-29 11:38:19 +08:00
|
|
|
if __name__ == "__main__":
|
2024-04-29 10:34:18 +08:00
|
|
|
for X, Y in iterate(get_val_files()):
|
2023-05-29 11:38:19 +08:00
|
|
|
print(X.shape, Y.shape)
|