ホーム » MONAI » MONAI 0.7 : tutorials : モジュール – 最適な学習率

MONAI 0.7 : tutorials : モジュール – 最適な学習率

MONAI 0.7 : tutorials : モジュール – 最適な学習率 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/13/2021 (0.7.0)

* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス 無料 Web セミナー開催中

◆ クラスキャットは人工知能・テレワークに関する各種サービスを提供しております。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

 

MONAI 0.7 : tutorials : モジュール – 最適な学習率

このノートブックは、ネットワークの学習率の値を調整するために LearningRateFinder を使用する方法を実演します。

このチュートリアルでは、MONAI の LearningRateFinder を調べるために MedNIST データセットを使用し、学習率の初期推定値を取得するためにそれを使用します。

そして最適化の過程で学習率を変化させるために PyTorch の周期的な学習率スケジューラの一つを採用します。これは改善された結果を与えることが示されています : https://arxiv.org/abs/1506.01186。これを optimizer (ADAM) のデフォルトの学習率と LearningRateFinder により提案された学習率と比較します。

この 2D 分類は非常に簡単ですので、それを少し難しく (そして高速に) するため、小さいネットワーク、画像のサブセットだけを使用し (訓練と検証のためにそれぞれ ~250 と ~25)、画像をクロップして (64×64 から 20×20 に) どのようなランダム変換も使用しません。より難しいシナリオでは、多分これらのどれも行なわないことを望むでしょう。

 

環境のセットアップ

!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib

 

インポートのセットアップ

# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import torch
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import decollate_batch
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet
from monai.networks.utils import eval_mode
from monai.optimizers import LearningRateFinder
from monai.transforms import (
    Activations,
    AsDiscrete,
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    LoadImaged,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
)
from monai.utils import set_determinism
from torch.utils.data import DataLoader
from tqdm import trange

print_config()
MONAI version: 0.6.0+1.g8365443a
Numpy version: 1.20.3
Pytorch version: 1.9.0a0+c3d40fd
MONAI flags: HAS_EXT = True, USE_COMPILED = False
MONAI rev id: 8365443ababac313340467e5987c7babe2b5b86a

Optional dependencies:
Pytorch Ignite version: 0.4.5
Nibabel version: 3.2.1
scikit-image version: 0.15.0
Pillow version: 8.2.0
Tensorboard version: 2.2.0
gdown version: 3.13.0
TorchVision version: 0.10.0a0
ITK version: 5.1.2
tqdm version: 4.53.0
lmdb version: 1.2.1
psutil version: 5.8.0
pandas version: 1.1.4
einops version: 0.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

 

データディレクトリのセットアップ

MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

 

再現性のために決定論的訓練を設定する

set_determinism(seed=0)

 

MONAI 変換を定義してデータセットとデーたローダを取得する

transforms = Compose(
    [
        LoadImaged(keys="image"),
        AddChanneld(keys="image"),
        ScaleIntensityd(keys="image"),
        CenterSpatialCropd(keys="image", roi_size=(20, 20)),
        EnsureTyped(keys="image"),
    ]
)
# Set fraction of images used for testing to be very high, then don't use it. In this way, we can reduce the number
# of images in both train and val. Makes it faster and makes the training a little harder.
def get_data(section):
    ds = MedNISTDataset(
        root_dir=root_dir,
        transform=transforms,
        section=section,
        download=True,
        num_workers=10,
        val_frac=0.0005,
        test_frac=0.995,
    )
    loader = DataLoader(ds, batch_size=30, shuffle=True, num_workers=10)
    return ds, loader


train_ds, train_loader = get_data("training")
val_ds, val_loader = get_data("validation")

print(len(train_ds))
print(len(val_ds))
print(train_ds[0]["image"].shape)
num_classes = train_ds.get_num_classes()

