ホーム » MONAI » MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版)

MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版)

MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版) (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/14/2021 (0.7.0)

* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス 無料 Web セミナー開催中

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

 

MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版)

このノートブックは GanTrainer、モジュール化された敵対的学習のための MONAI ワークフロー・エンジンを示します。MedNIST ハンド CT スキャン・データセットを使用して医療画像再構築ネットワークを訓練します。辞書バージョン。

MONAI フレームワークは敵対的生成ネットワークを簡単に設計し、訓練して評価するために使用できます。このノートブックは、ハンド CT スキャンの画像を再構築するために単純な GAN モデルを設計して訓練する MONAI コンポーネントを使用する実例を示します。

ネットワーク・アーキテクチャと損失関数についての詳細は MONAI Mednist GAN チュートリアル を読んでください。

 

Step 1: セットアップ

環境のセットアップ

!python -c "import monai" || pip install -q "monai-weekly[ignite, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import set_determinism
from monai.transforms import (
    AddChannelD,
    Compose,
    LoadImageD,
    RandFlipD,
    RandRotateD,
    RandZoomD,
    ScaleIntensityD,
    EnsureTypeD,
)
from monai.networks.nets import Discriminator, Generator
from monai.networks import normal_init
from monai.handlers import CheckpointSaver, MetricLogger, StatsHandler
from monai.engines.utils import GanKeys, default_make_latent
from monai.engines import GanTrainer
from monai.data import CacheDataset, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import numpy as np
import torch
import matplotlib.pyplot as plt
import tempfile
import sys
import shutil
import os
import logging

 

インポートのセットアップ

print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0
Numpy version: 1.20.3
Pytorch version: 1.9.0a0+c3d40fd
MONAI flags: HAS_EXT = True, USE_COMPILED = False
MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c

Optional dependencies:
Pytorch Ignite version: 0.4.5
Nibabel version: 3.2.1
scikit-image version: 0.15.0
Pillow version: 7.0.0
Tensorboard version: 2.5.0
gdown version: 3.13.0
TorchVision version: 0.10.0a0
ITK version: 5.1.2
tqdm version: 4.53.0
lmdb version: 1.2.1
psutil version: 5.8.0
pandas version: 1.1.4
einops version: 0.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

 

データディレクトリのセットアップ

MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical

 

データセットをダウンロードする

データセットをダウンロードして展開します。

MedMNIST データセットは TCIA, RSNA Bone Age チャレンジNIH Chest X-ray データセット からの様々なセットから集められました。

データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。

MedNIST データセットを使用する場合、出典を明示してください、e.g. https://github.com/Project-MONAI/tutorials/blob/master/2d_classification/mednist_tutorial.ipynb

resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)
hand_dir = os.path.join(data_dir, "Hand")
training_datadict = [
    {"hand": os.path.join(hand_dir, filename)}
    for filename in os.listdir(hand_dir)
]
print(training_datadict[:5])
[{'hand': '/workspace/data/medical/MedNIST/Hand/000317.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/002344.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/000816.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/004046.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/003316.jpeg'}]

 

Step 2: MONAI コンポーネントを初期化する

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
set_determinism(0)
device = torch.device("cuda:0")

 

画像変換チェインを作成する

セーブされたディスク画像を利用可能なテンソルに変換するために処理パイプラインを定義します。

train_transforms = Compose(
    [
        LoadImageD(keys=["hand"]),
        AddChannelD(keys=["hand"]),
        ScaleIntensityD(keys=["hand"]),
        RandRotateD(keys=["hand"], range_x=np.pi /
                    12, prob=0.5, keep_size=True),
        RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5),
        RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
        EnsureTypeD(keys=["hand"]),
    ]
)

 

データセットとデーたローダを作成する

データを保持して訓練の間にバッチを提示します。

real_dataset = CacheDataset(training_datadict, train_transforms)
100%|██████████| 10000/10000 [00:09<00:00, 1000.72it/s]
batch_size = 300
real_dataloader = DataLoader(
    real_dataset, batch_size=batch_size, shuffle=True, num_workers=10)


def prepare_batch(batchdata, device=None, non_blocking=False):
    return batchdata["hand"].to(device=device, non_blocking=non_blocking)

 

generator と discriminator を定義する

基本的なコンピュータビジョン GAN ネットワークをライブラリからロードします。

# define networks
disc_net = Discriminator(
    in_shape=(1, 64, 64),
    channels=(8, 16, 32, 64, 1),
    strides=(2, 2, 2, 2, 1),
    num_res_units=1,
    kernel_size=5,
).to(device)

