ホーム » MONAI » MONAI 0.7 : tutorials : 2D 分類 – MedNIST データセットによる医用画像分類

MONAI 0.7 : tutorials : 2D 分類 – MedNIST データセットによる医用画像分類

MONAI 0.7 : tutorials : 2D 分類 – MedNIST データセットによる医用画像分類 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/05/2021 (0.7.0)

* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス 無料 Web セミナー開催中

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

 

MONAI 0.7 : tutorials : 2D 分類 – MedNIST データセットによる医用画像分類

このノートブックは MONAI 機能を既存の PyTorch プログラムに容易に統合する方法を示します。それは MedNIST データセットに基づいています、これは初心者のためにチュートリアルとして非常に適切です。このチュートリアルはまた MONAI 組込みのオクルージョン感度の機能も利用しています。

また、このチュートリアルでは、MONAIに内蔵されているオクルージョン感度機能も利用しています。

このチュートリアルでは、MedNIST データセットに基づく end-to-end な訓練と評価サンプルを紹介します。

以下のステップで進めます :

  • 訓練とテストのためにデータセットを作成する。
  • データを前処理するために MONAI 変換を利用します。
  • 分類のために MONAI から DenseNet を利用します。
  • モデルを PyTorch プログラムで訓練します。
  • テストデータセット上で評価します。

 

環境のセットアップ

!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

 

インポートのセットアップ

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import matplotlib.pyplot as plt
import PIL
import torch
import numpy as np
from sklearn.metrics import classification_report

from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet121
from monai.transforms import (
    Activations,
    AddChannel,
    AsDiscrete,
    Compose,
    LoadImage,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    EnsureType,
)
from monai.utils import set_determinism

print_config()
MONAI version: 0.6.0rc1+15.gf3d436a0
Numpy version: 1.20.3
Pytorch version: 1.9.0a0+c3d40fd
MONAI flags: HAS_EXT = True, USE_COMPILED = False
MONAI rev id: f3d436a09deefcf905ece2faeec37f55ab030003

Optional dependencies:
Pytorch Ignite version: 0.4.5
Nibabel version: 3.2.1
scikit-image version: 0.15.0
Pillow version: 8.2.0
Tensorboard version: 2.5.0
gdown version: 3.13.0
TorchVision version: 0.10.0a0
ITK version: 5.1.2
tqdm version: 4.53.0
lmdb version: 1.2.1
psutil version: 5.8.0
pandas version: 1.1.4
einops version: 0.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

 

データディレクトリのセットアップ

MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical

 

データセットをダウンロードする

MedMNIST データセットは TCIA, RSNA Bone Age チャレンジNIH Chest X-ray データセット からの様々なセットから集められました。

データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。

MedNIST データセットを使用する場合、出典を明示してください。

resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, root_dir, md5)

 

再現性のために決定論的訓練を設定する

set_determinism(seed=0)

 

データセットフォルダから画像ファイル名を読む

まず最初に、データセットファイルを確認して幾つかの統計を表示します。
データセットには 6 つのフォルダがあります : Hand, AbdomenCT, CXR, ChestCT, BreastMRI, HeadCT,
これらは分類モデルを訓練するためにラベルとして使用されるべきです。

class_names = sorted(x for x in os.listdir(data_dir)
                     if os.path.isdir(os.path.join(data_dir, x)))
num_class = len(class_names)
image_files = [
    [
        os.path.join(data_dir, class_names[i], x)
        for x in os.listdir(os.path.join(data_dir, class_names[i]))
    ]
    for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
    image_files_list.extend(image_files[i])
    image_class.extend([i] * num_each[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")
Total image count: 58954
Image dimensions: 64 x 64
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']
Label counts: [10000, 8954, 10000, 10000, 10000, 10000]

 

データセットから画像をランダムに選択して可視化して確認する

plt.subplots(3, 3, figsize=(8, 8))
for i, k in enumerate(np.random.randint(num_total, size=9)):
    im = PIL.Image.open(image_files_list[k])
    arr = np.array(im)
    plt.subplot(3, 3, i + 1)
    plt.xlabel(class_names[image_class[k]])
    plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()

 

訓練、検証とテストデータのリストを準備する

データセットの 10% を検証用に、そして 10% をテスト用にランダムに選択します。

val_frac = 0.1
test_frac = 0.1
length = len(image_files_list)
indices = np.arange(length)
np.random.shuffle(indices)

test_split = int(test_frac * length)
val_split = int(val_frac * length) + test_split
test_indices = indices[:test_split]
val_indices = indices[test_split:val_split]
train_indices = indices[val_split:]

train_x = [image_files_list[i] for i in train_indices]
train_y = [image_class[i] for i in train_indices]
val_x = [image_files_list[i] for i in val_indices]
val_y = [image_class[i] for i in val_indices]
test_x = [image_files_list[i] for i in test_indices]
test_y = [image_class[i] for i in test_indices]

print(
    f"Training count: {len(train_x)}, Validation count: "
    f"{len(val_x)}, Test count: {len(test_x)}")
Training count: 47156, Validation count: 5913, Test count: 5885

 

データを前処理するために MONAI 変換、Dataset と Dataloader を定義する

train_transforms = Compose(
    [
        LoadImage(image_only=True),
        AddChannel(),
        ScaleIntensity(),
        RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
        RandFlip(spatial_axis=0, prob=0.5),
        RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
        EnsureType(),
    ]
)

val_transforms = Compose(
    [LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])

y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])
y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=num_class)])
class MedNISTDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels, transforms):
        self.image_files = image_files
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        return self.transforms(self.image_files[index]), self.labels[index]