y_pred_trans = Compose([EnsureType(), Activations(softmax=True)])
y_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=num_classes)])
Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
file /home/rbrown/data/MONAI/MedNIST.tar.gz exists, skip downloading.
extracted file /home/rbrown/data/MONAI/MedNIST exists, skip extracting.
Loading dataset: 100%|██████████| 249/249 [00:00<00:00, 913.43it/s]
Loading dataset: 100%|██████████| 25/25 [00:00<00:00, 1242.09it/s]
Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
file /home/rbrown/data/MONAI/MedNIST.tar.gz exists, skip downloading.
extracted file /home/rbrown/data/MONAI/MedNIST exists, skip extracting.
249
25
torch.Size([1, 20, 20])

 

視覚化して確認するためにデータセットから画像をランダムにピックアップする

%matplotlib inline
fig, axes = plt.subplots(3, 3, figsize=(15, 15), facecolor="white")
for i, k in enumerate(np.random.randint(len(train_ds), size=9)):
    data = train_ds[k]
    im, title = data["image"], data["class_name"]
    ax = axes[i // 3, i % 3]
    im_show = ax.imshow(im[0])
    ax.set_title(title, fontsize=25)
    ax.axis("off")

 

損失関数とネットワークを定義する

device = "cuda" if torch.cuda.is_available() else "cpu"
loss_function = torch.nn.CrossEntropyLoss()
auc_metric = ROCAUCMetric()


def get_new_net():
    return DenseNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=num_classes,
        init_features=2,
        growth_rate=2,
        block_config=(2,),
    ).to(device)


model = get_new_net()

 

最適な学習率を推定する

MONAI の LearningRateFinder を使用して学習率の初期推定値を得ます。それは範囲 1e-5, 1e0 内にあると過程します。それが当てはまらないのであれば (プロットで気付くでしょう)、より大きな/異なるウィンドウに渡り単に再試行できるでしょう。

結果をプロットして、最も急な勾配を持つ学習率を抽出することができます。

%matplotlib inline
lower_lr, upper_lr = 1e-3, 1e-0
optimizer = torch.optim.Adam(model.parameters(), lower_lr)
lr_finder = LearningRateFinder(model, optimizer, loss_function, device=device)
lr_finder.range_test(train_loader, val_loader, end_lr=upper_lr, num_iter=20)
steepest_lr, _ = lr_finder.get_steepest_gradient()
ax = plt.subplots(1, 1, figsize=(15, 15), facecolor="white")[1]
_ = lr_finder.plot(ax=ax)
Computing optimal learning rate:  90%|█████████ | 18/20 [00:14<00:01,  1.26it/s]
Stopping early, the loss has diverged
Resetting model and optimizer

 

ライブ・プロッティング

この関数は、プロットが総ての反復で更新されるような、range/trange 回りの単なるラッパーです。

def plot_range(data, wrapped_generator):
    plt.ion()
    for q in data.values():
        for d in q.values():
            if isinstance(d, dict):
                ax = d["line"].axes
                ax.legend()
                fig = ax.get_figure()
    fig.show()

    for i in wrapped_generator:
        yield i
        for q in data.values():
            for d in q.values():
                if isinstance(d, dict):
                    d["line"].set_data(d["x"], d["y"])
                    ax = d["line"].axes
                    ax.legend()
                    ax.relim()
                    ax.autoscale_view()
        fig.canvas.draw()

 

訓練

訓練は vanilla ループとは少し違って見えますが、これは様々な学習率手法 (標準、最急 (= steepest) と周期的) の各々に渡りループし、その結果それらは同時に更新できるからです。

def get_model_optimizer_scheduler(d):
    d["model"] = get_new_net()

    if "lr_lims" in d:
        d["optimizer"] = torch.optim.Adam(
            d["model"].parameters(), d["lr_lims"][0]
        )
        d["scheduler"] = torch.optim.lr_scheduler.CyclicLR(
            d["optimizer"],
            base_lr=d["lr_lims"][0],
            max_lr=d["lr_lims"][1],
            step_size_up=d["step"],
            cycle_momentum=False,
        )
    elif "lr_lim" in d:
        d["optimizer"] = torch.optim.Adam(d["model"].parameters(), d["lr_lim"])
    else:
        d["optimizer"] = torch.optim.Adam(d["model"].parameters())


