import random import functools from pathlib import Path import numpy as np import nibabel as nib from scipy import signal, ndimage import os import torch import torch.nn.functional as F from tqdm import tqdm from tinygrad.tensor import Tensor from tinygrad.helpers import fetch BASEDIR = Path(__file__).parent / "kits19" / "data" TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train" VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val" @functools.lru_cache(None) def get_train_files(): return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()]) @functools.lru_cache(None) def get_val_files(): data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/retired_benchmarks/unet3d/pytorch/evaluation_cases.txt").read_text() return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")]) def load_pair(file_path): image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz") image_spacings = image.header["pixdim"][1:4].tolist() image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8) image, label = np.expand_dims(image, 0), np.expand_dims(label, 0) return image, label, image_spacings def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)): if image_spacings != target_spacing: spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:]) new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist() image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True) label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest") image = np.squeeze(image.numpy(), axis=0) label = np.squeeze(label.numpy(), axis=0) return image, label def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9): image = np.clip(image, min_clip, max_clip) image = (image - mean) / std return image def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)): current_shape = image.shape[1:] bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)] paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)] image = np.pad(image, paddings, mode="edge") label = np.pad(label, paddings, mode="edge") return image, label def preprocess(file_path): image, label, image_spacings = load_pair(file_path) image, label = resample3d(image, label, image_spacings) image = normal_intensity(image.copy()) image, label = pad_to_min_shape(image, label) return image, label def preprocess_dataset(filenames, preprocessed_dir, val): if not preprocessed_dir.is_dir(): os.makedirs(preprocessed_dir) for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"): case = os.path.basename(fn) image, label = preprocess(fn) image, label = image.astype(np.float32), label.astype(np.uint8) np.save(preprocessed_dir / f"{case}_x.npy", image, allow_pickle=False) np.save(preprocessed_dir / f"{case}_y.npy", label, allow_pickle=False) def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1): order = list(range(0, len(files))) if shuffle: random.shuffle(order) for i in range(0, len(files), bs): samples = [] for i in order[i:i+bs]: if preprocessed_dir is not None: x_cached_path, y_cached_path = preprocessed_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dir / f"{os.path.basename(files[i])}_y.npy" if x_cached_path.exists() and y_cached_path.exists(): samples += [(np.load(x_cached_path), np.load(y_cached_path))] else: samples += [preprocess(files[i])] X, Y = [x[0] for x in samples], [x[1] for x in samples] if val: yield X[0][None], Y[0] else: X_preprocessed, Y_preprocessed = [], [] for x, y in zip(X, Y): x, y = rand_balanced_crop(x, y) x, y = rand_flip(x, y) x, y = x.astype(np.float32), y.astype(np.uint8) x = random_brightness_augmentation(x) x = gaussian_noise(x) X_preprocessed.append(x) Y_preprocessed.append(y) yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0) def gaussian_kernel(n, std): gaussian_1d = signal.windows.gaussian(n, std) gaussian_2d = np.outer(gaussian_1d, gaussian_1d) gaussian_3d = np.outer(gaussian_2d, gaussian_1d) gaussian_3d = gaussian_3d.reshape(n, n, n) gaussian_3d = np.cbrt(gaussian_3d) gaussian_3d /= gaussian_3d.max() return gaussian_3d def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3): bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)] bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)] paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0] return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5, gpus=None): from tinygrad.engine.jit import TinyJit mdl_run = TinyJit(lambda x: model(x).realize()) image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:]) strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)] bounds = [image_shape[i] % strides[i] for i in range(dim)] bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)] inputs = inputs[ ..., bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2), bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2), bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2), ] labels = labels[ ..., bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2), bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2), bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2), ] inputs, paddings = pad_input(inputs, roi_shape, strides) padded_shape = inputs.shape[2:] size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)] result = np.zeros((1, 3, *padded_shape), dtype=np.float32) norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32) norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0]) norm_patch = np.expand_dims(norm_patch, axis=0) for i in range(0, strides[0] * size[0], strides[0]): for j in range(0, strides[1] * size[1], strides[1]): for k in range(0, strides[2] * size[2], strides[2]): out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k], device=gpus)).numpy() result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch result /= norm_map result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]] return result, labels def rand_flip(image, label, axis=(1, 2, 3)): prob = 1 / len(axis) for ax in axis: if random.random() < prob: image = np.flip(image, axis=ax).copy() label = np.flip(label, axis=ax).copy() return image, label def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1): if random.random() < prob: factor = np.random.uniform(low=low, high=high, size=1) image = (image * (1 + factor)).astype(image.dtype) return image def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1): if random.random() < prob: scale = np.random.uniform(low=0.0, high=std) noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype) image += noise return image def _rand_foreg_cropb(image, label, patch_size): def adjust(foreg_slice, label, idx): diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start) sign = -1 if diff < 0 else 1 diff = abs(diff) ladj = 0 if diff == 0 else random.randrange(diff) hadj = diff - ladj low = max(0, foreg_slice[idx].start - sign * ladj) high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj) diff = patch_size[idx - 1] - (high - low) if diff > 0 and low == 0: high += diff elif diff > 0: low -= diff return low, high cl = np.random.choice(np.unique(label[label > 0])) foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0]) foreg_slices = [x for x in foreg_slices if x is not None] slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices] slice_idx = np.argsort(slice_volumes)[-2:] foreg_slices = [foreg_slices[i] for i in slice_idx] if not foreg_slices: return _rand_crop(image, label) foreg_slice = foreg_slices[random.randrange(len(foreg_slices))] low_x, high_x = adjust(foreg_slice, label, 1) low_y, high_y = adjust(foreg_slice, label, 2) low_z, high_z = adjust(foreg_slice, label, 3) image = image[:, low_x:high_x, low_y:high_y, low_z:high_z] label = label[:, low_x:high_x, low_y:high_y, low_z:high_z] return image, label def _rand_crop(image, label, patch_size): ranges = [s - p for s, p in zip(image.shape[1:], patch_size)] cord = [0 if x == 0 else random.randrange(x) for x in ranges] low_x, high_x = cord[0], cord[0] + patch_size[0] low_y, high_y = cord[1], cord[1] + patch_size[1] low_z, high_z = cord[2], cord[2] + patch_size[2] image = image[:, low_x:high_x, low_y:high_y, low_z:high_z] label = label[:, low_x:high_x, low_y:high_y, low_z:high_z] return image, label def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4): if random.random() < oversampling: image, label = _rand_foreg_cropb(image, label, patch_size) else: image, label = _rand_crop(image, label, patch_size) return image, label if __name__ == "__main__": for X, Y in iterate(get_val_files()): print(X.shape, Y.shape)