MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション (Lightning 版) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/16/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション (Lightning 版)
このノートブックは MONAI を PyTorch Lightning フレームワークと連携して使用できる可能性を示します。
このチュートリアルは MONAI をどのように PyTorch Lightning フレームワークと連携して使用できるかを実演します。
以下の MONAI の機能の使用方法を実演します :
- 辞書形式データのための変換。
- メタデータとともに Nifti 画像をロードする。
- チャネル次元がない場合チャネル dim をデータに追加する。
- 想定される範囲で医療画像強度をスケールする。
- ポジティブ/ネガティブ・ラベル比率に基づいてバランスの取れた画像のバッチをクロップする。
- 訓練と検証を高速化するキャシュ IO と変換。
- 3D セグメンテーション・タスクのための 3D UNet モデル、Dice 損失関数、Mean Dice メトリック。
- スライディング・ウィンドウ推論法。
- 再現性のための決定論的訓練。
Spleen データセットは http://medicaldecathlon.com/ からダウンロードできます。
- Target: Spleen
- Modality: CT
- Size: 61 3D volumes (41 Training + 20 Testing)
- Source: Memorial Sloan Kettering Cancer Center
- Challenge: Large ranging foreground size
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install -q pytorch-lightning==1.4.0
%matplotlib inline
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from monai.utils import set_determinism
from monai.transforms import (
AsDiscrete,
AddChanneld,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureTyped,
EnsureType,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import pytorch_lightning
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
print_config()
MONAI version: 0.6.0+1.g8365443a Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: 8365443ababac313340467e5987c7babe2b5b86a Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 8.2.0 Tensorboard version: 2.2.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical/
データセットのダウンロード
データセットをダウンロードして展開します。データセットは http://medicaldecathlon.com/ に由来します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
LightningModule を定義する
LightningModule は訓練コードのリファクタリングを含みます。以下のモジュールは spleen_segmentation_3d.ipynb のコードのリファクタリングです :
class Net(pytorch_lightning.LightningModule):
def __init__(self):
super().__init__()
self._model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
)
self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
self.post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])
self.post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
self.best_val_dice = 0
self.best_val_epoch = 0
def forward(self, x):
return self._model(x)
def prepare_data(self):
# set up the correct data path
train_images = sorted(
glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
# set deterministic training for reproducibility
set_determinism(seed=0)
# define the data transforms
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# randomly crop out patch samples from
# big image based on pos / neg ratio
# the image centers of negative samples
# must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
# user can also add other random transforms
# RandAffined(
# keys=['image', 'label'],
# mode=('bilinear', 'nearest'),
# prob=1.0,
# spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15),
# scale_range=(0.1, 0.1, 0.1)),
EnsureTyped(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)
# we use cached datasets - these are 10x faster than regular datasets
self.train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=4,
)
self.val_ds = CacheDataset(
data=val_files, transform=val_transforms,
cache_rate=1.0, num_workers=4,
)
# self.train_ds = monai.data.Dataset(
# data=train_files, transform=train_transforms)
# self.val_ds = monai.data.Dataset(
# data=val_files, transform=val_transforms)
def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(
self.train_ds, batch_size=2, shuffle=True,
num_workers=4, collate_fn=list_data_collate,
)
return train_loader
def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(
self.val_ds, batch_size=1, num_workers=4)
return val_loader
def configure_optimizers(self):
optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
return optimizer
def training_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
output = self.forward(images)
loss = self.loss_function(output, labels)
tensorboard_logs = {"train_loss": loss.item()}
return {"loss": loss, "log": tensorboard_logs}
def validation_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
roi_size = (160, 160, 160)
sw_batch_size = 4
outputs = sliding_window_inference(
images, roi_size, sw_batch_size, self.forward)
loss = self.loss_function(outputs, labels)
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
labels = [self.post_label(i) for i in decollate_batch(labels)]
self.dice_metric(y_pred=outputs, y=labels)
return {"val_loss": loss, "val_number": len(outputs)}
def validation_epoch_end(self, outputs):
val_loss, num_items = 0, 0
for output in outputs:
val_loss += output["val_loss"].sum().item()
num_items += output["val_number"]
mean_val_dice = self.dice_metric.aggregate().item()
self.dice_metric.reset()
mean_val_loss = torch.tensor(val_loss / num_items)
tensorboard_logs = {
"val_dice": mean_val_dice,
"val_loss": mean_val_loss,
}
if mean_val_dice > self.best_val_dice:
self.best_val_dice = mean_val_dice
self.best_val_epoch = self.current_epoch
print(
f"current epoch: {self.current_epoch} "
f"current mean dice: {mean_val_dice:.4f}"
f"\nbest mean dice: {self.best_val_dice:.4f} "
f"at epoch: {self.best_val_epoch}"
)
return {"log": tensorboard_logs}
訓練の実行
# initialise the LightningModule
net = Net()
# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger(
save_dir=log_dir
)
# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
gpus=[0],
max_epochs=600,
logger=tb_logger,
checkpoint_callback=True,
num_sanity_val_steps=1,
)
# train
trainer.fit(net)
print(
f"train completed, best_metric: {net.best_val_dice:.4f} "
f"at epoch {net.best_val_epoch}")
train completed, best_metric: 0.9498 at epoch 563
tensorboard で訓練を見る
%load_ext tensorboard
%tensorboard --logdir=log_dir
The tensorboard extension is already loaded. To reload it, use: %reload_ext tensorboard Reusing TensorBoard on port 6006 (pid 27668), started 1:35:41 ago. (Use '!kill 27668' to kill it.)
入力画像とラベルでベストなモデル出力を確認する
net.eval()
device = torch.device("cuda:0")
net.to(device)
with torch.no_grad():
for i, val_data in enumerate(net.val_dataloader()):
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_data["image"].to(device), roi_size, sw_batch_size, net
)
# plot the slice [:, :, 80]
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image {i}")
plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
plt.subplot(1, 3, 2)
plt.title(f"label {i}")
plt.imshow(val_data["label"][0, 0, :, :, 80])
plt.subplot(1, 3, 3)
plt.title(f"output {i}")
plt.imshow(torch.argmax(
val_outputs, dim=1).detach().cpu()[0, :, :, 80])
plt.show()
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合はディレクトリを削除する。
if directory is None:
shutil.rmtree(root_dir)
以上