MONAI 1.0 : tutorials : 2D 分類 – MedNIST データセットによる医用画像分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/01/2022 (1.0.1)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
MONAI 1.0 : tutorials : 2D 分類 – MedNIST データセットによる医用画像分類
このノートブックは MONAI 機能を既存の PyTorch プログラムに容易に統合する方法を示します。それは MedNIST データセットに基づいています、これは初心者のためにチュートリアルとして非常に適切です。このチュートリアルはまた MONAI 組込みのオクルージョン感度の機能も利用しています。
このチュートリアルでは、MedNIST データセットに基づく end-to-end な訓練と評価サンプルを紹介します。
以下のステップで進めます :
- 訓練とテスト用のデータセットを作成する。
- データを前処理するために MONAI 変換を利用します。
- 分類のために MONAI から DenseNet を利用します。
- モデルを PyTorch プログラムで訓練します。
- テストデータセット上で評価します。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import matplotlib.pyplot as plt
import PIL
import torch
import numpy as np
from sklearn.metrics import classification_report
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import decollate_batch, DataLoader
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
)
from monai.utils import set_determinism
print_config()
MONAI version: 0.9.dev2152 Numpy version: 1.19.5 Pytorch version: 1.10.0+cu111 MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: c5bd8aff8ba461d7b349eb92427d452481a7eb72 Optional dependencies: Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION. Nibabel version: 3.0.2 scikit-image version: 0.18.3 Pillow version: 7.1.2 Tensorboard version: 2.7.0 gdown version: 3.6.4 TorchVision version: 0.11.1+cu111 tqdm version: 4.62.3 lmdb version: 0.99 psutil version: 5.4.8 pandas version: 1.1.5 einops version: NOT INSTALLED or UNKNOWN VERSION. transformers version: NOT INSTALLED or UNKNOWN VERSION. mlflow version: NOT INSTALLED or UNKNOWN VERSION. For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットをダウンロードする
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
MedNIST データセットを使用する場合、出典を明示してください。
resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
再現性のために決定論的訓練を設定する
set_determinism(seed=0)
データセットフォルダから画像ファイル名を読む
まず最初に、データセットファイルを確認して幾つかの統計情報を表示します。
データセットには 6 つのフォルダがあります : Hand, AbdomenCT, CXR, ChestCT, BreastMRI, HeadCT,
これらは分類モデルを訓練するためのラベルとして使用されるべきです。
class_names = sorted(x for x in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, x)))
num_class = len(class_names)
image_files = [
[
os.path.join(data_dir, class_names[i], x)
for x in os.listdir(os.path.join(data_dir, class_names[i]))
]
for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
image_files_list.extend(image_files[i])
image_class.extend([i] * num_each[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size
print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")
Total image count: 58954 Image dimensions: 64 x 64 Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT'] Label counts: [10000, 8954, 10000, 10000, 10000, 10000]
データセットから画像をランダムに選択して可視化して確認する
plt.subplots(3, 3, figsize=(8, 8))
for i, k in enumerate(np.random.randint(num_total, size=9)):
im = PIL.Image.open(image_files_list[k])
arr = np.array(im)
plt.subplot(3, 3, i + 1)
plt.xlabel(class_names[image_class[k]])
plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()
訓練、検証とテストデータのリストを準備する
データセットの 10% を検証用に、そして 10% をテスト用にランダムに選択します。
val_frac = 0.1
test_frac = 0.1
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)
test_split = int(test_frac * length)
val_split = int(val_frac * length) + test_split
test_indices = indices[:test_split]
val_indices = indices[test_split:val_split]
train_indices = indices[val_split:]
train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [image_files_list[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]
print(
f"Training count: {len(train_x)}, Validation count: "
f"{len(val_x)}, Test count: {len(test_x)}")
Training count: 47156, Validation count: 5913, Test count: 5885
データを前処理するために MONAI 変換、Dataset と Dataloader を定義する
train_transforms = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
]
)
val_transforms = Compose(
[LoadImage(image_only=True), EnsureChannelFirst(), ScaleIntensity()])
y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_class)])
class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]
train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = DataLoader(
train_ds, batch_size=300, shuffle=True, num_workers=10)
val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = DataLoader(
val_ds, batch_size=300, num_workers=10)
test_ds = MedNISTDataset(test_x, test_y, val_transforms)
test_loader = DataLoader(
test_ds, batch_size=300, num_workers=10)
ネットワークと optimizer を定義する
- バッチ毎にモデルがどのくらい更新されるかについて学習率を設定します。
- 総エポック数を設定します、シャッフルしてランダムな変換を行ないますので、総てのエポックの訓練データは異なります。
そしてこれは get start チュートリアルに過ぎませんので、4 エポックだけ訓練しましょう。
10 エポック訓練すれば、モデルはテストデータセット上で 100% 精度を達成できます。
- MONAI からの DenseNet を使用して GPU デバイスに移します、この DenseNet は 2D と 3D 分類タスクの両方をサポートできます。
- Adam optimizer を使用します。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNet121(spatial_dims=2, in_channels=1,
out_channels=num_class).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
max_epochs = 4
val_interval = 1
auc_metric = ROCAUCMetric()
モデル訓練
典型的な PyTorch 訓練を実行します、これはエポック・ループとステップ・ループを実行して、総てのエポック後に検証を行ないます。ベストの検証精度を得た場合、モデル重みをファイルにセーブします。
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f"{step}/{len(train_ds) // train_loader.batch_size}, "
f"train_loss: {loss.item():.4f}")
epoch_len = len(train_ds) // train_loader.batch_size
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
y_pred = torch.tensor([], dtype=torch.float32, device=device)
y = torch.tensor([], dtype=torch.long, device=device)
for val_data in val_loader:
val_images, val_labels = (
val_data[0].to(device),
val_data[1].to(device),
)
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
y = torch.cat([y, val_labels], dim=0)
y_onehot = [y_trans(i) for i in decollate_batch(y, detach=False)]
y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)]
auc_metric(y_pred_act, y_onehot)
result = auc_metric.aggregate()
auc_metric.reset()
del y_pred_act, y_onehot
metric_values.append(result)
acc_value = torch.eq(y_pred.argmax(dim=1), y)
acc_metric = acc_value.sum().item() / len(acc_value)
if result > best_metric:
best_metric = result
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(
root_dir, "best_metric_model.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current AUC: {result:.4f}"
f" current accuracy: {acc_metric:.4f}"
f" best AUC: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}")
---------- epoch 1/4 (...) epoch 1 average loss: 0.7768 saved new best metric model current epoch: 1 current AUC: 0.9984 current accuracy: 0.9618 best AUC: 0.9984 at epoch: 1 ---------- epoch 2/4 (...) epoch 2 average loss: 0.1612 saved new best metric model current epoch: 2 current AUC: 0.9997 current accuracy: 0.9863 best AUC: 0.9997 at epoch: 2 ---------- epoch 3/4 (...) epoch 3 average loss: 0.0743 saved new best metric model current epoch: 3 current AUC: 0.9999 current accuracy: 0.9924 best AUC: 0.9999 at epoch: 3 ---------- epoch 4/4 (...) epoch 4 average loss: 0.0462 saved new best metric model current epoch: 4 current AUC: 1.0000 current accuracy: 0.9944 best AUC: 1.0000 at epoch: 4 train completed, best_metric: 1.0000 at epoch: 4
(...) epoch 9 average loss: 0.0094 saved new best metric model current epoch: 9 current AUC: 1.0000 current accuracy: 0.9997 best AUC: 1.0000 at epoch: 9 ---------- epoch 10/10 (...) epoch 10 average loss: 0.0080 current epoch: 10 current AUC: 1.0000 current accuracy: 0.9997 best AUC: 1.0000 at epoch: 9 train completed, best_metric: 1.0000 at epoch: 9
損失とメトリックをプロットする
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val AUC")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()
テストデータセット上でモデルを評価する
訓練と検証の後、検証テスト上のベストモデルを既に得ています。
モデルをテストデータセット上でそれが堅牢で over fitting していないかを確認するために評価する必要があります。
分類レポートを生成するためにこれらの予測を使用します。
model.load_state_dict(torch.load(
os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for test_data in test_loader:
test_images, test_labels = (
test_data[0].to(device),
test_data[1].to(device),
)
pred = model(test_images).argmax(dim=1)
for i in range(len(pred)):
y_true.append(test_labels[i].item())
y_pred.append(pred[i].item())
print(classification_report(
y_true, y_pred, target_names=class_names, digits=4))
Note: you may need to restart the kernel to use updated packages. precision recall f1-score support AbdomenCT 0.9816 0.9917 0.9867 969 BreastMRI 0.9968 0.9831 0.9899 944 CXR 0.9979 0.9928 0.9954 973 ChestCT 0.9938 0.9990 0.9964 959 Hand 0.9934 0.9934 0.9934 1055 HeadCT 0.9960 0.9990 0.9975 985 accuracy 0.9932 5885 macro avg 0.9932 0.9932 0.9932 5885 weighted avg 0.9932 0.9932 0.9932 5885
precision recall f1-score support AbdomenCT 0.9980 0.9990 0.9985 995 BreastMRI 0.9989 1.0000 0.9994 880 CXR 1.0000 0.9969 0.9985 982 ChestCT 0.9990 1.0000 0.9995 1014 Hand 0.9981 0.9990 0.9986 1048 HeadCT 1.0000 0.9990 0.9995 976 accuracy 0.9990 5895 macro avg 0.9990 0.9990 0.9990 5895 weighted avg 0.9990 0.9990 0.9990 5895
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合ディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上