MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (配列版) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/14/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (配列版)
このノートブックは GanTrainer、モジュール化された敵対的学習のための MONAI ワークフロー・エンジンを示します。MedNIST ハンド CT スキャン・データセットを使用して医療画像再構築ネットワークを訓練します。配列バージョン。
MONAI フレームワークは敵対的生成ネットワークを簡単に設計し、訓練して評価するために使用できます。このノートブックは、ハンド CT スキャンの画像を再構築するために単純な GAN モデルを設計して訓練する MONAI コンポーネントを使用する実例を示します。
ネットワーク・アーキテクチャと損失関数についての詳細は MONAI Mednist GAN チュートリアル を読んでください。
Step 1: セットアップ
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[ignite, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import set_determinism
from monai.transforms import (
AddChannel,
Compose,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
Transform,
)
from monai.networks.nets import Discriminator, Generator
from monai.networks import normal_init
from monai.handlers import CheckpointSaver, MetricLogger, StatsHandler
from monai.engines.utils import GanKeys, default_make_latent
from monai.engines import GanTrainer
from monai.data import CacheDataset, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import numpy as np
import torch
import matplotlib.pyplot as plt
import shutil
import sys
import logging
import tempfile
import os
インポートのセットアップ
print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0 Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 7.0.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットをダウンロードする
データセットをダウンロードして展開します。
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
MedNIST データセットを使用する場合、出典を明示してください、e.g. https://github.com/Project-MONAI/tutorials/blob/master/2d_classification/mednist_tutorial.ipynb。
resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
hands = [
os.path.join(data_dir, "Hand", x)
for x in os.listdir(os.path.join(data_dir, "Hand"))
]
print(hands[:5])
['/workspace/data/medical/MedNIST/Hand/000317.jpeg', '/workspace/data/medical/MedNIST/Hand/002344.jpeg', '/workspace/data/medical/MedNIST/Hand/000816.jpeg', '/workspace/data/medical/MedNIST/Hand/004046.jpeg', '/workspace/data/medical/MedNIST/Hand/003316.jpeg']
Step 2: MONAI コンポーネントを初期化する
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
set_determinism(0)
device = torch.device("cuda:0")
画像変換チェインを作成する
セーブされたディスク画像を利用可能なテンソルに変換するために処理パイプラインを定義します。
class LoadTarJpeg(Transform):
def __call__(self, data):
return plt.imread(data)
train_transforms = Compose(
[
LoadTarJpeg(),
AddChannel(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureType(),
]
)
データセットとデーたローダを作成する
データを保持して訓練の間にバッチを提示します。
real_dataset = CacheDataset(hands, train_transforms)
100%|██████████| 10000/10000 [00:09<00:00, 1092.83it/s]
batch_size = 300
real_dataloader = DataLoader(
real_dataset, batch_size=batch_size, shuffle=True, num_workers=10)
# We don't need to do any preparing so just return "as is"
def prepare_batch(batchdata, device=None, non_blocking=False):
return batchdata.to(device=device, non_blocking=non_blocking)
generator と discriminator を定義する
基本的なコンピュータビジョン GAN ネットワークをライブラリからロードします。
# define networks
disc_net = Discriminator(
in_shape=(1, 64, 64),
channels=(8, 16, 32, 64, 1),
strides=(2, 2, 2, 2, 1),
num_res_units=1,
kernel_size=5,
).to(device)
latent_size = 64
gen_net = Generator(
latent_shape=latent_size,
start_shape=(latent_size, 8, 8),
channels=[32, 16, 8, 1],
strides=[2, 2, 2, 1],
)
gen_net.conv.add_module("activation", torch.nn.Sigmoid())
gen_net = gen_net.to(device)
# initialize both networks
disc_net.apply(normal_init)
gen_net.apply(normal_init)
# define optimizors
learning_rate = 2e-4
betas = (0.5, 0.999)
disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas)
gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas)
# define loss functions
disc_loss_criterion = torch.nn.BCELoss()
gen_loss_criterion = torch.nn.BCELoss()
real_label = 1
fake_label = 0
def discriminator_loss(gen_images, real_images):
real = real_images.new_full((real_images.shape[0], 1), real_label)
gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
realloss = disc_loss_criterion(disc_net(real_images), real)
genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
return (genloss + realloss) / 2
def generator_loss(gen_images):
output = disc_net(gen_images)
cats = output.new_full(output.shape, real_label)
return gen_loss_criterion(output, cats)
訓練ハンドラを作成する
モデル訓練の間に操作を実行します。
metric_logger = MetricLogger(
loss_transform=lambda x: {
GanKeys.GLOSS: x[GanKeys.GLOSS], GanKeys.DLOSS: x[GanKeys.DLOSS]},
metric_transform=lambda x: x,
)
handlers = [
StatsHandler(
name="batch_training_loss",
output_transform=lambda x: {
GanKeys.GLOSS: x[GanKeys.GLOSS],
GanKeys.DLOSS: x[GanKeys.DLOSS],
},
),
CheckpointSaver(
save_dir=os.path.join(root_dir, "hand-gan"),
save_dict={"g_net": gen_net, "d_net": disc_net},
save_interval=10,
save_final=True,
epoch_level=True,
),
metric_logger,
]
GanTrainer を作成する
敵対的学習のための MONAI ワークフロー・エンジン。GanTrainer によってコンポーネントはここで集められます。
Goodfellow et al. 2014 https://arxiv.org/abs/1406.2661 に基づいた訓練ループを使用します。
- ランダムな潜在コードから m 個の fakes を生成します。
- これらの fakes と現在のバッチ reals で D を更新します、d_train_steps 回反復されます。
- 新しいランダムな潜在コードから m fakes を生成します。
- discriminator フィードバックを使用してこれらの fakes で generator を更新します。
disc_train_steps = 5
max_epochs = 50
trainer = GanTrainer(
device,
max_epochs,
real_dataloader,
gen_net,
gen_opt,
generator_loss,
disc_net,
disc_opt,
discriminator_loss,
d_prepare_batch=prepare_batch,
d_train_steps=disc_train_steps,
g_update_latents=True,
latent_shape=latent_size,
key_train_metric=None,
train_handlers=handlers,
)
Step 3: 訓練の開始
trainer.run()
結果を評価する
G と D の損失カーブを崩れていないか調べます。
g_loss = [loss[1][GanKeys.GLOSS] for loss in metric_logger.loss]
d_loss = [loss[1][GanKeys.DLOSS] for loss in metric_logger.loss]
plt.figure(figsize=(12, 5))
plt.semilogy(g_loss, label="Generator Loss")
plt.semilogy(d_loss, label="Discriminator Loss")
plt.grid(True, "both", "both")
plt.legend()
plt.show()
画像再構築を見る
ランダムな潜在コードで訓練された generator の出力を見ます。
test_img_count = 10
test_latents = default_make_latent(test_img_count, latent_size).to(device)
fakes = gen_net(test_latents)
fig, axs = plt.subplots(2, (test_img_count // 2), figsize=(20, 8))
axs = axs.flatten()
for i, ax in enumerate(axs):
ax.axis("off")
ax.imshow(fakes[i, 0].cpu().data.numpy(), cmap="gray")
データディレクトリのクリーンアップ
一時ディレクトリが作成された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上