MONAI 0.7 : tutorials : 2D レジストレーション – 2D XRay レジストレーション・デモ (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/10/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : 2D レジストレーション – 2D XRay レジストレーション・デモ
このノートブックは学習ベースの 64 x 64 X-Ray ハンドのアフィン・レジストレーションの素早いデモを示します。
このデモは MONAI のレジストレーション機能の使用方法を示す toy サンプルです。
このデモは主として以下を使用します :
- アフィン変換パラメータを予測するためにアフィンヘッドを持つ UNet ライクなレジストレーション・ネットワーク ;
- moving 画像を変換するための、MONAI C++/CUDA として実装された、warp 関数。
環境のセットアップ
BUILD_MONAI=1 フラグで “pip install” すると、MONA レポジトリから最新のソースコードを取得し、MONAI の C++/CUDA 拡張を構築して、パッケージをインストールします。
env BUILD_MONAI=1 の設定は、関連する Python モジュールを呼び出すとき MONAI は Pytorch/Python ネイティブ実装の代わりにそれらの拡張を優先することを示します。
(コンパイルは数分から 10+ 分かかる場合があります。)
%env BUILD_MONAI=1
!python -c "import monai" || pip install -q git+https://github.com/Project-MONAI/MONAI#egg=monai[all]
インポートのセットアップ
from monai.utils import set_determinism, first
from monai.transforms import (
EnsureChannelFirstD,
Compose,
LoadImageD,
RandRotateD,
RandZoomD,
ScaleIntensityRanged,
EnsureTypeD,
)
from monai.data import DataLoader, Dataset, CacheDataset
from monai.config import print_config, USE_COMPILED
from monai.networks.nets import GlobalNet
from monai.networks.blocks import Warp
from monai.apps import MedNISTDataset
import numpy as np
import torch
from torch.nn import MSELoss
import matplotlib.pyplot as plt
print_config()
set_determinism(42)
MONAI version: 0.5.0+7.g9f4da6a Numpy version: 1.19.5 Pytorch version: 1.8.1+cu101 MONAI flags: HAS_EXT = True, USE_COMPILED = True MONAI rev id: 9f4da6acded249bba24c85eaee4ece256ed45815 Optional dependencies: Pytorch Ignite version: 0.4.4 Nibabel version: 3.0.2 scikit-image version: 0.16.2 Pillow version: 7.1.2 Tensorboard version: 2.4.1 gdown version: 3.6.4 TorchVision version: 0.9.1+cu101 ITK version: 5.1.2 tqdm version: 4.60.0 lmdb version: 0.99 psutil version: 5.4.8 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
ペア単位の訓練入力の構築
実際のデータファイルをダウンロードして unzip するために MedNISTDataset オブジェクトを使用します。そして hand クラスを選択し、ロードされたデータ辞書を “fixed_hand” と “moving_hand” に変換します、これらは合成訓練ペアを作成するために別々に前処理されます。
train_data = MedNISTDataset(root_dir="./", section="training", download=True, transform=None)
training_datadict = [
{"fixed_hand": item["image"], "moving_hand": item["image"]}
for item in train_data.data if item["label"] == 4 # label 4 is for xray hands
]
print("\n first training items: ", training_datadict[:3])
MedNIST.tar.gz: 59.0MB [00:07, 8.83MB/s] downloaded file: ./MedNIST.tar.gz. Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d. Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d. Loading dataset: 100%|██████████| 47164/47164 [00:00<00:00, 145309.19it/s] first training items: [{'fixed_hand': './MedNIST/Hand/003696.jpeg', 'moving_hand': './MedNIST/Hand/003696.jpeg'}, {'fixed_hand': './MedNIST/Hand/001404.jpeg', 'moving_hand': './MedNIST/Hand/001404.jpeg'}, {'fixed_hand': './MedNIST/Hand/008882.jpeg', 'moving_hand': './MedNIST/Hand/008882.jpeg'}]
train_transforms = Compose(
[
LoadImageD(keys=["fixed_hand", "moving_hand"]),
EnsureChannelFirstD(keys=["fixed_hand", "moving_hand"]),
ScaleIntensityRanged(keys=["fixed_hand", "moving_hand"],
a_min=0., a_max=255., b_min=0.0, b_max=1.0, clip=True,),
RandRotateD(keys=["moving_hand"], range_x=np.pi/4, prob=1.0, keep_size=True, mode="bicubic"),
RandZoomD(keys=["moving_hand"], min_zoom=0.9, max_zoom=1.1, prob=1.0, mode="bicubic", align_corners=False),
EnsureTypeD(keys=["fixed_hand", "moving_hand"]),
]
)
訓練ペアの可視化
check_ds = Dataset(data=training_datadict, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1, shuffle=True)
check_data = first(check_loader)
fixed_image = check_data["fixed_hand"][0][0]
moving_image = check_data["moving_hand"][0][0]
print(f"moving_image shape: {moving_image.shape}")
print(f"fixed_image shape: {fixed_image.shape}")
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("moving_image")
plt.imshow(moving_image, cmap="gray")
plt.subplot(1, 2, 2)
plt.title("fixed_image")
plt.imshow(fixed_image, cmap="gray")
plt.show()
moving_image shape: torch.Size([64, 64]) fixed_image shape: torch.Size([64, 64])
訓練パイプラインを作成する
訓練ペアを獲得して訓練プロセスを高速化するために CacheDataset を使用します。この訓練データは GlobalNet に供給されます、これは 画像レベルのアフィン変換パラメータを予測します。Warp 層は初期化されて訓練と推論の両方のために使用されます。
train_ds = CacheDataset(data=training_datadict[:1000], transform=train_transforms,
cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
Loading dataset: 100%|██████████| 1000/1000 [00:01<00:00, 558.34it/s]
device = torch.device("cuda:0")
model = GlobalNet(
image_size=(64, 64),
spatial_dims=2,
in_channels=2, # moving and fixed
num_channel_initial=16,
depth=3).to(device)
image_loss = MSELoss()
if USE_COMPILED:
warp_layer = Warp(3, "border").to(device)
else:
warp_layer = Warp("bilinear", "border").to(device)
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
訓練ループ
max_epochs = 200
epoch_loss_values = []
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss, step = 0, 0
for batch_data in train_loader:
step += 1
optimizer.zero_grad()
moving = batch_data["moving_hand"].to(device)
fixed = batch_data["fixed_hand"].to(device)
ddf = model(torch.cat((moving, fixed), dim=1))
pred_image = warp_layer(moving, ddf)
loss = image_loss(pred_image, fixed)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# print(f"{step}/{len(train_ds) // train_loader.batch_size}, "
# f"train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
%matplotlib inline
plt.plot(epoch_loss_values)
幾つかの検証結果の視覚化
このセクションは moving vs fixed ハンドの初見の (i.e. 前に見てない) ペアのセットを作成して各ペアの間の変換を予測するためにネットワークを使用します。
val_ds = CacheDataset(data=training_datadict[2000:2500], transform=train_transforms,
cache_rate=1.0, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=16, num_workers=0)
for batch_data in val_loader:
moving = batch_data["moving_hand"].to(device)
fixed = batch_data["fixed_hand"].to(device)
ddf = model(torch.cat((moving, fixed), dim=1))
pred_image = warp_layer(moving, ddf)
break
fixed_image = fixed.detach().cpu().numpy()[:, 0]
moving_image = moving.detach().cpu().numpy()[:, 0]
pred_image = pred_image.detach().cpu().numpy()[:, 0]
Loading dataset: 100%|██████████| 500/500 [00:00<00:00, 803.96it/s]
%matplotlib inline
batch_size = 5
plt.subplots(batch_size, 3, figsize=(8, 10))
for b in range(batch_size):
# moving image
plt.subplot(batch_size, 3, b * 3 + 1)
plt.axis('off')
plt.title("moving image")
plt.imshow(moving_image[b], cmap="gray")
# fixed image
plt.subplot(batch_size, 3, b * 3 + 2)
plt.axis('off')
plt.title("fixed image")
plt.imshow(fixed_image[b], cmap="gray")
# warped moving
plt.subplot(batch_size, 3, b * 3 + 3)
plt.axis('off')
plt.title("predicted image")
plt.imshow(pred_image[b], cmap="gray")
plt.axis('off')
plt.show()
以上