MONAI 1.0 : tutorials : モジュール – MedNIST で GAN (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/08/2021 (1.0.1)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
MONAI 1.0 : tutorials : モジュール – MedNIST で GAN
このノートブックはランダムな入力テンソルから画像を生成するネットワークを訓練するための MONAI の使用方法を例示します。個別の Generator と Discriminator ネットワークに対処する単純な GAN が採用されます。
これは以下のステップを進みます :
- 遠隔ソースからデータをロードする
- このデータからのデータセットと変換を構築する
- ネットワークを定義する
- 訓練と評価
環境セットアップ
!python -c "import monai" || pip install -q monai-weekly
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import progress_bar, set_determinism
from monai.transforms import (
AddChannel,
Compose,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
Transform,
)
from monai.networks.nets import Discriminator, Generator
from monai.networks import normal_init
from monai.data import CacheDataset
from monai.config import print_config
from monai.apps import download_and_extract
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import tempfile
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0 Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 7.0.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
再現性のための決定論的訓練
set_determinism(seed=0)
訓練変数の定義
disc_train_interval = 1
disc_train_steps = 5
batch_size = 300
latent_size = 64
max_epochs = 50
real_label = 1
gen_label = 0
learning_rate = 2e-4
betas = (0.5, 0.999)
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。
これは結果をセーブしてダウンロードを再利用することを可能にします。
指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットをダウンロードする
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
MedNIST データセットを使用する場合、出典を明示してください、e.g. https://github.com/Project-MONAI/tutorials/blob/master/2d_classification/mednist_tutorial.ipynb。
ここではファイルシステムを使用することなく tar ファイルからダウンロードして読む方法を示すためと、ハンド X-rays の画像だけを望むために、遠隔ソースからデータをロードする方法は異なります。これは分類サンプルではないのでカテゴリーデータは必要ありませんので、tarball をダウンロードし、標準ライブラリを使用してそれをオープンし、そしてハンドのためのファイル名の総てを recall します :
resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
hands = [
os.path.join(data_dir, "Hand", x)
for x in os.listdir(os.path.join(data_dir, "Hand"))
]
tarfile から実際の画像データをロードするため、Matplotlib を使用してこれを行なう変換タイプを定義します。これはデータを準備するために他の変換とともに使用され、ランダム化された増強変換が続きます。ここでは tarball からの準備された画像の総てをキャッシュするために CacheDataset クラスが使用されますので、ランダム化された回転、反転とズーム操作により増強されることになる準備された画像の総てをメモリに持ちます :
class LoadTarJpeg(Transform):
def __call__(self, data):
return plt.imread(data)
train_transforms = Compose(
[
LoadTarJpeg(),
AddChannel(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureType(),
]
)
train_ds = CacheDataset(hands, train_transforms)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=10
)
100%|██████████| 10000/10000 [00:05<00:00, 1691.00it/s]
今は generator と discriminator ネットワークを定義します。パラメータは tar ファイルからロードされた (1, 64, 64) の画像サイズに合うように注意深く選択されています。discriminator への入力画像は非常に小さい画像を生成するために 4 回ダウンサンプリングされます、これらは平坦化されて完全結合層への入力として渡されます。generator への入力潜在ベクトルは shape (64, 8, 8) の出力を生成するために完全結合層に渡されます、そしてこれはリアル画像と同じ shape である最終的な出力を生成するために 3 回アップサンプリングされます。結果を改善するためにネットワークは正規化スキームで初期化されます :
device = torch.device("cuda:0")
disc_net = Discriminator(
in_shape=(1, 64, 64),
channels=(8, 16, 32, 64, 1),
strides=(2, 2, 2, 2, 1),
num_res_units=1,
kernel_size=5,
).to(device)
gen_net = Generator(
latent_shape=latent_size, start_shape=(64, 8, 8),
channels=[32, 16, 8, 1], strides=[2, 2, 2, 1],
)
# initialize both networks
disc_net.apply(normal_init)
gen_net.apply(normal_init)
# input images are scaled to [0,1] so enforce the same of generated outputs
gen_net.conv.add_module("activation", torch.nn.Sigmoid())
gen_net = gen_net.to(device)
今は generator と discriminator のための損失計算プロセスをラップするヘルパー関数とともに使用する損失関数を定義します。optimizer もまた定義します :
disc_loss = torch.nn.BCELoss()
gen_loss = torch.nn.BCELoss()
disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas)
gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas)
def discriminator_loss(gen_images, real_images):
"""
The discriminator loss if calculated by comparing its
prediction for real and generated images.
"""
real = real_images.new_full((real_images.shape[0], 1), real_label)
gen = gen_images.new_full((gen_images.shape[0], 1), gen_label)
realloss = disc_loss(disc_net(real_images), real)
genloss = disc_loss(disc_net(gen_images.detach()), gen)
return (realloss + genloss) / 2
def generator_loss(input):
"""
The generator loss is calculated by determining how well
the discriminator was fooled by the generated images.
"""
output = disc_net(input)
cats = output.new_full(output.shape, real_label)
return gen_loss(output, cats)
今は幾つかのエポックの間データセットに渡り反復することにより訓練します。各バッチのための generator 訓練ステージの後、discriminator は同じリアルと生成画像上で幾つかのステップの間訓練されます。
epoch_loss_values = [(0, 0)]
gen_step_loss = []
disc_step_loss = []
step = 0
for epoch in range(max_epochs):
gen_net.train()
disc_net.train()
epoch_loss = 0
for i, batch_data in enumerate(train_loader):
progress_bar(
i, len(
train_loader),
f"epoch {epoch + 1}, avg loss: {epoch_loss_values[-1][1]:.4f}",
)
real_images = batch_data.to(device)
latent = torch.randn(real_images.shape[0], latent_size).to(device)
gen_opt.zero_grad()
gen_images = gen_net(latent)
loss = generator_loss(gen_images)
loss.backward()
gen_opt.step()
epoch_loss += loss.item()
gen_step_loss.append((step, loss.item()))
if step % disc_train_interval == 0:
disc_total_loss = 0
for _ in range(disc_train_steps):
disc_opt.zero_grad()
dloss = discriminator_loss(gen_images, real_images)
dloss.backward()
disc_opt.step()
disc_total_loss += dloss.item()
disc_step_loss.append((step, disc_total_loss / disc_train_steps))
step += 1
epoch_loss /= step
epoch_loss_values.append((step, epoch_loss))
33/34 epoch 50, avg loss: 0.0563 [============================= ]
generator と discriminator のための個別の損失値は一緒にグラフ化できます。これらは、discriminator を騙す generator の能力がリアルとフェイク画像の間を正確に識別するネットワークの能力と均衡するにつれて、均衡に達するはずです。
plt.figure(figsize=(12, 5))
plt.semilogy(*zip(*gen_step_loss), label="Generator Loss")
plt.semilogy(*zip(*disc_step_loss), label="Discriminator Loss")
plt.grid(True, "both", "both")
plt.legend()
最後に幾つかランダムに生成された画像を示します。望ましくは期待されるように殆どの画像が 4 本の指と親指を持つことです (polydactyl (= 多指の) サンプルがデータセットに多くは存在していないと仮定して)。このデモ目的のノートブックは長くはネットワークを訓練しません、デフォルトの 50 エポックを越えた訓練は結果を改善するはずです。
test_size = 10
test_latent = torch.randn(test_size, latent_size).to(device)
test_images = gen_net(test_latent)
fig, axs = plt.subplots(1, test_size, figsize=(20, 4))
for i, ax in enumerate(axs):
ax.axis("off")
ax.imshow(test_images[i, 0].cpu().data.numpy(), cmap="gray")
以上