def train(max_epochs, axes, data):
    for d in data.keys():
        get_model_optimizer_scheduler(data[d])

        for q, i in enumerate(["train", "auc", "acc"]):
            data[d][i] = {"x": [], "y": []}
            (data[d][i]["line"],) = axes[q].plot(
                data[d][i]["x"], data[d][i]["y"], label=d
            )

        val_interval = 1

    for epoch in plot_range(data, trange(max_epochs)):

        for d in data.keys():
            data[d]["epoch_loss"] = 0
        for batch_data in train_loader:
            inputs = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)

            for d in data.keys():
                data[d]["optimizer"].zero_grad()
                outputs = data[d]["model"](inputs)
                loss = loss_function(outputs, labels)
                loss.backward()
                data[d]["optimizer"].step()
                if "scheduler" in data[d]:
                    data[d]["scheduler"].step()
                data[d]["epoch_loss"] += loss.item()
        for d in data.keys():
            data[d]["epoch_loss"] /= len(train_loader)
            data[d]["train"]["x"].append(epoch + 1)
            data[d]["train"]["y"].append(data[d]["epoch_loss"])

        if (epoch + 1) % val_interval == 0:
            with eval_mode(*[data[d]["model"] for d in data.keys()]):
                for d in data:
                    data[d]["y_pred"] = torch.tensor(
                        [], dtype=torch.float32, device=device
                    )
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images = val_data["image"].to(device)
                    val_labels = val_data["label"].to(device)
                    for d in data:
                        data[d]["y_pred"] = torch.cat(
                            [data[d]["y_pred"], data[d]["model"](val_images)],
                            dim=0,
                        )
                    y = torch.cat([y, val_labels], dim=0)

                for d in data:
                    y_onehot = [y_trans(i) for i in decollate_batch(y)]
                    y_pred_act = [y_pred_trans(i) for i in decollate_batch(data[d]["y_pred"])]
                    auc_metric(y_pred_act, y_onehot)
                    auc_result = auc_metric.aggregate()
                    auc_metric.reset()
                    del y_pred_act, y_onehot
                    data[d]["auc"]["x"].append(epoch + 1)
                    data[d]["auc"]["y"].append(auc_result)

                    acc_value = torch.eq(data[d]["y_pred"].argmax(dim=1), y)
                    acc_metric = acc_value.sum().item() / len(acc_value)
                    data[d]["acc"]["x"].append(epoch + 1)
                    data[d]["acc"]["y"].append(acc_metric)
%matplotlib notebook
fig, axes = plt.subplots(3, 1, figsize=(10, 10), facecolor="white")
for ax in axes:
    ax.set_xlabel("Epoch")
axes[0].set_ylabel("Train loss")
axes[1].set_ylabel("AUC")
axes[2].set_ylabel("ACC")

# In the paper referenced at the top of this notebook, a step
# size of 8 times the number of iterations per epoch is suggested.
step_size = 8 * len(train_loader)

max_epochs = 100
data = {}
data["Default LR"] = {}
data["Steepest LR"] = {"lr_lim": steepest_lr}
data["Cyclical LR"] = {
    "lr_lims": (0.8 * steepest_lr, 1.2 * steepest_lr),
    "step": step_size,
}

train(max_epochs, axes, data)

100%|██████████| 100/100 [03:11<00:00,  1.91s/it]

 

結び

当然のことながら、Steepest LR と Cyclical LR の両方はデフォルト LR よりも損失関数の迅速な収束を示します。

この例では Steepest LR と Cyclical LR の間に大きな違いはありません。より複雑な最適化問題では大きな違いが現れるかもしれませんが、ステップサイズ、そして下限と上限の周期的制限を自由に試してください。

 

データディレクトリのクリーンアップ

一時ディレクトリが使用された場合にはディレクトリを削除します。

if directory is None:
    shutil.rmtree(root_dir)
 

以上



ClassCat® Chatbot

人工知能開発支援

◆ クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com

カテゴリー