MONAI 0.7 : tutorials : 高速化 – MONAI 機能による高速訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/15/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : 高速化 – MONAI 機能による高速訓練
このドキュメントは、訓練パイプラインをプロファイルする方法、データセットを分析して適切なアルゴリズムを選択する方法、そしてシングル GPU、マルチ GPU 更にはマルチノードで GPU 利用を最適化する方法の詳細を紹介します。
このチュートリアルは PyTorch 訓練プログラムと MONAI 最適化訓練プログラムを示し、そしてパフォーマンスを比較します :
- AMP (Auto 混合精度)
- 決定論的変換のための CacheDataset
- データを GPU とキャッシュに移してから、GPU 上でランダムな変換を実行する。
- 軽量タスクではマルチスレッド化された ThreadDataLoader は PyTorch DataLoader よりも高速です。
- 通常の Dice 損失の代わりに MONAI DiceCE 損失を使用する。
- 通常の Adam optimizer の代わりに MONAI Novograd optimizer を使用する。
V100 GPU で、1 分内に 0.95 の検証平均 dice (損失) への訓練収束を獲得できます、それは同じメトリック達成するときの PyTorch の通常の実装と比べておよそ 200x 高速化しています。そして総てのエポックは通常の訓練よりも 20x 高速です。
それは 脾臓 3D セグメンテーション・チュートリアル ノートブックからの変更で、脾臓データセットは http://medicaldecathlon.com/ からダウンロードできます。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import math
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
import torch
from torch.optim import Adam
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import (
CacheDataset,
DataLoader,
ThreadDataLoader,
Dataset,
decollate_batch,
)
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.optimizers import Novograd
from monai.transforms import (
AddChanneld,
AsDiscrete,
Compose,
CropForegroundd,
FgBgToIndicesd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
ToDeviced,
EnsureTyped,
EnsureType,
)
from monai.utils import get_torch_version_tuple, set_determinism
print_config()
if get_torch_version_tuple() < (1, 6):
raise RuntimeError(
"AMP feature only exists in PyTorch version greater than v1.6."
)
MONAI version: 0.2.0+1008.gf65d296f Numpy version: 1.21.2 Pytorch version: 1.10.0a0+3fd9dcf MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: f65d296fe780f869dca84b6714dc36b94794930e Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.18.3 Pillow version: 8.2.0 Tensorboard version: 2.6.0 gdown version: 3.13.0 TorchVision version: 0.11.0a0 tqdm version: 4.62.1 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.3.2 einops version: 0.3.2 transformers version: NOT INSTALLED or UNKNOWN VERSION. For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")
root dir is: /workspace/data/medical
データセットのダウンロード
Decathlon 脾臓データセットをダウンロードして展開します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_root = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_root):
download_and_extract(resource, compressed_file, root_dir, md5)
MSD 脾臓データセット・パスの設定
train_images = sorted(
glob.glob(os.path.join(data_root, "imagesTr", "*.nii.gz"))
)
train_labels = sorted(
glob.glob(os.path.join(data_root, "labelsTr", "*.nii.gz"))
)
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
訓練と検証のための変換のセットアップ
def transformations(fast=False):
train_transforms = [
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
# change to execute transforms with Tensor data
EnsureTyped(keys=["image", "label"]),
]
if fast:
# move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
train_transforms.append(
ToDeviced(keys=["image", "label"], device="cuda:0")
)
train_transforms.extend([
ScaleIntensityRanged(
keys=["image"],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# pre-compute foreground and background indexes
# and cache them to accelerate training
FgBgToIndicesd(
keys="label",
fg_postfix="_fg",
bg_postfix="_bg",
image_key="image",
),
# randomly crop out patch samples from big
# image based on pos / neg ratio
# the image centers of negative samples
# must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
fg_indices_key="label_fg",
bg_indices_key="label_bg",
),
])
val_transforms = [
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
if fast:
# move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
val_transforms.append(
ToDeviced(keys=["image", "label"], device="cuda:0")
)
return Compose(train_transforms), Compose(val_transforms)
訓練手順を定義する
典型的な PyTorch 通常の学習手続きについては、 モデルを訓練するために通常の Dataset, DataLoader, Adam optimizer と Dice 損失を使用します。
MONAI 高速訓練手順については、主として以下の機能を導入します :
- AMP (auto 混合精度): AMP は PyTorch v1.6 でリリースされた重要な機能で、NVIDIA CUDA 11 は AMP の強力なサポートを追加して訓練スピードを大幅に改善しました。
- CacheDataset: キャッシュ機構を持つ Dataset で、訓練の間にデータをロードして決定論的変換の結果をキャッシュできます。
- ToDeviced 変換: データを GPU に移して CacheDataset でキャッシュしてから、直接 GPU 上でランダムな変換を実行し、総てのエポックでの CPU -> GPU 同期を回避します。総ての MONAI 変換が GPU 演算をサポートはしてないことに注意してください、作業は進行中です。
- ThreadDataLoader: マルチプロセス処理の代わりにマルチスレッドを使用します、殆どの計算の結果を既にキャッシュしていますので軽量タスクでは DataLoader よりも高速です。
- Novograd optimizer: Novograd はペーパー "Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks" < https://arxiv.org/pdf/1905.11286.pdf > に基づいています。
- DiceCE 損失関数: Dice 損失と交差エントロピー損失を計算して、これら 2 つの損失の重み付けられた合計を返します。
def train_process(fast=False):
max_epochs = 300
learning_rate = 2e-4
val_interval = 1 # do validation for every epoch
train_trans, val_trans = transformations(fast=fast)
# set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
if fast:
train_ds = CacheDataset(
data=train_files,
transform=train_trans,
cache_rate=1.0,
num_workers=8,
)
val_ds = CacheDataset(
data=val_files, transform=val_trans, cache_rate=1.0, num_workers=5
)
# disable multi-workers because `ThreadDataLoader` works with multi-threads
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True)
else:
train_ds = Dataset(data=train_files, transform=train_trans)
val_ds = Dataset(data=val_files, transform=val_trans)
# num_worker=4 is the best parameter according to the test
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
device = torch.device("cuda:0")
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
# set Novograd optimizer for MONAI training
if fast:
# Novograd paper suggests to use a bigger LR than Adam,
# because Adam does normalization by element-wise second moments
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
else:
optimizer = Adam(model.parameters(), learning_rate)
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
epoch_times = []
total_start = time.time()
for epoch in range(max_epochs):
epoch_start = time.time()
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step_start = time.time()
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
# set AMP for MONAI training
if fast:
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = loss_function(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
print(
f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
f" step time: {(time.time() - step_start):.4f}"
)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
# set AMP for MONAI validation
if fast:
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model
)
else:
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model
)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
dice_metric(y_pred=val_outputs, y=val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
best_metrics_epochs_and_time[0].append(best_metric)
best_metrics_epochs_and_time[1].append(best_metric_epoch)
best_metrics_epochs_and_time[2].append(
time.time() - total_start
)
torch.save(model.state_dict(), "best_metric_model.pth")
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current"
f" mean dice: {metric:.4f}"
f" best mean dice: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
)
print(
f"time consuming of epoch {epoch + 1} is:"
f" {(time.time() - epoch_start):.4f}"
)
epoch_times.append(time.time() - epoch_start)
total_time = time.time() - total_start
print(
f"train completed, best_metric: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
f" total time: {total_time:.4f}"
)
return (
max_epochs,
epoch_loss_values,
metric_values,
epoch_times,
best_metrics_epochs_and_time,
total_time,
)
決定論を有効にして通常の PyTorch 訓練を実行する
set_determinism(seed=0)
regular_start = time.time()
(
epoch_num,
epoch_loss_values,
metric_values,
epoch_times,
best,
train_time,
) = train_process(fast=False)
total_time = time.time() - regular_start
print(
f"total time of {epoch_num} epochs with regular PyTorch training: {total_time:.4f}"
)
決定論を有効にして MONAI 最適化訓練を実行する
set_determinism(seed=0)
monai_start = time.time()
(
epoch_num,
m_epoch_loss_values,
m_metric_values,
m_epoch_times,
m_best,
m_train_time,
) = train_process(fast=True)
m_total_time = time.time() - monai_start
print(
f"total time of {epoch_num} epochs with MONAI fast training: {m_train_time:.4f},"
f" time of preparing cache: {(m_total_time - m_train_time):.4f}"
)
訓練損失と検証メトリクスをプロットする
plt.figure("train", (12, 12))
plt.subplot(2, 2, 1)
plt.title("Regular Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="red")
plt.subplot(2, 2, 2)
plt.title("Regular Val Mean Dice")
x = [i + 1 for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.ylim(0, 1)
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="red")
plt.subplot(2, 2, 3)
plt.title("Fast Epoch Average Loss")
x = [i + 1 for i in range(len(m_epoch_loss_values))]
y = m_epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="green")
plt.subplot(2, 2, 4)
plt.title("Fast Val Mean Dice")
x = [i + 1 for i in range(len(m_metric_values))]
y = m_metric_values
plt.xlabel("epoch")
plt.ylim(0, 1)
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="green")
plt.show()
合計時間と総てのエポック時間をプロットする
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Total Train Time(300 epochs)")
plt.bar(
"regular PyTorch", total_time, 1, label="Regular training", color="red"
)
plt.bar("Fast", m_total_time, 1, label="Fast training", color="green")
plt.ylabel("secs")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.subplot(1, 2, 2)
plt.title("Epoch Time")
x = [i + 1 for i in range(len(epoch_times))]
plt.xlabel("epoch")
plt.ylabel("secs")
plt.plot(x, epoch_times, label="Regular training", color="red")
plt.plot(x, m_epoch_times, label="Fast training", color="green")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.show()
メトリクスを取得するための合計時間をプロットする
def get_best_metric_time(threshold, best_values):
for i, v in enumerate(best_values[0]):
if v > threshold:
return best_values[2][i]
return -1
def get_best_metric_epochs(threshold, best_values):
for i, v in enumerate(best_values[0]):
if v > threshold:
return best_values[1][i]
return -1
def get_label(index):
if index == 0:
return "Regular training"
elif index == 1:
return "Fast training"
else:
return None
plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Metrics Time")
plt.xlabel("secs")
plt.ylabel("best mean_dice")
plt.plot(best[2], best[0], label="Regular training", color="red")
plt.plot(m_best[2], m_best[0], label="Fast training", color="green")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.subplot(1, 3, 2)
plt.title("Typical Metrics Time")
plt.xlabel("best mean_dice")
plt.ylabel("secs")
labels = ["0.90", "0.90 ", "0.93", "0.93 ", "0.95", "0.95 ", "0.97", "0.97 "]
x_values = [0.9, 0.9, 0.93, 0.93, 0.95, 0.95, 0.97, 0.97]
for i, (l, x) in enumerate(zip(labels, x_values)):
value = int(get_best_metric_time(x, best if i % 2 == 0 else m_best))
color = "red" if i % 2 == 0 else "green"
plt.bar(l, value, 0.5, label=get_label(i), color=color)
plt.text(l, value, "%s" % value, ha="center", va="bottom")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.subplot(1, 3, 3)
plt.title("Typical Metrics Epochs")
plt.xlabel("best mean_dice")
plt.ylabel("epochs")
for i, (l, x) in enumerate(zip(labels, x_values)):
value = int(get_best_metric_epochs(x, best if i % 2 == 0 else m_best))
color = "red" if i % 2 == 0 else "green"
plt.bar(l, value, 0.5, label=get_label(i), color=color)
plt.text(l, value, "%s" % value, ha="center", va="bottom")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.show()
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上