latent_size = 64
gen_net = Generator(
    latent_shape=latent_size,
    start_shape=(latent_size, 8, 8),
    channels=[32, 16, 8, 1],
    strides=[2, 2, 2, 1],
)
gen_net.conv.add_module("activation", torch.nn.Sigmoid())
gen_net = gen_net.to(device)

# initialize both networks
disc_net.apply(normal_init)
gen_net.apply(normal_init)

# define optimizors
learning_rate = 2e-4
betas = (0.5, 0.999)
disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas)
gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas)

# define loss functions
disc_loss_criterion = torch.nn.BCELoss()
gen_loss_criterion = torch.nn.BCELoss()
real_label = 1
fake_label = 0


def discriminator_loss(gen_images, real_images):
    real = real_images.new_full((real_images.shape[0], 1), real_label)
    gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)

    realloss = disc_loss_criterion(disc_net(real_images), real)
    genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)

    return (genloss + realloss) / 2


def generator_loss(gen_images):
    output = disc_net(gen_images)
    cats = output.new_full(output.shape, real_label)
    return gen_loss_criterion(output, cats)

 

訓練ハンドラを作成する

モデル訓練の間に操作を実行します。

metric_logger = MetricLogger(
    loss_transform=lambda x: {
        GanKeys.GLOSS: x[GanKeys.GLOSS], GanKeys.DLOSS: x[GanKeys.DLOSS]},
    metric_transform=lambda x: x,
)

handlers = [
    StatsHandler(
        name="batch_training_loss",
        output_transform=lambda x: {
            GanKeys.GLOSS: x[GanKeys.GLOSS],
            GanKeys.DLOSS: x[GanKeys.DLOSS],
        },
    ),
    CheckpointSaver(
        save_dir=os.path.join(root_dir, "hand-gan"),
        save_dict={"g_net": gen_net, "d_net": disc_net},
        save_interval=10,
        save_final=True,
        epoch_level=True,
    ),
    metric_logger,
]

 

GanTrainer を作成する

敵対的学習のための MONAI ワークフロー・エンジン。GanTrainer によってコンポーネントはここで集められます。

Goodfellow et al. 2014 https://arxiv.org/abs/1406.2661 に基づいた訓練ループを使用します。

訓練ループ : データサイズ m の各バッチについて

  1. ランダムな潜在コードから m 個の fakes を生成します。
  2. これらの fakes と現在のバッチ reals で D を更新します、d_train_steps 回反復されます。
  3. 新しいランダムな潜在コードから m fakes を生成します。
  4. discriminator フィードバックを使用してこれらの fakes で generator を更新します。
disc_train_steps = 5
max_epochs = 50

trainer = GanTrainer(
    device,
    max_epochs,
    real_dataloader,
    gen_net,
    gen_opt,
    generator_loss,
    disc_net,
    disc_opt,
    discriminator_loss,
    d_prepare_batch=prepare_batch,
    d_train_steps=disc_train_steps,
    g_update_latents=True,
    latent_shape=latent_size,
    key_train_metric=None,
    train_handlers=handlers,
)

 

Step 3: 訓練の開始

trainer.run()

 

結果を評価する

G と D の損失カーブを崩れていないか調べます。

g_loss = [loss[1][GanKeys.GLOSS] for loss in metric_logger.loss]
d_loss = [loss[1][GanKeys.DLOSS] for loss in metric_logger.loss]
plt.figure(figsize=(12, 5))
plt.semilogy(g_loss, label="Generator Loss")
plt.semilogy(d_loss, label="Discriminator Loss")
plt.grid(True, "both", "both")
plt.legend()
plt.show()

 

画像再構築を見る

ランダムな潜在コードで訓練された generator の出力を見ます。

test_img_count = 10
test_latents = default_make_latent(test_img_count, latent_size).to(device)
fakes = gen_net(test_latents)

fig, axs = plt.subplots(2, (test_img_count // 2), figsize=(20, 8))
axs = axs.flatten()
for i, ax in enumerate(axs):
    ax.axis("off")
    ax.imshow(fakes[i, 0].cpu().data.numpy(), cmap="gray")

 

データディレクトリのクリーンアップ

一時ディレクトリが作成された場合にはディレクトリを削除します。

if directory is None:
    shutil.rmtree(root_dir)
 

以上



ClassCat® Chatbot

人工知能開発支援

◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com

カテゴリー