

    这些数据是根据CC BY-SA 4.0许可证提供的,因此完全是开源的。
    下载完VerSe数据集后,我打开了一个*.nii.gz*文件。通过读取一个文件并查看CT扫描图像的一个特定切片,我能够运行Numpy transpose功能,以轴向、矢状和冠状三种不同视图查看一个切片。
    数据准备的任务是从原始图像和遮罩文件中的每个3D CT扫描文件生成图像切片。
    作者:Mazi Boustani
    import numpy as np 
    import pandas as pd
    import os
    from os import listdir
    from os.path import splitext
    import glob
    import shutil
    import random
    from pathlib import Path
    from PIL import Image
    from tqdm import tqdm
    import matplotlib.pyplot as plt
    %matplotlib inline
       import nibabel as nib
       raise ImportError('Install NIBABEL')
    import torch
    import torch.nn as nn
    from torch import Tensor
    import torch.nn.functional as F
    from torch import optim
    import torchvision.transforms as T
    from torch.utils.data import DataLoader, random_split
    from torch.utils.data import Dataset
    # set folder paths for train and validation data
    data_folder_path = "/Users/mazi/Projects/other/CT/data"
    train_data = data_folder_path + "/verse_19_20_training/"
    validation_data = data_folder_path + "/verse_19_20_validation/"
    # get one image to load
    train_data_raw_image = train_data + "/rawdata/sub-verse521/sub-verse521_dir-ax_ct.nii.gz"
    one_image = nib.load(train_data_raw_image)
    # look at image shape
    # look at image header. To understand header please refer to: https://brainder.org/2012/09/23/the-nifti-file-format/
    # look at the raw data
    one_image_data = one_image.get_fdata()
    # Visualize one image in three different angles
    one_image_data_axial = one_image_data
    # change the view
    one_image_data_sagittal = np.transpose(one_image_data, [2,1,0])
    one_image_data_sagittal = np.flip(one_image_data_sagittal, axis=0)
    # change the view
    one_image_data_coronal = np.transpose(one_image_data, [2,0,1])
    one_image_data_coronal = np.flip(one_image_data_coronal, axis=0)
    fig, ax = plt.subplots(1, 3, figsize = (60, 60))
    ax[0].imshow(one_image_data_axial[:,:,10], cmap ='bone')
    ax[0].set_title("Axial view", fontsize=60)
    ax[1].imshow(one_image_data_sagittal[:,:,260], cmap ='bone')
    ax[1].set_title("Sagittal view", fontsize=60)
    ax[2].imshow(one_image_data_coronal[:,:,200], cmap ='bone')
    ax[2].set_title("Coronal view", fontsize=60)
    # Overlay a mask on top of raw image (one slice of CT-scan)
    train_data_mask_image = train_data + "derivatives/sub-verse521/sub-verse521_dir-ax_seg-vert_msk.nii.gz"
    train_data_mask_image = nib.load(train_data_mask_image).get_fdata()
    rotated_raw = np.transpose(one_image_data, [2,1,0])
    rotated_raw = np.flip(rotated_raw, axis=0)
    plt.imshow(rotated_raw[:,:,260], cmap ='bone', interpolation='none')
    train_data_mask_image[train_data_mask_image == 0 ] = np.nan
    rotated_mask = np.transpose(train_data_mask_image, [2,1,0])
    rotated_mask = np.flip(rotated_mask, axis=0)
    plt.imshow(rotated_mask[:,:,260], cmap ='cool')
    # Set paths to store processed train and validation raw images and masks
    processed_train = "./processed_train/"
    processed_validation = "./processed_validation/"
    processed_train_raw_images = processed_train + "raw_images/"
    processed_train_masks = processed_train + "masks/"
    processed_validation_raw_images = processed_validation + "raw_images/"
    processed_validation_masks = processed_validation + "masks/"
    # Read all 2019 and 2020 raw files, both train and validation
    raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))
    raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))
    # Read all 2019 and 2020 raw files, both train and validation
    raw_train_files = glob.glob(os.path.join(train_data, 'rawdatanii.gz'))
    raw_validation_files = glob.glob(os.path.join(validation_data, 'rawdatanii.gz'))
    print("Raw images count train: {0}, validation: {1}".format(len(raw_train_files), len(raw_validation_files)))
    # Read all 2019 and 2020 derivatives files, both train and validation
    masks_train_files = glob.glob(os.path.join(train_data, 'derivativesnii.gz'))
    masks_validation_files = glob.glob(os.path.join(validation_data, 'derivativesnii.gz'))
    print("Masks images count train: {0}, validation: {1}".format(len(masks_train_files), len(masks_validation_files)))
    def read_file(nii_file):
       Read .nii.gz file.
         nii_file (str): a file path.
         3D numpy array of CT image data.
       return np.asanyarray(nib.load(nii_file).dataobj)
    def save_file(raw_data, label_data, file_name, index, output_raw_file_path, output_label_file_path):
       Save file into npz format.
         raw_data (array): 2D numpy array of raw image data.
         label_data (array): 2D numpy array of label image data.
         file_name (str): file name.
         index (int): slice of CT image.
         output_raw_file_path (str): Path to all raw files.
         output_label_file_path (str): Path to all mask files.
       # replace all non-zero pixels to 1
       label_data = np.where(label_data > 0, 1, label_data)
       unique_values = np.unique(label_data)
       # if data has pixel with value of 1 means it is a positive datapoint
       if len(unique_values) > 1:
           raw_file_name = "{0}{1}_{2}.png".format(output_raw_file_path, file_name, index)
           im = Image.fromarray(raw_data)
           im = im.convert("L")
           label_file_name = "{0}{1}_{2}.png".format(output_label_file_path, file_name, index)
           im = Image.fromarray(label_data)
           im = im.convert("L")
    def is_diagonal(matrix):
       Check if givem matrix is diagonal or not.
           matrix (np array): numpy array
       for i in range(0, 3):
           for j in range(0, 3) :
               if ((i != j) and (matrix[i][j] != 0)):
                   return False
       return True
    def generate_data(raw_file, label_file, file_name, output_raw_file_path, output_label_file_path):
       Main function to read each raw and label file and generate series of images
       per each slice.
         raw_file (str): path to raw file.
         label_file (str): path to label file.
         file_name (str): file name.
         output_raw_file_path (str): Path to all raw files.
         output_label_file_path (str): Path to all mask files.
       # If skip every 2 slice. Adjacent slices can be very similar to each other and
       # will generate redundant data
       skip_slice = 3
       continue_it = True
       raw_data = read_file(raw_file)
       label_data = read_file(label_file)
       if "split" in raw_file:
           continue_it = False
       affine = nib.load(raw_file).affine
       if is_diagonal(affine[:3, :3]):
           transposed_raw_data = np.transpose(raw_data, [2,1,0])
           transposed_raw_data = np.flip(transposed_raw_data)
           transposed_label_data = np.transpose(label_data, [2,1,0])
           transposed_label_data = np.flip(transposed_label_data)
           transposed_raw_data = np.rot90(raw_data)
           transposed_raw_data = np.flip(transposed_raw_data)
           transposed_label_data = np.rot90(label_data)
           transposed_label_data = np.flip(transposed_label_data) 
       if continue_it:
           if transposed_raw_data.shape:
               slice_count = transposed_raw_data.shape[-1]
               print("File name: ", file_name, " - Slice count: ", slice_count)
               # skip some slices
               for each_slice in range(1, slice_count, skip_slice):
    # Loop over raw images and masks and generate 'PNG' images.
    print("Processing started.")
    for each_raw_file in raw_train_files:
       raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]
       for each_mask_file in masks_train_files:
           if raw_file_name in each_mask_file.split("/")[-1]:
    print("Processing train data done.")
    # Loop over raw images and masks and generate 'PNG' images.
    for each_raw_file in raw_validation_files:
       raw_file_name = each_raw_file.split("/")[-1].split("_ct.nii.gz")[0]
       for each_mask_file in masks_validation_files:
           if raw_file_name in each_mask_file.split("/")[-1]:
    print("Processing validation data done.")
    # Define model parameters
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    # image size to convert to
    IMAGE_HEIGHT = 250
    IMAGE_WIDTH = 250
    LEARNING_RATE = 1e-4
    BATCH_SIZE = 10
    EPOCHS = 10
    # Set the device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # UNet model parts
    # Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
    class DoubleConv(nn.Module):
       """(convolution => [BN] => ReLU) * 2"""
       def __init__(self, in_channels, out_channels, mid_channels=None):
           if not mid_channels:
               mid_channels = out_channels
           self.double_conv = nn.Sequential(
               nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
               nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
       def forward(self, x):
           return self.double_conv(x)
    class Down(nn.Module):
       """Downscaling with maxpool then double conv"""
       def __init__(self, in_channels, out_channels):
           self.maxpool_conv = nn.Sequential(
               DoubleConv(in_channels, out_channels)
       def forward(self, x):
           return self.maxpool_conv(x)
    class Up(nn.Module):
       """Upscaling then double conv"""
       def __init__(self, in_channels, out_channels, bilinear=True):
     # if bilinear, use the normal convolutions to reduce the number of channels
           if bilinear:
                self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
                self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
                self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
                self.conv = DoubleConv(in_channels, out_channels)
       def forward(self, x1, x2):
           x1 = self.up(x1)
           # input is CHW
           diffY = x2.size()[2] - x1.size()[2]
           diffX = x2.size()[3] - x1.size()[3]
           x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                           diffY // 2, diffY - diffY // 2])
           # if you have padding issues, see
           # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
           # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
           x = torch.cat([x2, x1], dim=1)
           return self.conv(x)
    class OutConv(nn.Module):
       def __init__(self, in_channels, out_channels):
           super(OutConv, self).__init__()
           self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
       def forward(self, x):
           return self.conv(x)
    # Defining UNet architecture
    # Source code: https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py
    class UNet(nn.Module):
       def __init__(self, n_channels, n_classes, bilinear=True):
           super(UNet, self).__init__(
           self.n_channels = n_channels
           self.n_classes = n_classes
           self.bilinear = bilinear
           self.inc = DoubleConv(n_channels, 64)
           self.down1 = Down(64, 128)
           self.down2 = Down(128, 256)
           self.down3 = Down(256, 512)
           factor = 2 if bilinear else 1
           self.down4 = Down(512, 1024 // factor)
           self.up1 = Up(1024, 512 // factor, bilinear)
           self.up2 = Up(512, 256 // factor, bilinear)
           self.up3 = Up(256, 128 // factor, bilinear)
           self.up4 = Up(128, 64, bilinear)
           self.outc = OutConv(64, n_classes)
       def forward(self, x):
           x1 = self.inc(x)
           x2 = self.down1(x1)
           x3 = self.down2(x2)
           x4 = self.down3(x3)
           x5 = self.down4(x4)
           x = self.up1(x5, x4)
           x = self.up2(x, x3)
           x = self.up3(x, x2)
           x = self.up4(x, x1)
           logits = self.outc(x)
           return logits
    # Define PyTorch dataset class
    # This class will access the images and masks, preprocess them for training and validation
    class VerSeDataset(Dataset):
       def __init__(self, raw_images_path, masks_path, images_name):
           self.raw_images_path = raw_images_path
           self.masks_path = masks_path
           self.images_name = images_name
       def __len__(self):
           return len(self.images_name)
       def __getitem__(self, index):
           # get image and mask for a given index
           img_path = os.path.join(self.raw_images_path, self.images_name[index])
           mask_path = os.path.join(self.masks_path, self.images_name[index])
           # read the image and mask
           image = Image.open(img_path)
           mask = Image.open(mask_path)
           # resize image and change the shape to (1, image_width, image_height)
           w, h = image.size
           image = image.resize((w, h), resample=Image.BICUBIC)
           image = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(image)
           image_ndarray = np.asarray(image)
           image_ndarray = image_ndarray.reshape(1, image_ndarray.shape[0], image_ndarray.shape[1])
           # resize the mask. Mask shape is (image_width, image_height)
           mask = mask.resize((w, h), resample=Image.NEAREST)
           mask = T.Resize(size=(IMAGE_WIDTH, IMAGE_HEIGHT))(mask)
           mask_ndarray = np.asarray(mask)
           return {
               'image': torch.as_tensor(image_ndarray.copy()).float().contiguous(),
               'mask': torch.as_tensor(mask_ndarray.copy()).float().contiguous(
    # Get path for all images and masks
    train_images_paths = os.listdir(processed_train_raw_images)
    train_masks_paths = os.listdir(processed_train_masks)
    validation_images_paths = os.listdir(processed_validation_raw_images)
    validation_masks_paths = os.listdir(processed_validation_masks)
    # Load both images and masks data
    train_data = VerSeDataset(processed_train_raw_images, processed_train_masks, train_images_paths)
    valid_data = VerSeDataset(processed_validation_raw_images, processed_validation_masks, validation_images_paths)
    # Create PyTorch DataLoader
    train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    valid_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False)
    # Looking at one image and mask from one batch just to check them visually
    next_image = next(iter(valid_dataloader))
    fig, ax = plt.subplots(1, 2, figsize = (60, 60))
    ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')
    ax[0].set_title("Raw image", fontsize=60)
    ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')
    ax[1].set_title("Mask image", fontsize=60)
    # Defining Dice loss class
    # Source code: https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch
    class DiceLoss(nn.Module):
       def __init__(self, weight=None, size_average=True):
           super(DiceLoss, self).__init__()
       def forward(self, inputs, targets, smooth=1):
           inputs = torch.sigmoid(inputs)      
           # flatten label and prediction tensors
           inputs = inputs.view(-1)
           targets = targets.view(-1)
           intersection = (inputs * targets).sum()                            
           dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
           bce = F.binary_cross_entropy_with_logits(inputs, targets)
           pred = torch.sigmoid(inputs)
           loss = bce * 0.5 + dice * (1 - 0.5)
          # subtract 1 to calculate loss from dice value
           return 1 - dice
    # Define model as UNet
    model = UNet(n_channels=1, n_classes=1)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    # Train and validate
    train_loss = []
    val_loss = []
    for epoch in range(EPOCHS):  
       train_running_loss = 0.0
       counter = 0
       with tqdm(total=len(train_data), desc=f'Epoch {epoch + 1}/{EPOCHS}', unit='img') as pbar:
           for batch in train_dataloader:
               image = batch['image'].to(DEVICE)
               mask = batch['mask'].to(DEVICE)
               outputs = model(image)
               outputs = outputs.squeeze(1)
               loss = DiceLoss()(outputs, mask)
               train_running_loss += loss.item()
               pbar.set_postfix(**{'loss (batch)': loss.item()})
       valid_running_loss = 0.0
       counter = 0
       with torch.no_grad():
           for i, data in enumerate(valid_dataloader):
               counter += 1
               image = data['image'].to(DEVICE)
               mask = data['mask'].to(DEVICE)
               outputs = model(image)
               outputs = outputs.squeeze(1)
               loss = DiceLoss()(outputs, mask)
               valid_running_loss += loss.item()
    Epoch 1/10: 100%|██████████| 4790/4790 [4:00:34<00:00,  3.01s/img, loss (batch)=0.385]  
    Epoch 2/10: 100%|██████████| 4790/4790 [4:00:02<00:00,  3.01s/img, loss (batch)=0.268]  
    Epoch 3/10: 100%|██████████| 4790/4790 [3:57:30<00:00,  2.98s/img, loss (batch)=0.152]  
    Epoch 4/10: 100%|██████████| 4790/4790 [3:57:05<00:00,  2.97s/img, loss (batch)=0.105]  
    Epoch 5/10: 100%|██████████| 4790/4790 [4:08:29<00:00,  3.11s/img, loss (batch)=0.103]   
    Epoch 6/10: 100%|██████████| 4790/4790 [4:04:12<00:00,  3.06s/img, loss (batch)=0.0874]  
    Epoch 7/10: 100%|██████████| 4790/4790 [4:02:00<00:00,  3.03s/img, loss (batch)=0.0759]  
    Epoch 8/10: 100%|██████████| 4790/4790 [3:58:32<00:00,  2.99s/img, loss (batch)=0.0655]  
    Epoch 9/10: 100%|██████████| 4790/4790 [4:00:47<00:00,  3.02s/img, loss (batch)=0.0644]  
    Epoch 10/10: 100%|██████████| 4790/4790 [4:08:54<00:00,  3.12s/img, loss (batch)=0.0604]  
    # Plot train vs validation loss
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color="orange", label='train loss')
    plt.plot(val_loss, color="red", label='validation loss')
    # Save the trained model
       'epoch': EPOCHS,
       'model_state_dict': model.state_dict(),
       'optimizer_state_dict': optimizer.state_dict(),
    }, "./unet_model.pth")
    # Visually look at one prediction 
    next_image = next(iter(valid_dataloader))
    # do predict
    outputs = model(next_image['image'].float())
    outputs = outputs.detach().cpu()
    loss = DiceLoss()(outputs, next_image['mask'])
    print("Dice Score: ", 1- loss.item())
    outputs[outputs<=0.0] = 0
    outputs[outputs>0.0] = 1.0
    # plot all three images
    fig, ax = plt.subplots(1, 3, figsize = (60, 60))
    ax[0].imshow(next_image['image'][0][0,:,:], cmap ='bone')
    ax[0].set_title("Raw Image", fontsize=60)
    ax[1].imshow(next_image['mask'][0][:,:], cmap ='bone')
    ax[1].set_title("True Mask", fontsize=60)
    ax[2].imshow(outputs[0,0,:,:], cmap ='bone')
    ax[2].set_title("Predicted Mask", fontsize=60)
    未来的工作:这个任务也可以用3D UNet完成,这可能是学习脊柱结构的更好方法。