train_ds = MedNISTDataset(train_x, train_y, train_transforms)
train_loader = torch.utils.data.DataLoader(
    train_ds, batch_size=300, shuffle=True, num_workers=10)

val_ds = MedNISTDataset(val_x, val_y, val_transforms)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=300, num_workers=10)

test_ds = MedNISTDataset(test_x, test_y, val_transforms)
test_loader = torch.utils.data.DataLoader(
    test_ds, batch_size=300, num_workers=10)

 

ネットワークと optimizer を定義する

  1. バッチ毎にモデルがどのくらい更新されるかについて学習率を設定します。

  2. 合計エポック数を設定します、シャッフルしてランダムな変換を行ないますので、総てのエポックの訓練データは異なります。そしてこれは get start チュートリアルに過ぎませんので、4 エポックだけ訓練しましょう。10 エポック訓練すれば、モデルはテストデータセット上で 100% 精度を達成できます。

  3. MONAI からの DenseNet を使用して GPU デバイスに移動します、この DenseNet は 2D と 3D 分類タスクの両方をサポートできます。

  4. Adam optimizer を使用します。
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DenseNet121(spatial_dims=2, in_channels=1,
                        out_channels=num_class).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)
    max_epochs = 4
    val_interval = 1
    auc_metric = ROCAUCMetric()
    

     

    モデル訓練

    典型的な PyTorch 訓練を実行します、これはエポック・ループとステップ・ループを実行して、総てのエポック後に検証を行ないます。ベストの検証精度を得た場合、モデル重みをファイルにセーブします。

    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = []
    metric_values = []
    
    for epoch in range(max_epochs):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_epochs}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print(
                f"{step}/{len(train_ds) // train_loader.batch_size}, "
                f"train_loss: {loss.item():.4f}")
            epoch_len = len(train_ds) // train_loader.batch_size
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images, val_labels = (
                        val_data[0].to(device),
                        val_data[1].to(device),
                    )
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)
                y_onehot = [y_trans(i) for i in decollate_batch(y)]
                y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)]
                auc_metric(y_pred_act, y_onehot)
                result = auc_metric.aggregate()
                auc_metric.reset()
                del y_pred_act, y_onehot
                metric_values.append(result)
                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                if result > best_metric:
                    best_metric = result
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(
                        root_dir, "best_metric_model.pth"))
                    print("saved new best metric model")
                print(
                    f"current epoch: {epoch + 1} current AUC: {result:.4f}"
                    f" current accuracy: {acc_metric:.4f}"
                    f" best AUC: {best_metric:.4f}"
                    f" at epoch: {best_metric_epoch}"
                )
    
    print(
        f"train completed, best_metric: {best_metric:.4f} "
        f"at epoch: {best_metric_epoch}")
    
    ----------
    epoch 1/4
    1/157, train_loss: 1.7898
    2/157, train_loss: 1.7560
    3/157, train_loss: 1.7461
    4/157, train_loss: 1.7241
    5/157, train_loss: 1.6973
    6/157, train_loss: 1.6706
    7/157, train_loss: 1.6388
    8/157, train_loss: 1.6210
    9/157, train_loss: 1.5989
    10/157, train_loss: 1.5599
    11/157, train_loss: 1.5826
    12/157, train_loss: 1.5339
    13/157, train_loss: 1.5235
    14/157, train_loss: 1.5098
    15/157, train_loss: 1.4746
    16/157, train_loss: 1.4584
    17/157, train_loss: 1.4365
    18/157, train_loss: 1.4328
    19/157, train_loss: 1.4274
    20/157, train_loss: 1.4327
    21/157, train_loss: 1.4017
    22/157, train_loss: 1.3231
    23/157, train_loss: 1.3180
    24/157, train_loss: 1.3001
    25/157, train_loss: 1.2958
    26/157, train_loss: 1.3021
    27/157, train_loss: 1.2336
    28/157, train_loss: 1.2154
    29/157, train_loss: 1.2595
    30/157, train_loss: 1.2050
    31/157, train_loss: 1.2161
    32/157, train_loss: 1.2106
    33/157, train_loss: 1.1495
    34/157, train_loss: 1.1550
    35/157, train_loss: 1.1246
    36/157, train_loss: 1.1607
    37/157, train_loss: 1.1126
    38/157, train_loss: 1.0987
    39/157, train_loss: 1.0694
    40/157, train_loss: 1.1181
    41/157, train_loss: 1.0576
    42/157, train_loss: 1.0703
    43/157, train_loss: 1.0414
    44/157, train_loss: 1.0446
    45/157, train_loss: 1.0313
    46/157, train_loss: 0.9786
    47/157, train_loss: 0.9767
    48/157, train_loss: 0.9579
    49/157, train_loss: 0.9659
    50/157, train_loss: 1.0069
    51/157, train_loss: 0.9868
    52/157, train_loss: 0.9637
    53/157, train_loss: 0.9301
    54/157, train_loss: 0.9382
    55/157, train_loss: 0.8923
    56/157, train_loss: 0.9034
    57/157, train_loss: 0.8674
    58/157, train_loss: 0.8707
    59/157, train_loss: 0.8876
    60/157, train_loss: 0.8628
    61/157, train_loss: 0.7709
    62/157, train_loss: 0.8494
    63/157, train_loss: 0.8264
    64/157, train_loss: 0.8011
    65/157, train_loss: 0.8186
    66/157, train_loss: 0.8016
    67/157, train_loss: 0.7813
    68/157, train_loss: 0.7447
    69/157, train_loss: 0.7201
    70/157, train_loss: 0.7323
    71/157, train_loss: 0.7332
    72/157, train_loss: 0.7379
    73/157, train_loss: 0.7495
    74/157, train_loss: 0.7157
    75/157, train_loss: 0.7007
    76/157, train_loss: 0.7058
    77/157, train_loss: 0.6814
    78/157, train_loss: 0.6738
    79/157, train_loss: 0.6449
    80/157, train_loss: 0.6393
    81/157, train_loss: 0.6238
    82/157, train_loss: 0.6234
    83/157, train_loss: 0.6262
    84/157, train_loss: 0.6044
    85/157, train_loss: 0.6005
    86/157, train_loss: 0.5783
    87/157, train_loss: 0.5570
    88/157, train_loss: 0.5569
    89/157, train_loss: 0.5633
    90/157, train_loss: 0.5199
    91/157, train_loss: 0.5846
    92/157, train_loss: 0.5915
    93/157, train_loss: 0.5612
    94/157, train_loss: 0.5785
    95/157, train_loss: 0.5654
    96/157, train_loss: 0.5437
    97/157, train_loss: 0.5429
    98/157, train_loss: 0.5053
    99/157, train_loss: 0.5221
    100/157, train_loss: 0.4928
    101/157, train_loss: 0.5064
    102/157, train_loss: 0.5104
    103/157, train_loss: 0.4830
    104/157, train_loss: 0.4901
    105/157, train_loss: 0.4957
    106/157, train_loss: 0.4913
    107/157, train_loss: 0.4722
    108/157, train_loss: 0.4756
    109/157, train_loss: 0.4803
    110/157, train_loss: 0.4534
    111/157, train_loss: 0.4383
    112/157, train_loss: 0.4437
    113/157, train_loss: 0.4264
    114/157, train_loss: 0.4067
    115/157, train_loss: 0.4268
    116/157, train_loss: 0.4136
    117/157, train_loss: 0.4115
    118/157, train_loss: 0.4082
    119/157, train_loss: 0.4246
    120/157, train_loss: 0.4018
    121/157, train_loss: 0.4161
    122/157, train_loss: 0.3475
    123/157, train_loss: 0.4233
    124/157, train_loss: 0.4043
    125/157, train_loss: 0.3192
    126/157, train_loss: 0.3611
    127/157, train_loss: 0.3475
    128/157, train_loss: 0.3580
    129/157, train_loss: 0.3286
    130/157, train_loss: 0.3688
    131/157, train_loss: 0.3205
    132/157, train_loss: 0.3745
    133/157, train_loss: 0.3709
    134/157, train_loss: 0.3535
    135/157, train_loss: 0.3729
    136/157, train_loss: 0.2839
    137/157, train_loss: 0.3614
    138/157, train_loss: 0.2981
    139/157, train_loss: 0.3280
    140/157, train_loss: 0.3031
    141/157, train_loss: 0.2994
    142/157, train_loss: 0.3296
    143/157, train_loss: 0.3032
    144/157, train_loss: 0.2965
    145/157, train_loss: 0.2955
    146/157, train_loss: 0.3379
    147/157, train_loss: 0.3339
    148/157, train_loss: 0.3092
    149/157, train_loss: 0.2723
    150/157, train_loss: 0.3171
    151/157, train_loss: 0.2933
    152/157, train_loss: 0.2841
    153/157, train_loss: 0.2723
    154/157, train_loss: 0.2954
    155/157, train_loss: 0.2955
    156/157, train_loss: 0.2969
    157/157, train_loss: 0.2674
    158/157, train_loss: 0.2465
    epoch 1 average loss: 0.7768
    saved new best metric model
    current epoch: 1 current AUC: 0.9984 current accuracy: 0.9618 best AUC: 0.9984 at epoch: 1
    ----------
    epoch 2/4
    1/157, train_loss: 0.2618
    2/157, train_loss: 0.2525
    3/157, train_loss: 0.2640
    4/157, train_loss: 0.2566
    5/157, train_loss: 0.2526
    6/157, train_loss: 0.2337
    7/157, train_loss: 0.2311
    8/157, train_loss: 0.2351
    9/157, train_loss: 0.2332
    10/157, train_loss: 0.2783
    11/157, train_loss: 0.2521
    12/157, train_loss: 0.2458
    13/157, train_loss: 0.2261
    14/157, train_loss: 0.2362
    15/157, train_loss: 0.2702
    16/157, train_loss: 0.2399
    17/157, train_loss: 0.2113
    18/157, train_loss: 0.2452
    19/157, train_loss: 0.2202
    20/157, train_loss: 0.2124
    21/157, train_loss: 0.2050
    22/157, train_loss: 0.2297
    23/157, train_loss: 0.2410
    24/157, train_loss: 0.2254
    25/157, train_loss: 0.2276
    26/157, train_loss: 0.2344
    27/157, train_loss: 0.1969
    28/157, train_loss: 0.2110
    29/157, train_loss: 0.2114
    30/157, train_loss: 0.2424
    31/157, train_loss: 0.2111
    32/157, train_loss: 0.1963
    33/157, train_loss: 0.1799
    34/157, train_loss: 0.1925
    35/157, train_loss: 0.2277
    36/157, train_loss: 0.2327
    37/157, train_loss: 0.1968
    38/157, train_loss: 0.2165
    39/157, train_loss: 0.1924
    40/157, train_loss: 0.1959
    41/157, train_loss: 0.1764
    42/157, train_loss: 0.2327
    43/157, train_loss: 0.1955
    44/157, train_loss: 0.1669
    45/157, train_loss: 0.1829
    46/157, train_loss: 0.1894
    47/157, train_loss: 0.2079
    48/157, train_loss: 0.1984
    49/157, train_loss: 0.2035
    50/157, train_loss: 0.1879
    51/157, train_loss: 0.1839
    52/157, train_loss: 0.1885
    53/157, train_loss: 0.1887
    54/157, train_loss: 0.1733
    55/157, train_loss: 0.1828
    56/157, train_loss: 0.1593
    57/157, train_loss: 0.1906
    58/157, train_loss: 0.1494
    59/157, train_loss: 0.1740
    60/157, train_loss: 0.1791
    61/157, train_loss: 0.1763
    62/157, train_loss: 0.1659
    63/157, train_loss: 0.1961
    64/157, train_loss: 0.1593
    65/157, train_loss: 0.1468
    66/157, train_loss: 0.1576
    67/157, train_loss: 0.1567
    68/157, train_loss: 0.1751
    69/157, train_loss: 0.1640
    70/157, train_loss: 0.1702
    71/157, train_loss: 0.1406
    72/157, train_loss: 0.1519
    73/157, train_loss: 0.1552
    74/157, train_loss: 0.1581
    75/157, train_loss: 0.1564
    76/157, train_loss: 0.1741
    77/157, train_loss: 0.1474
    78/157, train_loss: 0.1473
    79/157, train_loss: 0.1402
    80/157, train_loss: 0.1402
    81/157, train_loss: 0.1471
    82/157, train_loss: 0.1556
    83/157, train_loss: 0.1329
    84/157, train_loss: 0.1578
    85/157, train_loss: 0.1364
    86/157, train_loss: 0.1413
    87/157, train_loss: 0.1170
    88/157, train_loss: 0.1332
    89/157, train_loss: 0.1369
    90/157, train_loss: 0.1500
    91/157, train_loss: 0.1320
    92/157, train_loss: 0.1265
    93/157, train_loss: 0.1444
    94/157, train_loss: 0.1278
    95/157, train_loss: 0.1348
    96/157, train_loss: 0.1403
    97/157, train_loss: 0.1246
    98/157, train_loss: 0.1125
    99/157, train_loss: 0.1509
    100/157, train_loss: 0.1270
    101/157, train_loss: 0.1286
    102/157, train_loss: 0.1160
    103/157, train_loss: 0.1239
    104/157, train_loss: 0.1052
    105/157, train_loss: 0.1238
    106/157, train_loss: 0.1110
    107/157, train_loss: 0.1429
    108/157, train_loss: 0.1097
    109/157, train_loss: 0.1099
    110/157, train_loss: 0.1403
    111/157, train_loss: 0.1460
    112/157, train_loss: 0.1216
    113/157, train_loss: 0.1079
    114/157, train_loss: 0.1187
    115/157, train_loss: 0.1519
    116/157, train_loss: 0.1128
    117/157, train_loss: 0.1174
    118/157, train_loss: 0.1094
    119/157, train_loss: 0.1106
    120/157, train_loss: 0.1179
    121/157, train_loss: 0.1278
    122/157, train_loss: 0.1331
    123/157, train_loss: 0.1484
    124/157, train_loss: 0.1140
    125/157, train_loss: 0.1302
    126/157, train_loss: 0.1201
    127/157, train_loss: 0.1038
    128/157, train_loss: 0.1108
    129/157, train_loss: 0.1446
    130/157, train_loss: 0.0905
    131/157, train_loss: 0.1323
    132/157, train_loss: 0.0995
    133/157, train_loss: 0.0998
    134/157, train_loss: 0.0937
    135/157, train_loss: 0.1240
    136/157, train_loss: 0.0871
    137/157, train_loss: 0.1131
    138/157, train_loss: 0.1349
    139/157, train_loss: 0.1316
    140/157, train_loss: 0.0989
    141/157, train_loss: 0.1081
    142/157, train_loss: 0.1240
    143/157, train_loss: 0.1104
    144/157, train_loss: 0.0980
    145/157, train_loss: 0.1081
    146/157, train_loss: 0.0906
    147/157, train_loss: 0.1624
    148/157, train_loss: 0.1025
    149/157, train_loss: 0.1071
    150/157, train_loss: 0.0988
    151/157, train_loss: 0.1151
    152/157, train_loss: 0.1425
    153/157, train_loss: 0.1092
    154/157, train_loss: 0.0831
    155/157, train_loss: 0.1252
    156/157, train_loss: 0.0960
    157/157, train_loss: 0.1076
    158/157, train_loss: 0.0725
    epoch 2 average loss: 0.1612
    saved new best metric model
    current epoch: 2 current AUC: 0.9997 current accuracy: 0.9863 best AUC: 0.9997 at epoch: 2
    ----------
    epoch 3/4
    1/157, train_loss: 0.1095
    2/157, train_loss: 0.0810
    3/157, train_loss: 0.1085
    4/157, train_loss: 0.1033
    5/157, train_loss: 0.1527
    6/157, train_loss: 0.0988
    7/157, train_loss: 0.0935
    8/157, train_loss: 0.0903
    9/157, train_loss: 0.0941
    10/157, train_loss: 0.0742
    11/157, train_loss: 0.1127
    12/157, train_loss: 0.0803
    13/157, train_loss: 0.0937
    14/157, train_loss: 0.0810
    15/157, train_loss: 0.0965
    16/157, train_loss: 0.0705
    17/157, train_loss: 0.0802
    18/157, train_loss: 0.1040
    19/157, train_loss: 0.0940
    20/157, train_loss: 0.0758
    21/157, train_loss: 0.1002
    22/157, train_loss: 0.0720
    23/157, train_loss: 0.0773
    24/157, train_loss: 0.0906
    25/157, train_loss: 0.1002
    26/157, train_loss: 0.0948
    27/157, train_loss: 0.0731
    28/157, train_loss: 0.0938
    29/157, train_loss: 0.0731
    30/157, train_loss: 0.1004
    31/157, train_loss: 0.0829
    32/157, train_loss: 0.0864
    33/157, train_loss: 0.0729
    34/157, train_loss: 0.0773
    35/157, train_loss: 0.0719
    36/157, train_loss: 0.0875
    37/157, train_loss: 0.0897
    38/157, train_loss: 0.0740
    39/157, train_loss: 0.1091
    40/157, train_loss: 0.0652
    41/157, train_loss: 0.0899
    42/157, train_loss: 0.0853
    43/157, train_loss: 0.0727
    44/157, train_loss: 0.0843
    45/157, train_loss: 0.0878
    46/157, train_loss: 0.1135
    47/157, train_loss: 0.1041
    48/157, train_loss: 0.0935
    49/157, train_loss: 0.0879
    50/157, train_loss: 0.0922
    51/157, train_loss: 0.0862
    52/157, train_loss: 0.0692
    53/157, train_loss: 0.0784
    54/157, train_loss: 0.0986
    55/157, train_loss: 0.0707
    56/157, train_loss: 0.1013
    57/157, train_loss: 0.0598
    58/157, train_loss: 0.0639
    59/157, train_loss: 0.0587
    60/157, train_loss: 0.1027
    61/157, train_loss: 0.0711
    62/157, train_loss: 0.0775
    63/157, train_loss: 0.1045
    64/157, train_loss: 0.0655
    65/157, train_loss: 0.0621
    66/157, train_loss: 0.0636
    67/157, train_loss: 0.0774
    68/157, train_loss: 0.0875
    69/157, train_loss: 0.0664
    70/157, train_loss: 0.0707
    71/157, train_loss: 0.0814
    72/157, train_loss: 0.1022
    73/157, train_loss: 0.0820
    74/157, train_loss: 0.0829
    75/157, train_loss: 0.0809
    76/157, train_loss: 0.0975
    77/157, train_loss: 0.0684
    78/157, train_loss: 0.0686
    79/157, train_loss: 0.0831
    80/157, train_loss: 0.0671
    81/157, train_loss: 0.0647
    82/157, train_loss: 0.0574
    83/157, train_loss: 0.0611
    84/157, train_loss: 0.0886
    85/157, train_loss: 0.0674
    86/157, train_loss: 0.0609
    87/157, train_loss: 0.0582
    88/157, train_loss: 0.0584
    89/157, train_loss: 0.0751
    90/157, train_loss: 0.0720
    91/157, train_loss: 0.0727
    92/157, train_loss: 0.0664
    93/157, train_loss: 0.0681
    94/157, train_loss: 0.0791
    95/157, train_loss: 0.0880
    96/157, train_loss: 0.0746
    97/157, train_loss: 0.0730
    98/157, train_loss: 0.0818
    99/157, train_loss: 0.0617
    100/157, train_loss: 0.0646
    101/157, train_loss: 0.0607
    102/157, train_loss: 0.0749
    103/157, train_loss: 0.0656
    104/157, train_loss: 0.0607
    105/157, train_loss: 0.0713
    106/157, train_loss: 0.0725
    107/157, train_loss: 0.0711
    108/157, train_loss: 0.0642
    109/157, train_loss: 0.0624
    110/157, train_loss: 0.0685
    111/157, train_loss: 0.0542
    112/157, train_loss: 0.0771
    113/157, train_loss: 0.0786
    114/157, train_loss: 0.0580
    115/157, train_loss: 0.0698
    116/157, train_loss: 0.0847
    117/157, train_loss: 0.0542
    118/157, train_loss: 0.0760
    119/157, train_loss: 0.0817
    120/157, train_loss: 0.0904
    121/157, train_loss: 0.0705
    122/157, train_loss: 0.0541
    123/157, train_loss: 0.0521
    124/157, train_loss: 0.0555
    125/157, train_loss: 0.0491
    126/157, train_loss: 0.0504
    127/157, train_loss: 0.0598
    128/157, train_loss: 0.0417
    129/157, train_loss: 0.0470
    130/157, train_loss: 0.0584
    131/157, train_loss: 0.0498
    132/157, train_loss: 0.0532
    133/157, train_loss: 0.0515
    134/157, train_loss: 0.0575
    135/157, train_loss: 0.0506
    136/157, train_loss: 0.0519
    137/157, train_loss: 0.0727
    138/157, train_loss: 0.0796
    139/157, train_loss: 0.0469
    140/157, train_loss: 0.0509
    141/157, train_loss: 0.0501
    142/157, train_loss: 0.0603
    143/157, train_loss: 0.0499
    144/157, train_loss: 0.0585
    145/157, train_loss: 0.0590
    146/157, train_loss: 0.0447
    147/157, train_loss: 0.0699
    148/157, train_loss: 0.0595
    149/157, train_loss: 0.0372
    150/157, train_loss: 0.0446
    151/157, train_loss: 0.0576
    152/157, train_loss: 0.0735
    153/157, train_loss: 0.0464
    154/157, train_loss: 0.0742
    155/157, train_loss: 0.0356
    156/157, train_loss: 0.0492
    157/157, train_loss: 0.0644
    158/157, train_loss: 0.0630
    epoch 3 average loss: 0.0743
    saved new best metric model
    current epoch: 3 current AUC: 0.9999 current accuracy: 0.9924 best AUC: 0.9999 at epoch: 3
    ----------
    epoch 4/4
    1/157, train_loss: 0.0563
    2/157, train_loss: 0.0645
    3/157, train_loss: 0.0497
    4/157, train_loss: 0.0579
    5/157, train_loss: 0.0442
    6/157, train_loss: 0.0588
    7/157, train_loss: 0.0501
    8/157, train_loss: 0.0402
    9/157, train_loss: 0.0432
    10/157, train_loss: 0.0465
    11/157, train_loss: 0.0546
    12/157, train_loss: 0.0548
    13/157, train_loss: 0.0430
    14/157, train_loss: 0.0448
    15/157, train_loss: 0.0533
    16/157, train_loss: 0.0521
    17/157, train_loss: 0.0406
    18/157, train_loss: 0.0426
    19/157, train_loss: 0.0471
    20/157, train_loss: 0.0570
    21/157, train_loss: 0.0611
    22/157, train_loss: 0.0500
    23/157, train_loss: 0.0532
    24/157, train_loss: 0.0549
    25/157, train_loss: 0.0488
    26/157, train_loss: 0.0574
    27/157, train_loss: 0.0587
    28/157, train_loss: 0.0488
    29/157, train_loss: 0.0509
    30/157, train_loss: 0.0299
    31/157, train_loss: 0.0404
    32/157, train_loss: 0.0345
    33/157, train_loss: 0.0569
    34/157, train_loss: 0.0361
    35/157, train_loss: 0.0623
    36/157, train_loss: 0.0686
    37/157, train_loss: 0.0376
    38/157, train_loss: 0.0528
    39/157, train_loss: 0.0367
    40/157, train_loss: 0.0466
    41/157, train_loss: 0.0551
    42/157, train_loss: 0.0374
    43/157, train_loss: 0.0681
    44/157, train_loss: 0.0386
    45/157, train_loss: 0.0636
    46/157, train_loss: 0.0555
    47/157, train_loss: 0.0449
    48/157, train_loss: 0.0481
    49/157, train_loss: 0.0382
    50/157, train_loss: 0.0682
    51/157, train_loss: 0.0511
    52/157, train_loss: 0.0606
    53/157, train_loss: 0.0490
    54/157, train_loss: 0.0497
    55/157, train_loss: 0.0476
    56/157, train_loss: 0.0457
    57/157, train_loss: 0.0545
    58/157, train_loss: 0.0426
    59/157, train_loss: 0.0445
    60/157, train_loss: 0.0528
    61/157, train_loss: 0.0597
    62/157, train_loss: 0.0376
    63/157, train_loss: 0.0555
    64/157, train_loss: 0.0571
    65/157, train_loss: 0.0475
    66/157, train_loss: 0.0577
    67/157, train_loss: 0.0393
    68/157, train_loss: 0.0397
    69/157, train_loss: 0.0536
    70/157, train_loss: 0.0516
    71/157, train_loss: 0.0595
    72/157, train_loss: 0.0473
    73/157, train_loss: 0.0624
    74/157, train_loss: 0.0426
    75/157, train_loss: 0.0474
    76/157, train_loss: 0.0474
    77/157, train_loss: 0.0516
    78/157, train_loss: 0.0332
    79/157, train_loss: 0.0403
    80/157, train_loss: 0.0401
    81/157, train_loss: 0.0397
    82/157, train_loss: 0.0526
    83/157, train_loss: 0.0429
    84/157, train_loss: 0.0306
    85/157, train_loss: 0.0433
    86/157, train_loss: 0.0376
    87/157, train_loss: 0.0430
    88/157, train_loss: 0.0433
    89/157, train_loss: 0.0575
    90/157, train_loss: 0.0349
    91/157, train_loss: 0.0273
    92/157, train_loss: 0.0395
    93/157, train_loss: 0.0474
    94/157, train_loss: 0.0464
    95/157, train_loss: 0.0310
    96/157, train_loss: 0.0597
    97/157, train_loss: 0.0403
    98/157, train_loss: 0.0684
    99/157, train_loss: 0.0371
    100/157, train_loss: 0.0570
    101/157, train_loss: 0.0468
    102/157, train_loss: 0.0317
    103/157, train_loss: 0.0322
    104/157, train_loss: 0.0472
    105/157, train_loss: 0.0351
    106/157, train_loss: 0.0430
    107/157, train_loss: 0.0319
    108/157, train_loss: 0.0459
    109/157, train_loss: 0.0448
    110/157, train_loss: 0.0486
    111/157, train_loss: 0.0538
    112/157, train_loss: 0.0290
    113/157, train_loss: 0.0567
    114/157, train_loss: 0.0455
    115/157, train_loss: 0.0502
    116/157, train_loss: 0.0338
    117/157, train_loss: 0.0541
    118/157, train_loss: 0.0496
    119/157, train_loss: 0.0461
    120/157, train_loss: 0.0353
    121/157, train_loss: 0.0569
    122/157, train_loss: 0.0282
    123/157, train_loss: 0.0299
    124/157, train_loss: 0.0366
    125/157, train_loss: 0.0397
    126/157, train_loss: 0.0339
    127/157, train_loss: 0.0417
    128/157, train_loss: 0.0515
    129/157, train_loss: 0.0433
    130/157, train_loss: 0.0435
    131/157, train_loss: 0.0310
    132/157, train_loss: 0.0497
    133/157, train_loss: 0.0366
    134/157, train_loss: 0.0436
    135/157, train_loss: 0.0387
    136/157, train_loss: 0.0291
    137/157, train_loss: 0.0480
    138/157, train_loss: 0.0377
    139/157, train_loss: 0.0346
    140/157, train_loss: 0.0265
    141/157, train_loss: 0.0497
    142/157, train_loss: 0.0352
    143/157, train_loss: 0.0264
    144/157, train_loss: 0.0349
    145/157, train_loss: 0.0409
    146/157, train_loss: 0.0488
    147/157, train_loss: 0.0541
    148/157, train_loss: 0.0506
    149/157, train_loss: 0.0451
    150/157, train_loss: 0.0280
    151/157, train_loss: 0.0349
    152/157, train_loss: 0.0344
    153/157, train_loss: 0.0307
    154/157, train_loss: 0.0550
    155/157, train_loss: 0.0521
    156/157, train_loss: 0.0478
    157/157, train_loss: 0.0295
    158/157, train_loss: 0.1089
    epoch 4 average loss: 0.0462
    saved new best metric model
    current epoch: 4 current AUC: 1.0000 current accuracy: 0.9944 best AUC: 1.0000 at epoch: 4
    train completed, best_metric: 1.0000 at epoch: 4
    

     

    損失とメトリックをプロットする

    plt.figure("train", (12, 6))
    plt.subplot(1, 2, 1)
    plt.title("Epoch Average Loss")
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch")
    plt.plot(x, y)
    plt.subplot(1, 2, 2)
    plt.title("Val AUC")
    x = [val_interval * (i + 1) for i in range(len(metric_values))]
    y = metric_values
    plt.xlabel("epoch")
    plt.plot(x, y)
    plt.show()
    

     

    テストデータセット上でモデルを評価する

    訓練と検証の後、検証テスト上のベストモデルを既に得ています。モデルをテストデータセット上でそれが堅牢で over fitting していないかを確認するために評価する必要があります。分類レポートを生成するためにこれらの予測を使用します。

    model.load_state_dict(torch.load(
        os.path.join(root_dir, "best_metric_model.pth")))
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for test_data in test_loader:
            test_images, test_labels = (
                test_data[0].to(device),
                test_data[1].to(device),
            )
            pred = model(test_images).argmax(dim=1)
            for i in range(len(pred)):
                y_true.append(test_labels[i].item())
                y_pred.append(pred[i].item())
    
    print(classification_report(
        y_true, y_pred, target_names=class_names, digits=4))
    
    Note: you may need to restart the kernel to use updated packages.
                  precision    recall  f1-score   support
    
       AbdomenCT     0.9816    0.9917    0.9867       969
       BreastMRI     0.9968    0.9831    0.9899       944
             CXR     0.9979    0.9928    0.9954       973
         ChestCT     0.9938    0.9990    0.9964       959
            Hand     0.9934    0.9934    0.9934      1055
          HeadCT     0.9960    0.9990    0.9975       985
    
        accuracy                         0.9932      5885
       macro avg     0.9932    0.9932    0.9932      5885
    weighted avg     0.9932    0.9932    0.9932      5885
    

     

    データディレクトリのクリーンアップ

    一時ディレクトリが使用された場合ディレクトリを削除します。

    if directory is None:
        shutil.rmtree(root_dir)
    
     

    以上



ClassCat® Chatbot

人工知能開発支援

◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com

カテゴリー