tinygrad/extra/datasets/kits19.py

220 lines
9.9 KiB
Python
Raw Normal View History

import random
import functools
from pathlib import Path
import numpy as np
import nibabel as nib
from scipy import signal, ndimage
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.helpers import fetch
2023-07-08 09:41:58 +08:00
BASEDIR = Path(__file__).parent / "kits19" / "data"
TRAIN_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "train"
VAL_PREPROCESSED_DIR = Path(__file__).parent / "kits19" / "preprocessed" / "val"
@functools.lru_cache(None)
def get_train_files():
return sorted([x for x in BASEDIR.iterdir() if x.stem.startswith("case") and int(x.stem.split("_")[-1]) < 210 and x not in get_val_files()])
@functools.lru_cache(None)
def get_val_files():
data = fetch("https://raw.githubusercontent.com/mlcommons/training/master/retired_benchmarks/unet3d/pytorch/evaluation_cases.txt").read_text()
return sorted([x for x in BASEDIR.iterdir() if x.stem.split("_")[-1] in data.split("\n")])
def load_pair(file_path):
image, label = nib.load(file_path / "imaging.nii.gz"), nib.load(file_path / "segmentation.nii.gz")
image_spacings = image.header["pixdim"][1:4].tolist()
image, label = image.get_fdata().astype(np.float32), label.get_fdata().astype(np.uint8)
image, label = np.expand_dims(image, 0), np.expand_dims(label, 0)
return image, label, image_spacings
def resample3d(image, label, image_spacings, target_spacing=(1.6, 1.2, 1.2)):
if image_spacings != target_spacing:
spc_arr, targ_arr, shp_arr = np.array(image_spacings), np.array(target_spacing), np.array(image.shape[1:])
new_shape = (spc_arr / targ_arr * shp_arr).astype(int).tolist()
image = F.interpolate(torch.from_numpy(np.expand_dims(image, axis=0)), size=new_shape, mode="trilinear", align_corners=True)
label = F.interpolate(torch.from_numpy(np.expand_dims(label, axis=0)), size=new_shape, mode="nearest")
image = np.squeeze(image.numpy(), axis=0)
label = np.squeeze(label.numpy(), axis=0)
return image, label
def normal_intensity(image, min_clip=-79.0, max_clip=304.0, mean=101.0, std=76.9):
image = np.clip(image, min_clip, max_clip)
image = (image - mean) / std
return image
def pad_to_min_shape(image, label, roi_shape=(128, 128, 128)):
current_shape = image.shape[1:]
bounds = [max(0, roi_shape[i] - current_shape[i]) for i in range(3)]
paddings = [(0, 0)] + [(bounds[i] // 2, bounds[i] - bounds[i] // 2) for i in range(3)]
image = np.pad(image, paddings, mode="edge")
label = np.pad(label, paddings, mode="edge")
return image, label
def preprocess(file_path):
image, label, image_spacings = load_pair(file_path)
image, label = resample3d(image, label, image_spacings)
image = normal_intensity(image.copy())
image, label = pad_to_min_shape(image, label)
return image, label
def preprocess_dataset(filenames, preprocessed_dir, val):
if not preprocessed_dir.is_dir(): os.makedirs(preprocessed_dir)
for fn in tqdm(filenames, desc=f"preprocessing {'validation' if val else 'training'}"):
case = os.path.basename(fn)
image, label = preprocess(fn)
image, label = image.astype(np.float32), label.astype(np.uint8)
np.save(preprocessed_dir / f"{case}_x.npy", image, allow_pickle=False)
np.save(preprocessed_dir / f"{case}_y.npy", label, allow_pickle=False)
def iterate(files, preprocessed_dir=None, val=True, shuffle=False, bs=1):
order = list(range(0, len(files)))
if shuffle: random.shuffle(order)
for i in range(0, len(files), bs):
samples = []
for i in order[i:i+bs]:
if preprocessed_dir is not None:
x_cached_path, y_cached_path = preprocessed_dir / f"{os.path.basename(files[i])}_x.npy", preprocessed_dir / f"{os.path.basename(files[i])}_y.npy"
if x_cached_path.exists() and y_cached_path.exists():
samples += [(np.load(x_cached_path), np.load(y_cached_path))]
else: samples += [preprocess(files[i])]
X, Y = [x[0] for x in samples], [x[1] for x in samples]
if val:
yield X[0][None], Y[0]
else:
X_preprocessed, Y_preprocessed = [], []
for x, y in zip(X, Y):
x, y = rand_balanced_crop(x, y)
x, y = rand_flip(x, y)
x, y = x.astype(np.float32), y.astype(np.uint8)
x = random_brightness_augmentation(x)
x = gaussian_noise(x)
X_preprocessed.append(x)
Y_preprocessed.append(y)
yield np.stack(X_preprocessed, axis=0), np.stack(Y_preprocessed, axis=0)
def gaussian_kernel(n, std):
gaussian_1d = signal.windows.gaussian(n, std)
gaussian_2d = np.outer(gaussian_1d, gaussian_1d)
gaussian_3d = np.outer(gaussian_2d, gaussian_1d)
gaussian_3d = gaussian_3d.reshape(n, n, n)
gaussian_3d = np.cbrt(gaussian_3d)
gaussian_3d /= gaussian_3d.max()
return gaussian_3d
def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-2.2, dim=3):
bounds = [(strides[i] - volume.shape[2:][i] % strides[i]) % strides[i] for i in range(dim)]
bounds = [bounds[i] if (volume.shape[2:][i] + bounds[i]) >= roi_shape[i] else bounds[i] + strides[i] for i in range(dim)]
paddings = [bounds[2]//2, bounds[2]-bounds[2]//2, bounds[1]//2, bounds[1]-bounds[1]//2, bounds[0]//2, bounds[0]-bounds[0]//2, 0, 0, 0, 0]
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
UNet3D MLPerf (#3470) * add training set transforms * add DICE cross entropy loss * convert pred and label to Tensor when calculating DICE score * cleanups and allow train dataset batching * fix DICE CE loss calculation * jitted training step * clean up DICE CE loss calculation * initial support for sharding * Revert "initial support for sharding" This reverts commit e3670813b8a67469e7f694e09f2d15a8c40065da. * minor updates * cleanup imports * add support for sharding * apply temp patch to try to avoid OOM * revert cstyle changes * add gradient acc * hotfix * add FP16 support * add ability to train on smaller image sizes * add support for saving and loading checkpoints + cleanup some various modes * fix issue with using smaller patch size + update W&B logging * disable LR_WARMUP_EPOCHS * updates * minor cleanups * cleanup * update order of transformations * more cleanups * realize loss * cleanup * more cleanup * some cleanups * add RAM usage * minor cleanups * add support for gradient accumulation * cleanup imports * minor updates to not use GA_STEPS * remove FP16 option since it's available now globally * update multi-GPU setup * add timing logs for training loop * go back to using existing dataloader and add ability to preprocess data to save time * clean up optimization and re-enable JIT and multi-GPU support for training and evaluation * free train and eval steps memory * cleanups and scale batch size based on the number of GPUs * fix GlobalCounters import * fix seed * fix W&B setup * update batch size default size * add back metric divergence check * put back JIT on UNet3d eval * move dataset preprocessing inside training code * add test for dice_loss * add config logging support to W&B and other cleanups * change how default float is getting retrieved * remove TinyJit import duplicate * update config logging to W&B and remove JIT on eval_step * no need for caching preprocessed data anymore * fix how evaluation is ran and how often * add support for LR scaling * fix issue with gaussian being moved to scipy.signal.windows * remove DICE loss unit test * fix issue where loss isn't compatible with multiGPU * add individual BEAM control for train and eval steps * fix ndimage scipy import * add BENCHMARK * cleanups on BENCHMARK + fix on rand_flip augmentation during training * cleanup train and eval BEAM envs * add checkpointing support after every eval * cleanup model_eval * disable grad during eval * use new preprocessing dataset mechanism * remove unused import * use training and inference_mode contexts * start eval after benchmarking * add data fetching time * cleanup decorators * more cleanups on training script * add message during benchmarking mode * realize when reassigning LR on scheduler and update default number of epochs * add JIT on eval step * remove JIT on eval_step * add train dataloader for unet3d * move checkpointing to be done after every epoch * revert removal of JIT on unet3d inference * save checkpoint if metric is not successful * Revert "add train dataloader for unet3d" This reverts commit c166d129dfbe2e1c46d1937135a60b4ed25caa3d. * Revert "Revert "add train dataloader for unet3d"" This reverts commit 36366c65d26f59ed1227acb670d5ce7b997606ae. * hotfix: seed was defaulting to a value of 0 * fix SEED value * remove the usage of context managers for setting BEAM and going from training to inference * support new stack API for calculating eval loss and metric * Revert "remove the usage of context managers for setting BEAM and going from training to inference" This reverts commit 2c0ba8d322ec912bd8617cbe167c542e9ba229d9. * check training and test preprocessed folders separately * clean up imports and log FUSE_CONV_BW * use train and val preprocessing constants * add kits19 dataset setup script * update to use the new test decorator for disabling grad * update kits19 dataset setup script * add docs on how to train the model * set default value for BASEDIR * add detailed instruction about BASEDIR usage --------- Co-authored-by: chenyu <chenyu@fastmail.com>
2024-09-10 16:37:28 +08:00
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5, gpus=None):
from tinygrad.engine.jit import TinyJit
mdl_run = TinyJit(lambda x: model(x).realize())
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]
bounds = [image_shape[i] % strides[i] for i in range(dim)]
bounds = [bounds[i] if bounds[i] < strides[i] // 2 else 0 for i in range(dim)]
inputs = inputs[
...,
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
]
labels = labels[
...,
bounds[0]//2:image_shape[0]-(bounds[0]-bounds[0]//2),
bounds[1]//2:image_shape[1]-(bounds[1]-bounds[1]//2),
bounds[2]//2:image_shape[2]-(bounds[2]-bounds[2]//2),
]
inputs, paddings = pad_input(inputs, roi_shape, strides)
padded_shape = inputs.shape[2:]
size = [(inputs.shape[2:][i] - roi_shape[i]) // strides[i] + 1 for i in range(dim)]
result = np.zeros((1, 3, *padded_shape), dtype=np.float32)
norm_map = np.zeros((1, 3, *padded_shape), dtype=np.float32)
norm_patch = gaussian_kernel(roi_shape[0], 0.125 * roi_shape[0])
norm_patch = np.expand_dims(norm_patch, axis=0)
for i in range(0, strides[0] * size[0], strides[0]):
for j in range(0, strides[1] * size[1], strides[1]):
for k in range(0, strides[2] * size[2], strides[2]):
UNet3D MLPerf (#3470) * add training set transforms * add DICE cross entropy loss * convert pred and label to Tensor when calculating DICE score * cleanups and allow train dataset batching * fix DICE CE loss calculation * jitted training step * clean up DICE CE loss calculation * initial support for sharding * Revert "initial support for sharding" This reverts commit e3670813b8a67469e7f694e09f2d15a8c40065da. * minor updates * cleanup imports * add support for sharding * apply temp patch to try to avoid OOM * revert cstyle changes * add gradient acc * hotfix * add FP16 support * add ability to train on smaller image sizes * add support for saving and loading checkpoints + cleanup some various modes * fix issue with using smaller patch size + update W&B logging * disable LR_WARMUP_EPOCHS * updates * minor cleanups * cleanup * update order of transformations * more cleanups * realize loss * cleanup * more cleanup * some cleanups * add RAM usage * minor cleanups * add support for gradient accumulation * cleanup imports * minor updates to not use GA_STEPS * remove FP16 option since it's available now globally * update multi-GPU setup * add timing logs for training loop * go back to using existing dataloader and add ability to preprocess data to save time * clean up optimization and re-enable JIT and multi-GPU support for training and evaluation * free train and eval steps memory * cleanups and scale batch size based on the number of GPUs * fix GlobalCounters import * fix seed * fix W&B setup * update batch size default size * add back metric divergence check * put back JIT on UNet3d eval * move dataset preprocessing inside training code * add test for dice_loss * add config logging support to W&B and other cleanups * change how default float is getting retrieved * remove TinyJit import duplicate * update config logging to W&B and remove JIT on eval_step * no need for caching preprocessed data anymore * fix how evaluation is ran and how often * add support for LR scaling * fix issue with gaussian being moved to scipy.signal.windows * remove DICE loss unit test * fix issue where loss isn't compatible with multiGPU * add individual BEAM control for train and eval steps * fix ndimage scipy import * add BENCHMARK * cleanups on BENCHMARK + fix on rand_flip augmentation during training * cleanup train and eval BEAM envs * add checkpointing support after every eval * cleanup model_eval * disable grad during eval * use new preprocessing dataset mechanism * remove unused import * use training and inference_mode contexts * start eval after benchmarking * add data fetching time * cleanup decorators * more cleanups on training script * add message during benchmarking mode * realize when reassigning LR on scheduler and update default number of epochs * add JIT on eval step * remove JIT on eval_step * add train dataloader for unet3d * move checkpointing to be done after every epoch * revert removal of JIT on unet3d inference * save checkpoint if metric is not successful * Revert "add train dataloader for unet3d" This reverts commit c166d129dfbe2e1c46d1937135a60b4ed25caa3d. * Revert "Revert "add train dataloader for unet3d"" This reverts commit 36366c65d26f59ed1227acb670d5ce7b997606ae. * hotfix: seed was defaulting to a value of 0 * fix SEED value * remove the usage of context managers for setting BEAM and going from training to inference * support new stack API for calculating eval loss and metric * Revert "remove the usage of context managers for setting BEAM and going from training to inference" This reverts commit 2c0ba8d322ec912bd8617cbe167c542e9ba229d9. * check training and test preprocessed folders separately * clean up imports and log FUSE_CONV_BW * use train and val preprocessing constants * add kits19 dataset setup script * update to use the new test decorator for disabling grad * update kits19 dataset setup script * add docs on how to train the model * set default value for BASEDIR * add detailed instruction about BASEDIR usage --------- Co-authored-by: chenyu <chenyu@fastmail.com>
2024-09-10 16:37:28 +08:00
out = mdl_run(Tensor(inputs[..., i:roi_shape[0]+i,j:roi_shape[1]+j, k:roi_shape[2]+k], device=gpus)).numpy()
result[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += out * norm_patch
norm_map[..., i:roi_shape[0]+i, j:roi_shape[1]+j, k:roi_shape[2]+k] += norm_patch
result /= norm_map
result = result[..., paddings[4]:image_shape[0]+paddings[4], paddings[2]:image_shape[1]+paddings[2], paddings[0]:image_shape[2]+paddings[0]]
return result, labels
def rand_flip(image, label, axis=(1, 2, 3)):
prob = 1 / len(axis)
for ax in axis:
if random.random() < prob:
image = np.flip(image, axis=ax).copy()
label = np.flip(label, axis=ax).copy()
return image, label
def random_brightness_augmentation(image, low=0.7, high=1.3, prob=0.1):
if random.random() < prob:
factor = np.random.uniform(low=low, high=high, size=1)
image = (image * (1 + factor)).astype(image.dtype)
return image
def gaussian_noise(image, mean=0.0, std=0.1, prob=0.1):
if random.random() < prob:
scale = np.random.uniform(low=0.0, high=std)
noise = np.random.normal(loc=mean, scale=scale, size=image.shape).astype(image.dtype)
image += noise
return image
def _rand_foreg_cropb(image, label, patch_size):
def adjust(foreg_slice, label, idx):
diff = patch_size[idx - 1] - (foreg_slice[idx].stop - foreg_slice[idx].start)
sign = -1 if diff < 0 else 1
diff = abs(diff)
ladj = 0 if diff == 0 else random.randrange(diff)
hadj = diff - ladj
low = max(0, foreg_slice[idx].start - sign * ladj)
high = min(label.shape[idx], foreg_slice[idx].stop + sign * hadj)
diff = patch_size[idx - 1] - (high - low)
if diff > 0 and low == 0: high += diff
elif diff > 0: low -= diff
return low, high
cl = np.random.choice(np.unique(label[label > 0]))
foreg_slices = ndimage.find_objects(ndimage.label(label==cl)[0])
foreg_slices = [x for x in foreg_slices if x is not None]
slice_volumes = [np.prod([s.stop - s.start for s in sl]) for sl in foreg_slices]
slice_idx = np.argsort(slice_volumes)[-2:]
foreg_slices = [foreg_slices[i] for i in slice_idx]
if not foreg_slices: return _rand_crop(image, label)
foreg_slice = foreg_slices[random.randrange(len(foreg_slices))]
low_x, high_x = adjust(foreg_slice, label, 1)
low_y, high_y = adjust(foreg_slice, label, 2)
low_z, high_z = adjust(foreg_slice, label, 3)
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label
def _rand_crop(image, label, patch_size):
ranges = [s - p for s, p in zip(image.shape[1:], patch_size)]
cord = [0 if x == 0 else random.randrange(x) for x in ranges]
low_x, high_x = cord[0], cord[0] + patch_size[0]
low_y, high_y = cord[1], cord[1] + patch_size[1]
low_z, high_z = cord[2], cord[2] + patch_size[2]
image = image[:, low_x:high_x, low_y:high_y, low_z:high_z]
label = label[:, low_x:high_x, low_y:high_y, low_z:high_z]
return image, label
def rand_balanced_crop(image, label, patch_size=(128, 128, 128), oversampling=0.4):
if random.random() < oversampling:
image, label = _rand_foreg_cropb(image, label, patch_size)
else:
image, label = _rand_crop(image, label, patch_size)
return image, label
if __name__ == "__main__":
for X, Y in iterate(get_val_files()):
print(X.shape, Y.shape)