ホーム » 「MONAI 0.7」タグがついた投稿 (ページ 2)
タグアーカイブ: MONAI 0.7
MONAI 0.7 : tutorials : モジュール – CSV データセットで CSV ファイルのロード
MONAI 0.7 : tutorials : モジュール – CSV データセットで CSV ファイルのロード (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/20/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
MONAI 0.7 : tutorials : モジュール – CSV データセットで CSV ファイルのロード
チュートリアルは CSVDataset と CSVIterableDataset の使い方を示し、複数の CSV ファイルをロードして後処理ロジックを実行します。
このチュートリアルは CSVDataset と CSVIterableDataset モジュールに基づいて CSV ファイルからデータをロードする方法を示します。そしてデータ上で後処理ロジックを実行します。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[pandas, pillow]"
%matplotlib inline
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import shutil
import sys
import matplotlib.pyplot as plt
import pandas as pd
import PIL
import numpy as np
from monai.data import CSVDataset, CSVIterableDataset, DataLoader
from monai.apps import download_and_extract
from monai.config import print_config
from monai.transforms import Compose, LoadImaged, ToNumpyd
from monai.utils import first
print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0 Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 8.2.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットをダウンロードする
ここではデモで MedNIST データセットの幾つかの画像を使用します。データセットをダウンロードして展開してください。
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
ハンド・カテゴリーの幾つかの医療画像をプロットする
plt.subplots(1, 5, figsize=(10, 10))
for i in range(5):
filename = f"00000{i}.jpeg"
im = PIL.Image.open(os.path.join(data_dir, "Hand", filename))
arr = np.array(im)
plt.subplot(3, 3, i + 1)
plt.xlabel(filename)
plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()
テストのために 3 CSV ファイルを生成する
ここで 画像の特性をストアするために 3 CSV ファイルを生成します、欠損値を含みます。
test_data1 = [
["subject_id", "label", "image", "ehr_0", "ehr_1", "ehr_2"],
["s000000", 5, os.path.join(data_dir, "Hand", "000000.jpeg"), 2.007843256, 2.29019618, 2.054902077],
["s000001", 0, os.path.join(data_dir, "Hand", "000001.jpeg"), 6.839215755, 6.474509716, 5.862744808],
["s000002", 4, os.path.join(data_dir, "Hand", "000002.jpeg"), 3.772548914, 4.211764812, 4.635294437],
["s000003", 1, os.path.join(data_dir, "Hand", "000003.jpeg"), 3.333333254, 3.235294342, 3.400000095],
["s000004", 9, os.path.join(data_dir, "Hand", "000004.jpeg"), 6.427451134, 6.254901886, 5.976470947],
]
test_data2 = [
["subject_id", "ehr_3", "ehr_4", "ehr_5", "ehr_6", "ehr_7", "ehr_8"],
["s000000", 3.019608021, 3.807843208, 3.584313869, 3.141176462, 3.1960783, 4.211764812],
["s000001", 5.192157269, 5.274509907, 5.250980377, 4.647058964, 4.886274338, 4.392156601],
["s000002", 5.298039436, 9.545097351, 12.57254887, 6.799999714, 2.1960783, 1.882352948],
["s000003", 3.164705753, 3.086274624, 3.725490093, 3.698039293, 3.698039055, 3.701960802],
["s000004", 6.26274538, 7.717647076, 9.584313393, 6.082352638, 2.662744999, 2.34117651],
]
test_data3 = [
["subject_id", "ehr_9", "ehr_10", "meta_0", "meta_1", "meta_2"],
["s000000", 6.301961422, 6.470588684, "TRUE", "TRUE", "TRUE"],
["s000001", 5.219608307, 7.827450752, "FALSE", "TRUE", "FALSE"],
["s000002", 1.882352948, 2.031372547, "TRUE", "FALSE", "TRUE"],
["s000003", 3.309803963, 3.729412079, "FALSE", "FALSE", "TRUE"],
["s000004", 2.062745094, 2.34117651, "FALSE", "TRUE", "TRUE"],
# generate missing values in the row
["s000005", 3.353655643, 1.675674543, "TRUE", "TRUE", "FALSE"],
]
def prepare_csv_file(data, filepath):
with open(filepath, "w") as f:
for d in data:
f.write((",".join([str(i) for i in d])) + "\n")
filepath1 = os.path.join(data_dir, "test_data1.csv")
filepath2 = os.path.join(data_dir, "test_data2.csv")
filepath3 = os.path.join(data_dir, "test_data3.csv")
prepare_csv_file(test_data1, filepath1)
prepare_csv_file(test_data2, filepath2)
prepare_csv_file(test_data3, filepath3)
CSVDataset でシングル CSV ファイルをロードする
dataset = CSVDataset(filename=filepath1)
# construct pandas table to show the data, `CSVDataset` inherits from PyTorch Dataset
print(pd.DataFrame(dataset.data))
subject_id label image \ 0 s000000 5 /workspace/data/medical/MedNIST/Hand/000000.jpeg 1 s000001 0 /workspace/data/medical/MedNIST/Hand/000001.jpeg 2 s000002 4 /workspace/data/medical/MedNIST/Hand/000002.jpeg 3 s000003 1 /workspace/data/medical/MedNIST/Hand/000003.jpeg 4 s000004 9 /workspace/data/medical/MedNIST/Hand/000004.jpeg ehr_0 ehr_1 ehr_2 0 2.007843 2.290196 2.054902 1 6.839216 6.474510 5.862745 2 3.772549 4.211765 4.635294 3 3.333333 3.235294 3.400000 4 6.427451 6.254902 5.976471
複数の CSV ファイルをロードしてテーブルを結合する
dataset = CSVDataset([filepath1, filepath2, filepath3], on="subject_id")
# construct pandas table to show the joined data of 3 tables
print(pd.DataFrame(dataset.data))
subject_id label image \ 0 s000000 5 /workspace/data/medical/MedNIST/Hand/000000.jpeg 1 s000001 0 /workspace/data/medical/MedNIST/Hand/000001.jpeg 2 s000002 4 /workspace/data/medical/MedNIST/Hand/000002.jpeg 3 s000003 1 /workspace/data/medical/MedNIST/Hand/000003.jpeg 4 s000004 9 /workspace/data/medical/MedNIST/Hand/000004.jpeg ehr_0 ehr_1 ehr_2 ehr_3 ehr_4 ehr_5 ehr_6 \ 0 2.007843 2.290196 2.054902 3.019608 3.807843 3.584314 3.141176 1 6.839216 6.474510 5.862745 5.192157 5.274510 5.250980 4.647059 2 3.772549 4.211765 4.635294 5.298039 9.545097 12.572549 6.800000 3 3.333333 3.235294 3.400000 3.164706 3.086275 3.725490 3.698039 4 6.427451 6.254902 5.976471 6.262745 7.717647 9.584313 6.082353 ehr_7 ehr_8 ehr_9 ehr_10 meta_0 meta_1 meta_2 0 3.196078 4.211765 6.301961 6.470589 True True True 1 4.886274 4.392157 5.219608 7.827451 False True False 2 2.196078 1.882353 1.882353 2.031373 True False True 3 3.698039 3.701961 3.309804 3.729412 False False True 4 2.662745 2.341177 2.062745 2.341177 False True True
3 CSV ファイルから選択された行と選択された列だけをロードする
ここでは rows: 0 – 1 と 3, columns: “subject_id”, “label”, “ehr_1”, “ehr_7”, “meta_1” をロードします。
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
row_indices=[[0, 2], 3], # load row: 0, 1, 3
col_names=["subject_id", "label", "ehr_1", "ehr_7", "meta_1"],
)
# construct pandas table to show the joined and selected data
print(pd.DataFrame(dataset.data))
subject_id label ehr_1 ehr_7 meta_1 0 s000000 5 2.290196 3.196078 True 1 s000001 0 6.474510 4.886274 True 2 s000003 1 3.235294 3.698039 False
新しいカラムを生成するためにカラムをロードしてグループ分けする
ここでは 3 CSV ファイルをロードして新しい ehr カラムを生成するために総ての ehr_* カラムをグループ分けし、そして新しい meta カラムを生成するために総ての meta_* カラムをグループ分けします。
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"],
col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta": ["meta_0", "meta_1", "meta_2"]},
)
# construct pandas table to show the joined, selected and generated data
print(pd.DataFrame(dataset.data))
subject_id image ehr_0 \ 0 s000000 /workspace/data/medical/MedNIST/Hand/000000.jpeg 2.007843 1 s000001 /workspace/data/medical/MedNIST/Hand/000001.jpeg 6.839216 2 s000002 /workspace/data/medical/MedNIST/Hand/000002.jpeg 3.772549 3 s000003 /workspace/data/medical/MedNIST/Hand/000003.jpeg 3.333333 4 s000004 /workspace/data/medical/MedNIST/Hand/000004.jpeg 6.427451 ehr_1 ehr_2 ehr_3 ehr_4 ehr_5 ehr_6 ehr_7 \ 0 2.290196 2.054902 3.019608 3.807843 3.584314 3.141176 3.196078 1 6.474510 5.862745 5.192157 5.274510 5.250980 4.647059 4.886274 2 4.211765 4.635294 5.298039 9.545097 12.572549 6.800000 2.196078 3 3.235294 3.400000 3.164706 3.086275 3.725490 3.698039 3.698039 4 6.254902 5.976471 6.262745 7.717647 9.584313 6.082353 2.662745 ehr_8 ehr_9 ehr_10 meta_0 meta_1 meta_2 \ 0 4.211765 6.301961 6.470589 True True True 1 4.392157 5.219608 7.827451 False True False 2 1.882353 1.882353 2.031373 True False True 3 3.701961 3.309804 3.729412 False False True 4 2.341177 2.062745 2.341177 False True True ehr meta 0 [2.007843256, 2.29019618, 2.054902077, 3.01960... [True, True, True] 1 [6.839215755, 6.474509716, 5.8627448079999995,... [False, True, False] 2 [3.7725489139999997, 4.211764812, 4.635294437,... [True, False, True] 3 [3.333333254, 3.235294342, 3.400000095, 3.1647... [False, False, True] 4 [6.427451134, 6.254901886, 5.976470947, 6.2627... [False, True, True]
ロードして欠損値を埋めてデータ型を変換する
このチュートリアルでは、s000005 画像は CSV file1 と file2 で多くの欠損値を持ちます。ここでは幾つかのカラムを選択して画像カラムの欠損値にデフォルト値を設定して、ehr_1 を int 型に変換もしてみます。
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
col_names=["subject_id", "label", "ehr_0", "ehr_1", "ehr_9", "meta_1"],
col_types={"label": {"default": "No label"}, "ehr_1": {"type": int, "default": 0}},
how="outer", # will load the NaN values in this merge mode
)
# construct pandas table to show the joined, selected and converted data
print(pd.DataFrame(dataset.data))
subject_id label ehr_0 ehr_1 ehr_9 meta_1 0 s000000 5 2.007843 2 6.301961 True 1 s000001 0 6.839216 6 5.219608 True 2 s000002 4 3.772549 4 1.882353 False 3 s000003 1 3.333333 3 3.309804 False 4 s000004 9 6.427451 6 2.062745 True 5 s000005 No label NaN 0 3.353656 True
ロードされたデータ上で変換を実行する
ここでは image 値から JPG 画像をロードして、ehr グループを numpy 配列に変換します。
dataset = CSVDataset(
filename=[filepath1, filepath2, filepath3],
col_groups={"ehr": [f"ehr_{i}" for i in range(5)]},
transform=Compose([LoadImaged(keys="image"), ToNumpyd(keys="ehr")]),
)
# test the transformed `ehr` data:
for item in dataset:
print(type(item["ehr"]), item["ehr"])
# plot the transformed image array
plt.subplots(1, 5, figsize=(10, 10))
for i in range(5):
plt.subplot(3, 3, i + 1)
plt.xlabel(dataset[i]["subject_id"])
plt.imshow(dataset[i]["image"], cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()
<class 'numpy.ndarray'> [2.00784326 2.29019618 2.05490208 3.01960802 3.80784321] <class 'numpy.ndarray'> [6.83921575 6.47450972 5.86274481 5.19215727 5.27450991] <class 'numpy.ndarray'> [3.77254891 4.21176481 4.63529444 5.29803944 9.54509735] <class 'numpy.ndarray'> [3.33333325 3.23529434 3.4000001 3.16470575 3.08627462] <class 'numpy.ndarray'> [6.42745113 6.25490189 5.97647095 6.26274538 7.71764708]
CSVIterableDataset で CSV ファイルをロードする
CSVIterableDataset は非常に大きな CSV ファイルからデータチャンクをロードするように設計されています、それは最初に総てのコンテンツをロードする必要がありません。そしてそれは行の選択を除いて CSVDataset の上の機能の殆どをサポートできます。
ここでは DataLoader のマルチ処理方式で CSVIterableDataset を使用して CSV ファイルをロードします。
dataset = CSVIterableDataset(filename=[filepath1, filepath2, filepath3])
# set num workers = 0 for mac / win
num_workers = 2 if sys.platform == "linux" else 0
dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2)
print(first(dataloader))
{'subject_id': ['s000000', 's000001'], 'label': tensor([5, 0]), 'image': ['/workspace/data/medical/MedNIST/Hand/000000.jpeg', '/workspace/data/medical/MedNIST/Hand/000001.jpeg'], 'ehr_0': tensor([2.0078, 6.8392], dtype=torch.float64), 'ehr_1': tensor([2.2902, 6.4745], dtype=torch.float64), 'ehr_2': tensor([2.0549, 5.8627], dtype=torch.float64), 'ehr_3': tensor([3.0196, 5.1922], dtype=torch.float64), 'ehr_4': tensor([3.8078, 5.2745], dtype=torch.float64), 'ehr_5': tensor([3.5843, 5.2510], dtype=torch.float64), 'ehr_6': tensor([3.1412, 4.6471], dtype=torch.float64), 'ehr_7': tensor([3.1961, 4.8863], dtype=torch.float64), 'ehr_8': tensor([4.2118, 4.3922], dtype=torch.float64), 'ehr_9': tensor([6.3020, 5.2196], dtype=torch.float64), 'ehr_10': tensor([6.4706, 7.8275], dtype=torch.float64), 'meta_0': tensor([ True, False]), 'meta_1': tensor([True, True]), 'meta_2': tensor([ True, False])}
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : 配備 – BentoML による MedNIST 分類器の配備
MONAI 0.7 : tutorials : 配備 – BentoML による MedNIST 分類器の配備 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/19/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
MONAI 0.7 : tutorials : 配備 – BentoML による MedNIST 分類器の配備
これは MONAI ネットワークを訓練して BentoML を web サーバとして使用して配備するサンプルです、BentoML レポジトリをローカルで使用するかコンテナサービスとして使用します。
このノートブックは BentoML を使用して訓練済みモデルをアーティファクトにパッケージ化するプロセスを実演します、これは推論を実行するローカルプログラムとして、同じことを行なう web サービスとして、そして Docker コンテナ化された web サービスとして実行できます。BentoML は AWS や Azure のような既存のプラットフォームでモデルを配備する様々な方法を提供しますが、ここではローカル配備にフォーカスします、研究者はこれを行なう傾向にあるためです。このチュートリアルは ここの MONAI チュートリアル のような MedNIST 分類器を訓練してから BentoML チュートリアル で説明されているパッケージ化を行ないます。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]"
!python -c "import bentoml" || pip install -q bentoml
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import tempfile
import glob
import PIL.Image
import torch
import numpy as np
from ignite.engine import Events
from monai.apps import download_and_extract
from monai.config import print_config
from monai.networks.nets import DenseNet121
from monai.engines import SupervisedTrainer
from monai.transforms import (
AddChannel,
Compose,
LoadImage,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
)
from monai.utils import set_determinism
set_determinism(seed=0)
print_config()
MONAI version: 0.4.0+119.g9898a89 Numpy version: 1.19.2 Pytorch version: 1.7.1 MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: 9898a89d24364a9be3525d066a7492adf00b9e6b Optional dependencies: Pytorch Ignite version: 0.4.2 Nibabel version: 3.2.1 scikit-image version: 0.18.1 Pillow version: 8.1.0 Tensorboard version: 2.4.1 gdown version: 3.12.2 TorchVision version: 0.8.2 ITK version: 5.1.2 tqdm version: 4.56.0 lmdb version: 1.0.0 psutil version: 5.8.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データセットをダウンロードする
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
MedNIST データセットを使用する場合、出典を明示してください。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
MedNIST.tar.gz: 0.00B [00:00, ?B/s] /tmp/tmpxxp5z205 MedNIST.tar.gz: 59.0MB [00:04, 15.4MB/s] downloaded file: /tmp/tmpxxp5z205/MedNIST.tar.gz. Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d. Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
subdirs = sorted(glob.glob(f"{data_dir}/*/"))
class_names = [os.path.basename(sd[:-1]) for sd in subdirs]
image_files = [glob.glob(f"{sb}/*") for sb in subdirs]
image_files_list = sum(image_files, [])
image_class = sum(([i] * len(f) for i, f in enumerate(image_files)), [])
image_width, image_height = PIL.Image.open(image_files_list[0]).size
print(f"Label names: {class_names}")
print(f"Label counts: {list(map(len, image_files))}")
print(f"Total image count: {len(image_class)}")
print(f"Image dimensions: {image_width} x {image_height}")
Label names: ['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT'] Label counts: [10000, 8954, 10000, 10000, 10000, 10000] Total image count: 58954 Image dimensions: 64 x 64
セットアップと訓練
ここでは変換シークエンスを作成してネットワークを訓練します、検証とテストはこれが実際に動作することを私達は知っていてそしてここでは必要ないので省略します :
train_transforms = Compose(
[
LoadImage(image_only=True),
AddChannel(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureType(),
]
)
class MedNISTDataset(torch.utils.data.Dataset):
def __init__(self, image_files, labels, transforms):
self.image_files = image_files
self.labels = labels
self.transforms = transforms
def __len__(self):
return len(self.image_files)
def __getitem__(self, index):
return self.transforms(self.image_files[index]), self.labels[index]
# just one dataset and loader, we won't bother with validation or testing
train_ds = MedNISTDataset(image_files_list, image_class, train_transforms)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
device = torch.device("cuda:0")
net = DenseNet121(spatial_dims=2, in_channels=1, out_channels=len(class_names)).to(device)
loss_function = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(net.parameters(), 1e-5)
max_epochs = 5
def _prepare_batch(batch, device, non_blocking):
return tuple(b.to(device) for b in batch)
trainer = SupervisedTrainer(device, max_epochs, train_loader, net, opt, loss_function, prepare_batch=_prepare_batch)
@trainer.on(Events.EPOCH_COMPLETED)
def _print_loss(engine):
print(f"Epoch {engine.state.epoch}/{engine.state.max_epochs} Loss: {engine.state.output[0]['loss']}")
trainer.run()
Epoch 1/5 Loss: 0.231450617313385 Epoch 2/5 Loss: 0.07256477326154709 Epoch 3/5 Loss: 0.04309789836406708 Epoch 4/5 Loss: 0.04549304023385048 Epoch 5/5 Loss: 0.025731785222887993
ここでネットワークが Torchscript オブジェクトとしてセーブされますが後で見るようこれは必要ありません。
torch.jit.script(net).save("classifier.zip")
BentoML セットアップ
BentoML はサービスリクエストをメソッド呼び出しとしてラップする API を通してプラットフォームを提供します。これは明らかに Flask が動作する方法と似ていますが (これはここで使用される基礎技術の一つです)、これの上にはネットワーク (アーティファクト) のストア、リクエストの IO コンポーネントの処理、そしてデータのキャッシュのための様々な機能が提供されます。私達が提供する必要があるものは望むサービスを表わすスクリプトファイルで、BentoML は提供するアーティファクトと一緒にこれを取得して別の場所にストアします、これはローカルで実行したりサーバにアップロードすることができます (Docker レジストリのようなものです)。
下のスクリプトは MONAI コードを含む API を作成します。変換シークエンスはデータストリームを画像に変えるために特殊な読み取り変換 (= read Transform) を必要としますが、それ以外は訓練のために上で使用されたようなコードです。ネットワークはアーティファクトとしてストアされ、これは実際には BentoML バンドルでストアされた重みです。これは実行時に自動的にロードされますが、望むならば代わりに Torchscript モデルをロードすることもできるでしょう、特に MONAI コードに依存しない API を望む場合には。
スクリプトは最初にファイルに書き出される必要があります :
%%writefile mednist_classifier_bentoml.py
from typing import BinaryIO, List
import numpy as np
from PIL import Image
import torch
from monai.transforms import (
AddChannel,
Compose,
Transform,
ScaleIntensity,
EnsureType,
)
import bentoml
from bentoml.frameworks.pytorch import PytorchModelArtifact
from bentoml.adapters import FileInput, JsonOutput
from bentoml.utils import cached_property
MEDNIST_CLASSES = ["AbdomenCT", "BreastMRI", "CXR", "ChestCT", "Hand", "HeadCT"]
class LoadStreamPIL(Transform):
"""Load an image file from a data stream using PIL."""
def __init__(self, mode=None):
self.mode = mode
def __call__(self, stream):
img = Image.open(stream)
if self.mode is not None:
img = img.convert(mode=self.mode)
return np.array(img)
@bentoml.env(pip_packages=["torch", "numpy", "monai", "pillow"])
@bentoml.artifacts([PytorchModelArtifact("classifier")])
class MedNISTClassifier(bentoml.BentoService):
@cached_property
def transform(self):
return Compose([LoadStreamPIL("L"), AddChannel(), ScaleIntensity(), EnsureType()])
@bentoml.api(input=FileInput(), output=JsonOutput(), batch=True)
def predict(self, file_streams: List[BinaryIO]) -> List[str]:
img_tensors = list(map(self.transform, file_streams))
batch = torch.stack(img_tensors).float()
with torch.no_grad():
outputs = self.artifacts.classifier(batch)
_, output_classes = outputs.max(dim=1)
return [MEDNIST_CLASSES[oc] for oc in output_classes]
Overwriting mednist_classifier_bentoml.py
今はスクリプトがロードされて分類器アーティファクトはネットワーク状態とともにパックされます。そしてこれはローカルマシンのレポジトリ・ディレクトリに保存されます :
from mednist_classifier_bentoml import MedNISTClassifier # noqa: E402
bento_svc = MedNISTClassifier()
bento_svc.pack('classifier', net.cpu().eval())
saved_path = bento_svc.save()
print(saved_path)
[2021-03-02 00:39:04,202] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService [2021-03-02 00:39:04,204] WARNING - pip package requirement torch already exist [2021-03-02 00:39:05,494] INFO - BentoService bundle 'MedNISTClassifier:20210302003904_AC4A5D' saved to: /home/localek10/bentoml/repository/MedNISTClassifier/20210302003904_AC4A5D /home/localek10/bentoml/repository/MedNISTClassifier/20210302003904_AC4A5D
このレポジトリの内容を見ることができます、これはコードとセットアップ・スクリプトを含みます :
!ls -l {saved_path}
total 44 -rwxr--r-- 1 localek10 bioeng 2411 Mar 2 00:39 bentoml-init.sh -rw-r--r-- 1 localek10 bioeng 875 Mar 2 00:39 bentoml.yml -rwxr--r-- 1 localek10 bioeng 699 Mar 2 00:39 docker-entrypoint.sh -rw-r--r-- 1 localek10 bioeng 1205 Mar 2 00:39 Dockerfile -rw-r--r-- 1 localek10 bioeng 70 Mar 2 00:39 environment.yml -rw-r--r-- 1 localek10 bioeng 72 Mar 2 00:39 MANIFEST.in drwxr-xr-x 4 localek10 bioeng 4096 Mar 2 00:39 MedNISTClassifier -rw-r--r-- 1 localek10 bioeng 5 Mar 2 00:39 python_version -rw-r--r-- 1 localek10 bioeng 298 Mar 2 00:39 README.md -rw-r--r-- 1 localek10 bioeng 69 Mar 2 00:39 requirements.txt -rw-r--r-- 1 localek10 bioeng 1691 Mar 2 00:39 setup.py
このレポジトリはストアされたプログラムのように実行できます、そこでは使用したい名前と API 名 (“predict”) によりそれを起動してファイルとして入力を提供します :
!bentoml run MedNISTClassifier:latest predict --input-file {image_files[0][0]}
[2021-03-02 00:39:16,999] INFO - Getting latest version MedNISTClassifier:20210302003904_AC4A5D [2021-03-02 00:39:19,508] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService [2021-03-02 00:39:19,508] WARNING - pip package requirement torch already exist [2021-03-02 00:39:20,329] INFO - {'service_name': 'MedNISTClassifier', 'service_version': '20210302003904_AC4A5D', 'api': 'predict', 'task': {'data': {'uri': 'file:///tmp/tmphl16qkwk/MedNIST/AbdomenCT/006160.jpeg', 'name': '006160.jpeg'}, 'task_id': '6d4680de-f719-4e04-abde-00c7d8a6110d', 'cli_args': ('--input-file', '/tmp/tmphl16qkwk/MedNIST/AbdomenCT/006160.jpeg'), 'inference_job_args': {}}, 'result': {'data': '"Hand"', 'http_status': 200, 'http_headers': (('Content-Type', 'application/json'),)}, 'request_id': '6d4680de-f719-4e04-abde-00c7d8a6110d'} "Hand"
サービスはまた Flask web サーバでも実行できます。以下のスクリプトはサービスを開始し、進むのを待ち、予測を得るために POST リクエストとしてテストファイルを送るために curl を使用して、そしてサーバを kill します :
%%bash -s {image_files[0][0]}
# filename passed in as an argument to the cell
test_file=$1
# start the Flask-based server, sending output to /dev/null for neatness
bentoml serve --port=8000 MedNISTClassifier:latest &> /dev/null &
# recall the PID of the server and wait for it to start
lastpid=$!
sleep 5
# send the test file using curl and capture the returned string
result=$(curl -s -X POST "http://127.0.0.1:8000/predict" -F image=@$test_file)
# kill the server
kill $lastpid
echo "Prediction: $result"
Prediction: "AbdomenCT"
The service can be packaged as a Docker container to be started elsewhere as a server:
!bentoml containerize MedNISTClassifier:latest -t mednist-classifier:latest
[2021-03-02 00:40:48,846] INFO - Getting latest version MedNISTClassifier:20210302003904_AC4A5D Found Bento: /home/localek10/bentoml/repository/MedNISTClassifier/20210302003904_AC4A5D Containerizing MedNISTClassifier:20210302003904_AC4A5D with local YataiService and docker daemon from local environment\WARNING: No swap limit support |Build container image: mednist-classifier:latest
!docker image ls
REPOSITORY TAG IMAGE ID CREATED SIZE mednist-classifier latest 326ab3f07478 15 seconds ago 2.94GB <none> <none> 87e9c5c97297 2 days ago 2.94GB <none> <none> cb62f45a9163 2 days ago 1.14GB bentoml/model-server 0.11.0-py38 387830631375 6 weeks ago 1.14GB sshtest latest 1be604ad1135 3 months ago 225MB ubuntu 20.04 9140108b62dc 5 months ago 72.9MB ubuntu latest 9140108b62dc 5 months ago 72.9MB nvcr.io/nvidia/pytorch 20.09-py3 86042df4bd3c 5 months ago 11.1GB pytorch/pytorch 1.6.0-cuda10.1-cudnn7-runtime 6a2d656bcf94 7 months ago 3.47GB pytorch/pytorch latest 6a2d656bcf94 7 months ago 3.47GB python 3.7 22c70bba8283 7 months ago 920MB ubuntu 16.04 c522ac0d6194 7 months ago 126MB python 3.7-alpine 6a5ca85ed89b 9 months ago 72.5MB alpine 3.12 a24bb4013296 9 months ago 5.57MB hello-world latest bf756fb1ae65 14 months ago 13.3kB
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : モジュール概要 (2) データセットとデータローダ
MONAI 0.7 : モジュール概要 (2) データセットとデータローダ (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/19/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : モジュール概要 (2) データセットとデータローダ
Datasets と DataLoader
1. 訓練を高速化するためのキャッシュ IO と変換データ
ユーザは望まれるモデル品質を獲得するためにデータに渡り多くの (潜在的には数千の) エポックでモデルを訓練する必要が多くの場合あります。ネイティブ PyTorch 実装はデータを繰り返しロードして訓練の間総てのエポックについて同じ前処理ステップを実行する場合がありますが、これは時間がかかり不必要である可能性があります、特に医用画像ボリュームが大きいときには。
MONAI は、変換チェインの最初のランダム化された変換の前に、中間的な結果をストアして訓練の間これらの変換ステップを高速化するためにマルチスレッド CacheDataset と LMDBDataset を提供しています。この機能を有効にすると Datasets 実験の 10x 訓練スピードアップを潜在的に与えられるでしょう。
2. 中間結果を永続的なストレージにキャッシュする
PersistentDataset は CacheDataset に類似しています、そこでは (ハイパーパラメータ調整のときのような) 実験的な実行の間やデータセット全体のサイズが利用可能なメモリを越えるときの迅速な検索のために、中間的なキャッシュ値がディスク・ストレージや LMDB に永続化されます。PersistentDataset は Dataset 実験で CacheDataset と比較したとき同様の性能を獲得できました。
3. 大規模データベースのための SmartCache 機構
大規模なボリュームのデータセットによる訓練の間、効率的なアプローチはエポックでデータセットのサブセットだけを使用して訓練して総てのエポックでサブセットの部分を動的に置き換えることです。それは NVIDIA Clara-train SDK の SmartCache メカニズムです。
MONAI は PyTorch 版 SmartCache を SmartCacheDataset を提供しています。各エポックで、キャッシュの項目だけが訓練に使用され、同時に、別のスレッドがキャシュにない項目に変換シークエンスを提供することにより置換項目を準備しています。1 エポックが完了すれば、SmartCache は置換項目で同じ数の項目を置き換えます。
例えば、5 画像 : [image1, image2, image3, image4, image5] を持ち、そして cache_num=4, replace_rate=0.25 とします。するとキャッシュされて置換される、実際の訓練画像は以下のようになります :
epoch 1: [image1, image2, image3, image4] epoch 2: [image2, image3, image4, image5] epoch 3: [image3, image4, image5, image1] epoch 3: [image4, image5, image1, image2] epoch N: [image[N % 5] ...]
SmartCacheDataset の完全なサンプルは Distributed training with SmartCache で利用可能です。
4. マルチ PyTorch データセットの zip と出力の融合
MONAI は複数の PyTorch データセットを関連付けて出力データを (同じ対応するバッチインデックスで) タプルに連結するための ZipDataset を提供します、これは様々なデータソースに基づいて複雑な訓練プロセスを実行するの役立つことができます。
例えば :
class DatasetA(Dataset):
def __getitem__(self, index: int):
return image_data[index]
class DatasetB(Dataset):
def __getitem__(self, index: int):
return extra_data[index]
dataset = ZipDataset([DatasetA(), DatasetB()], transform)
5. PatchDataset
monai.data.PatchDataset は画像- とパッチ-レベル前処理の両方を組み合わせる柔軟な API を提供します :
image_dataset = Dataset(input_images, transforms=image_transforms)
patch_dataset = PatchDataset(
dataset=image_dataset, patch_func=sampler,
samples_per_image=n_samples, transform=patch_transforms)
それはカスタマイズ可能なパッチサンプリング・ストラテジーを使用してユーザ指定の image_transforms と patch_transforms をサポートします、これはマルチプロセス・コンテキストで 2 レベルの計算を切り離します。
6. 公開医療データの事前定義されたデータセット
医療ドメインのポピュラーな訓練データで素早く始めるために、MONAI は幾つかのデータ固有のデータセットを提供しています、これは AWS ストレージからのダウンロード、データファイルの抽出を含み、変換とともに訓練/評価項目の生成をサポートしています。そしてそれらはデフォルトの動作を変更する JSON config ファイルを簡単に変更できるという点で柔軟です。
MONAI は公開データセットの新しい寄与を常に歓迎しています、既存のデータセットを参照してダウンロードと抽出 API を活用してください、等々。公開データセット・チュートリアル は MedNISTDataset と DecathlonDataset で訓練ワークフローを素早くセットアップする方法と公開データのために新しいデータセットを作成する方法を示しています。
事前定義されたデータセットの一般的なワークフロー :
7. 交差検証のためにデータセットを分割する
MONAI の partition_dataset は訓練と検証のためあるいは交差検証のために様々なタイプの分割を実行できます。それは指定されたランダムシードに基づくシャッフルをサポートし、各データセットが一つの分割を含む、データセットのセットを返します。そしてそれは指定された比率に基づいてデータセットを分割したり、num_partitions に均等に分割することもできます。与えられたクラスラベルについて、総てのパーティションでクラスの同じ比率を確実にすることもできます。
8. CSV データセットと IterableDataset
CSV テーブルは画像データに加えて患者の人口統計, 検査結果, 画像収集パラメータと他の非画像データのような補助的な情報を組み込むためによく使用され、MONAI は CSV ファイルをロードするために CSVDataset をそしてスケーラブルなデータアクセスにより大規模な CSV ファイルをロードするために CSVIterableDataset をロードするために CSVIterableDataset を提供しています。ロードの間の通常の前処理変換に加えて、それはまた複数の CSV ファイルロード、テーブルの結合、行と列の選択とグループ化もサポートします。CSVDatasets チュートリアル は詳細な使用例を示しています。
9. ThreadDataLoader vs. DataLoader
変換が軽量である場合、特に総てのデータを RAM にキャッシュするとき、PyTorch DataLoader のマルチプロセッシングは不必要な IPC 時間を引き起こして総てのエポックの後に GPU 利用率の低下を引き起こす可能性があります。MONAI は変換を個別のスレッドで実行する ThreadDataLoader を提供しています :
ThreadDataLoader サンプルは 脾臓高速訓練チュートリアル で利用可能です。
以上
MONAI 0.7 : tutorials : モジュール – 脾臓セグメンテーション・タスクの後処理変換
MONAI 0.7 : tutorials : モジュール – 脾臓セグメンテーション・タスクの後処理変換 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/17/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : モジュール – 脾臓セグメンテーション・タスクの後処理変換
このノートブックは脾臓セグメンテーション・タスクのモデル出力に基づいて幾つかの後処理変換の使用方法を示します。
MONAI はモデル出力を処理するための後処理変換を提供しています。現在、変換は以下を含みます :
- Activations: 活性化層の追加 (Sigmoid, Softmax, etc.)。
- AsDiscrete: 離散値 (Argmax, One-Hot, Threshold 値等) に変換します。
- SplitChannel: マルチチャネル・データを複数のシングル・チャネルに分割する。
- KeepLargestConnectedComponent: セグメンテーション結果の輪郭を抽出します、これは元の画像へのマップに使用できてモデルを評価できます。
- LabelToContour: 接続コンポーネント分析に基づいてセグメンテーション・ノイズを除去する。
MONAI は同じデータ上で様々な前処理変換や後処理変換を適用して結果を結合するためのマルチ変換チェインをサポートしています、それはデータ辞書の指定された項目のコピーを作成する CopyItems 変換と想定される次元の指定項目を組み合わせるために ConcatItems 変換を提供し、そしてまたメモリを節約するために不要な項目を削除するための DeleteItems 変換も提供します。
典型的な使用方法は入力画像の 3 つの異なる強度範囲をスケールして結合することです :
このチュートリアルは脾臓セグメンテーションのモデル出力に基づいて上の後処理変換の幾つかを示します。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, skimage, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import set_determinism
from monai.transforms import (
AsDiscrete,
AddChanneld,
Compose,
CropForegroundd,
KeepLargestConnectedComponent,
LabelToContour,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureTyped,
EnsureType,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0 Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 7.0.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットのダウンロード
データセットをダウンロードして展開します。データセットは http://medicaldecathlon.com/ に由来します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
MSD 脾臓データセット・パスの設定
train_images = sorted(
glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
再現性のための決定論的訓練の設定
set_determinism(seed=0)
訓練と検証のための変換のセットアップ
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# randomly crop out patch samples from big image
# based on pos / neg ratio. the image centers
# of negative samples must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
EnsureTyped(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)
訓練と検証のための CacheDataset と DataLoader を定義する
train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=4)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_ds = CacheDataset(
data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
# val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
100%|██████████| 32/32 [00:48<00:00, 1.51s/it] 100%|██████████| 9/9 [00:11<00:00, 1.32s/it]
モデル、損失、Optimizer を作成する
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
典型的な PyTorch 訓練プロセスを実行する
max_epochs = 160
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f"{step}/{len(train_ds) // train_loader.batch_size},"
f"train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
# validation progress
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
dice_metric(y_pred=val_outputs, y=val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(),
"best_metric_model_post_transforms.pth")
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"\nbest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
検証データセット上で後処理変換を実行する
ここでは AsDiscrete, KeepLargestConnectedComponent と LabelToContour をテストします。
model.load_state_dict(torch.load("best_metric_model_post_transforms.pth"))
model.eval()
with torch.no_grad():
for i, val_data in enumerate(val_loader):
roi_size = (160, 160, 160)
sw_batch_size = 4
val_data = val_data["image"].to(device)
val_output = sliding_window_inference(
val_data, roi_size, sw_batch_size, model)
# plot the slice [:, :, 80]
plt.figure("check", (20, 4))
plt.subplot(1, 5, 1)
plt.title(f"image {i}")
plt.imshow(val_data.detach().cpu()[0, 0, :, :, 80], cmap="gray")
plt.subplot(1, 5, 2)
plt.title(f"argmax {i}")
argmax = [AsDiscrete(argmax=True)(i) for i in decollate_batch(val_output)]
plt.imshow(argmax[0].detach().cpu()[0, :, :, 80])
plt.subplot(1, 5, 3)
plt.title(f"largest {i}")
largest = [KeepLargestConnectedComponent(applied_labels=[1])(i) for i in argmax]
plt.imshow(largest[0].detach().cpu()[0, :, :, 80])
plt.subplot(1, 5, 4)
plt.title(f"contour {i}")
contour = [LabelToContour()(i) for i in largest]
plt.imshow(contour[0].detach().cpu()[0, :, :, 80])
plt.subplot(1, 5, 5)
plt.title(f"map image {i}")
map_image = contour[0] + val_data[0]
plt.imshow(map_image.detach().cpu()[0, :, :, 80], cmap="gray")
plt.show()
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション (Lightning 版)
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション (Lightning 版) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/16/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション (Lightning 版)
このノートブックは MONAI を PyTorch Lightning フレームワークと連携して使用できる可能性を示します。
このチュートリアルは MONAI をどのように PyTorch Lightning フレームワークと連携して使用できるかを実演します。
以下の MONAI の機能の使用方法を実演します :
- 辞書形式データのための変換。
- メタデータとともに Nifti 画像をロードする。
- チャネル次元がない場合チャネル dim をデータに追加する。
- 想定される範囲で医療画像強度をスケールする。
- ポジティブ/ネガティブ・ラベル比率に基づいてバランスの取れた画像のバッチをクロップする。
- 訓練と検証を高速化するキャシュ IO と変換。
- 3D セグメンテーション・タスクのための 3D UNet モデル、Dice 損失関数、Mean Dice メトリック。
- スライディング・ウィンドウ推論法。
- 再現性のための決定論的訓練。
Spleen データセットは http://medicaldecathlon.com/ からダウンロードできます。
- Target: Spleen
- Modality: CT
- Size: 61 3D volumes (41 Training + 20 Testing)
- Source: Memorial Sloan Kettering Cancer Center
- Challenge: Large ranging foreground size
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install -q pytorch-lightning==1.4.0
%matplotlib inline
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from monai.utils import set_determinism
from monai.transforms import (
AsDiscrete,
AddChanneld,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureTyped,
EnsureType,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import pytorch_lightning
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
print_config()
MONAI version: 0.6.0+1.g8365443a Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: 8365443ababac313340467e5987c7babe2b5b86a Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 8.2.0 Tensorboard version: 2.2.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical/
データセットのダウンロード
データセットをダウンロードして展開します。データセットは http://medicaldecathlon.com/ に由来します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
LightningModule を定義する
LightningModule は訓練コードのリファクタリングを含みます。以下のモジュールは spleen_segmentation_3d.ipynb のコードのリファクタリングです :
class Net(pytorch_lightning.LightningModule):
def __init__(self):
super().__init__()
self._model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
)
self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
self.post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])
self.post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
self.dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
self.best_val_dice = 0
self.best_val_epoch = 0
def forward(self, x):
return self._model(x)
def prepare_data(self):
# set up the correct data path
train_images = sorted(
glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
# set deterministic training for reproducibility
set_determinism(seed=0)
# define the data transforms
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# randomly crop out patch samples from
# big image based on pos / neg ratio
# the image centers of negative samples
# must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
# user can also add other random transforms
# RandAffined(
# keys=['image', 'label'],
# mode=('bilinear', 'nearest'),
# prob=1.0,
# spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15),
# scale_range=(0.1, 0.1, 0.1)),
EnsureTyped(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)
# we use cached datasets - these are 10x faster than regular datasets
self.train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=4,
)
self.val_ds = CacheDataset(
data=val_files, transform=val_transforms,
cache_rate=1.0, num_workers=4,
)
# self.train_ds = monai.data.Dataset(
# data=train_files, transform=train_transforms)
# self.val_ds = monai.data.Dataset(
# data=val_files, transform=val_transforms)
def train_dataloader(self):
train_loader = torch.utils.data.DataLoader(
self.train_ds, batch_size=2, shuffle=True,
num_workers=4, collate_fn=list_data_collate,
)
return train_loader
def val_dataloader(self):
val_loader = torch.utils.data.DataLoader(
self.val_ds, batch_size=1, num_workers=4)
return val_loader
def configure_optimizers(self):
optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
return optimizer
def training_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
output = self.forward(images)
loss = self.loss_function(output, labels)
tensorboard_logs = {"train_loss": loss.item()}
return {"loss": loss, "log": tensorboard_logs}
def validation_step(self, batch, batch_idx):
images, labels = batch["image"], batch["label"]
roi_size = (160, 160, 160)
sw_batch_size = 4
outputs = sliding_window_inference(
images, roi_size, sw_batch_size, self.forward)
loss = self.loss_function(outputs, labels)
outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
labels = [self.post_label(i) for i in decollate_batch(labels)]
self.dice_metric(y_pred=outputs, y=labels)
return {"val_loss": loss, "val_number": len(outputs)}
def validation_epoch_end(self, outputs):
val_loss, num_items = 0, 0
for output in outputs:
val_loss += output["val_loss"].sum().item()
num_items += output["val_number"]
mean_val_dice = self.dice_metric.aggregate().item()
self.dice_metric.reset()
mean_val_loss = torch.tensor(val_loss / num_items)
tensorboard_logs = {
"val_dice": mean_val_dice,
"val_loss": mean_val_loss,
}
if mean_val_dice > self.best_val_dice:
self.best_val_dice = mean_val_dice
self.best_val_epoch = self.current_epoch
print(
f"current epoch: {self.current_epoch} "
f"current mean dice: {mean_val_dice:.4f}"
f"\nbest mean dice: {self.best_val_dice:.4f} "
f"at epoch: {self.best_val_epoch}"
)
return {"log": tensorboard_logs}
訓練の実行
# initialise the LightningModule
net = Net()
# set up loggers and checkpoints
log_dir = os.path.join(root_dir, "logs")
tb_logger = pytorch_lightning.loggers.TensorBoardLogger(
save_dir=log_dir
)
# initialise Lightning's trainer.
trainer = pytorch_lightning.Trainer(
gpus=[0],
max_epochs=600,
logger=tb_logger,
checkpoint_callback=True,
num_sanity_val_steps=1,
)
# train
trainer.fit(net)
print(
f"train completed, best_metric: {net.best_val_dice:.4f} "
f"at epoch {net.best_val_epoch}")
train completed, best_metric: 0.9498 at epoch 563
tensorboard で訓練を見る
%load_ext tensorboard
%tensorboard --logdir=log_dir
The tensorboard extension is already loaded. To reload it, use: %reload_ext tensorboard Reusing TensorBoard on port 6006 (pid 27668), started 1:35:41 ago. (Use '!kill 27668' to kill it.)
入力画像とラベルでベストなモデル出力を確認する
net.eval()
device = torch.device("cuda:0")
net.to(device)
with torch.no_grad():
for i, val_data in enumerate(net.val_dataloader()):
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_data["image"].to(device), roi_size, sw_batch_size, net
)
# plot the slice [:, :, 80]
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image {i}")
plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
plt.subplot(1, 3, 2)
plt.title(f"label {i}")
plt.imshow(val_data["label"][0, 0, :, :, 80])
plt.subplot(1, 3, 3)
plt.title(f"output {i}")
plt.imshow(torch.argmax(
val_outputs, dim=1).detach().cpu()[0, :, :, 80])
plt.show()
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合はディレクトリを削除する。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/16/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : 3D セグメンテーション – 脾臓 3D セグメンテーション
このノートブックは MSD 脾臓データセット に基づいた 3D セグメンテーションの end-to-end な訓練と評価サンプルです。このサンプルは PyTorch ベースのプログラムで MONAI モジュールの柔軟性を示します :
- 辞書ベースの訓練データ構造のための変換。
- メタデータと共に NIfTI 画像をロードする。
- 想定する範囲で医療画像強度をスケールする。
- ポジティブ/ネガティブ・ラベル比率に基づいてバランスの取れた画像パッチサンプルのバッチをクロップする。
- 訓練と検証を高速化するキャッシュ IO と変換。
- 3D セグメンテーション・タスクのための 3D UNet, Dice 損失関数, Mean Dice メトリック。
- スライディング・ウィンドウ推論。
- 再現性のための決定論的訓練。
このチュートリアルは MONAI を既存の PyTorch 医療 DL プログラムに統合する方法を示します。
そして以下の機能を簡単に使用することができます :
- 辞書形式データのための変換。
- メタデータとともに Nifti 画像をロードする。
- チャネル次元がない場合チャネル dim をデータに追加する。
- 想定される範囲で医療画像強度をスケールする。
- ポジティブ/ネガティブ・ラベル比率に基づいてバランスの取れた画像のバッチをクロップする。
- 訓練と検証を高速化するキャシュ IO と変換。
- 3D セグメンテーション・タスクのための 3D UNet モデル、Dice 損失関数、Mean Dice メトリック。
- スライディング・ウィンドウ推論法。
- 再現性のための決定論的訓練。
Spleen データセットは http://medicaldecathlon.com/ からダウンロードできます。
- Target: Spleen
- Modality: CT
- Size: 61 3D volumes (41 Training + 20 Testing)
- Source: Memorial Sloan Kettering Cancer Center
- Challenge: Large ranging foreground size
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import first, set_determinism
from monai.transforms import (
AsDiscrete,
AsDiscreted,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
EnsureTyped,
EnsureType,
Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print_config()
MONAI version: 0.6.0+22.g027947bf Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: 027947bf91ff0dfac94f472ed1855cd49e3feb8d Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 8.2.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical/
データセットのダウンロード
データセットをダウンロードして展開します。データセットは http://medicaldecathlon.com/ に由来します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
MSD 脾臓データセット・パスの設定
train_images = sorted(
glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
再現性のための決定論的訓練の設定
set_determinism(seed=0)
訓練と検証のための変換のセットアップ
ここではデータセットを増強するために幾つかの変換を使用します :
- LoadImaged は NIfTI 形式ファイルから脾臓 CT 画像とラベルをロードします。
- AddChanneld は元のデータがチャネル dim を持たないとき、「チャネル first」shape を構築するために 1 dim 追加します。
- Spacingd はアフィン行列に基づいて pixdim=(1.5, 1.5, 2.) により spacing を調整します。
- Orientationd はアフィン行列に基づいてデータの向きを統一します。
- ScaleIntensityRanged は強度範囲 [-57, 164] を抽出して [0, 1] にスケールします。
- CropForegroundd は画像とラベルの valid body 領域にフォーカスするために総てのゼロ境界 (= border) を削除します。
- RandCropByPosNegLabeld は pos / neg 比率に基づいて大きな画像からランダムにパッチサンプルをクロップします。
ネガティブサンプルの画像の中心は valid body 領域内になければなりません。
- RandAffined は PyTorch アフィン変換に基づいて回転, スケール, shear, 並行移動等を一緒に効率的に実行します。
- EnsureTyped は更なるステップのために numpy 配列を PyTorch テンソルに変換します。
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key="image",
image_threshold=0,
),
# user can also add other random transforms
# RandAffined(
# keys=['image', 'label'],
# mode=('bilinear', 'nearest'),
# prob=1.0, spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15),
# scale_range=(0.1, 0.1, 0.1)),
EnsureTyped(keys=["image", "label"]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)
DataLoaer で変換を確認する
check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")
# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()
image shape: torch.Size([226, 157, 113]), label shape: torch.Size([226, 157, 113])
訓練と検証のために CacheDataset と DataLoader を定義する
ここで訓練と検証プロセスを高速化するために CacheDataset を使用し、それは通常の Dataset よりも 10x 高速です。ベストなパフォーマンスを得るためには、総てのデータをキャッシュするために cache_rate=1.0 を設定します、メモリが十分でない場合には、低い値を設定してください。ユーザはまた cache_rate の代わりに cache_num を設定することもできて、2 つの設定の最小値を使用します。そしてキャッシュの間にマルチスレッドを有効にするために num_workers を設定します。通常の Dataset を試したい場合は、下でコメントされたコードを単に使用するように変更してください。
train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=4)
# train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_ds = CacheDataset(
data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
Loading dataset: 100%|██████████| 32/32 [00:32<00:00, 1.02s/it] Loading dataset: 100%|██████████| 9/9 [00:07<00:00, 1.18it/s]
モデル、損失、Optimizer を作成する
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
典型的な PyTorch 訓練プロセスを実行する
max_epochs = 600
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f"{step}/{len(train_ds) // train_loader.batch_size}, "
f"train_loss: {loss.item():.4f}")
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(
root_dir, "best_metric_model.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"\nbest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}")
train completed, best_metric: 0.9510 at epoch: 598
損失とメトリックをプロットする
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()
入力画像とラベルでベストなモデル出力を確認する
model.load_state_dict(torch.load(
os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
for i, val_data in enumerate(val_loader):
roi_size = (160, 160, 160)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_data["image"].to(device), roi_size, sw_batch_size, model
)
# plot the slice [:, :, 80]
plt.figure("check", (18, 6))
plt.subplot(1, 3, 1)
plt.title(f"image {i}")
plt.imshow(val_data["image"][0, 0, :, :, 80], cmap="gray")
plt.subplot(1, 3, 2)
plt.title(f"label {i}")
plt.imshow(val_data["label"][0, 0, :, :, 80])
plt.subplot(1, 3, 3)
plt.title(f"output {i}")
plt.imshow(torch.argmax(
val_outputs, dim=1).detach().cpu()[0, :, :, 80])
plt.show()
if i == 2:
break
元の画像 spacing 上の評価
val_org_transforms = Compose(
[
LoadImaged(keys=["image", "label"]),
EnsureChannelFirstd(keys=["image", "label"]),
Spacingd(keys=["image"], pixdim=(
1.5, 1.5, 2.0), mode="bilinear"),
Orientationd(keys=["image"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"], a_min=-57, a_max=164,
b_min=0.0, b_max=1.0, clip=True,
),
CropForegroundd(keys=["image"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
)
val_org_ds = Dataset(
data=val_files, transform=val_org_transforms)
val_org_loader = DataLoader(val_org_ds, batch_size=1, num_workers=4)
post_transforms = Compose([
EnsureTyped(keys="pred"),
Invertd(
keys="pred",
transform=val_org_transforms,
orig_keys="image",
meta_keys="pred_meta_dict",
orig_meta_keys="image_meta_dict",
meta_key_postfix="meta_dict",
nearest_interp=False,
to_tensor=True,
),
AsDiscreted(keys="pred", argmax=True, to_onehot=True, num_classes=2),
AsDiscreted(keys="label", to_onehot=True, num_classes=2),
])
model.load_state_dict(torch.load(
os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
for val_data in val_org_loader:
val_inputs = val_data["image"].to(device)
roi_size = (160, 160, 160)
sw_batch_size = 4
val_data["pred"] = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
val_data = [post_transforms(i) for i in decollate_batch(val_data)]
val_outputs, val_labels = from_engine(["pred", "label"])(val_data)
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric_org = dice_metric.aggregate().item()
# reset the status for next validation round
dice_metric.reset()
print("Metric on original image spacing: ", metric_org)
Metric on original image spacing: 0.9637420177459717
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : モジュール – 3D 画像変換 (幾何学的変換)
MONAI 0.7 : tutorials : モジュール – 3D 画像変換 (幾何学的変換) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/15/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : モジュール – 3D 画像変換 (幾何学的変換)
このノートブックは volumetric 画像の変換を実演します。
このノートブックは 3D 画像のための MONAI の変換モジュールを紹介します。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import matplotlib" || pip install -q matplotlib
from monai.transforms import (
AddChanneld,
LoadImage,
LoadImaged,
Orientationd,
Rand3DElasticd,
RandAffined,
Spacingd,
)
from monai.config import print_config
from monai.apps import download_and_extract
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
print_config()
MONAI version: 0.4.0+721.g75b7a446 Numpy version: 1.21.2 Pytorch version: 1.10.0a0+3fd9dcf MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: 75b7a4462647bfbe9bc8e7d8e5bff1238b87990d Optional dependencies: Pytorch Ignite version: 0.4.6 Nibabel version: 3.2.1 scikit-image version: 0.18.3 Pillow version: 8.2.0 Tensorboard version: 2.6.0 gdown version: 3.13.1 TorchVision version: 0.11.0a0 tqdm version: 4.62.1 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.3.3 einops version: 0.3.2 transformers version: 4.10.2 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")
root dir is: /workspace/data/medical
データセットのダウンロード
データセットをダウンロードして展開します。データセットは http://medicaldecathlon.com/ に由来します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_dir = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
MSD 脾臓データセット・パスを設定する
以下は Task09_Spleen/imagesTr と Task09_Spleen/labelsTr からの画像とラベルをペアにグループ化します。
train_images = sorted(
glob.glob(os.path.join(data_dir, "imagesTr", "*.nii.gz")))
train_labels = sorted(
glob.glob(os.path.join(data_dir, "labelsTr", "*.nii.gz")))
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_data_dicts, val_data_dicts = data_dicts[:-9], data_dicts[-9:]
画像ファイル名は辞書のリストに体系化されます。
train_data_dicts[0]
{'image': '/workspace/data/medical/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label': '/workspace/data/medical/Task09_Spleen/labelsTr/spleen_10.nii.gz'}
データ辞書のリスト, train_data_dicts は PyTorch のデータローダで使用できます。
例えば :
from torch.utils.data import DataLoader
data_loader = DataLoader(train_data_dicts)
for training_sample in data_loader:
# run the deep learning training with training_sample
このチュートリアルの残りは、最終的に深層学習モデルにより消費される train_data_dict をデータ配列に変換する「変換」のセットを提示します。
NIfTI ファイルをロードする
MONAI の一つの設計上の選択は、それは高位ワークフロー・コンポーネントだけでなく、最小限機能する形で比較的低位の API も提供することです。
例えば、LoadImage クラスは基礎となる Nibabel 画像ローダの単純な呼び出し可能なラッパーです。幾つかの必要なシステムパラメータでローダを構築した後、NIfTI ファイル名と共にローダインスタンスを呼び出すと画像データ配列とアフィン情報やボクセル・サイズのようなメタデータを返します。
loader = LoadImage(dtype=np.float32)
image, metadata = loader(train_data_dicts[0]["image"])
print(f"input: {train_data_dicts[0]['image']}")
print(f"image shape: {image.shape}")
print(f"image affine:\n{metadata['affine']}")
print(f"image pixdim:\n{metadata['pixdim']}")
input: /workspace/data/medical/Task09_Spleen/imagesTr/spleen_10.nii.gz image shape: (512, 512, 55) image affine: [[ 0.97656202 0. 0. -499.02319336] [ 0. 0.97656202 0. -499.02319336] [ 0. 0. 5. 0. ] [ 0. 0. 0. 1. ]] image pixdim: [1. 0.976562 0.976562 5. 0. 0. 0. 0. ]
多くの場合、入力のグループを訓練サンプルとしてロードすることを望みます。例えば教師あり画像セグメンテーション・ネットワークの訓練は訓練サンプルとして画像とラベルのペアを必要とします。
入力のグループが一貫して前処理されることを保証するため、MONAI はまた最小限機能する変換のための辞書ベースのインターフェイスも提供します。
LoadImaged は LoadImage の対応する辞書ベース版です :
loader = LoadImaged(keys=("image", "label"))
data_dict = loader(train_data_dicts[0])
print(f"input:, {train_data_dicts[0]}")
print(f"image shape: {data_dict['image'].shape}")
print(f"label shape: {data_dict['label'].shape}")
print(f"image pixdim:\n{data_dict['image_meta_dict']['pixdim']}")
input:, {'image': '/workspace/data/medical/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label': '/workspace/data/medical/Task09_Spleen/labelsTr/spleen_10.nii.gz'} image shape: (512, 512, 55) label shape: (512, 512, 55) image pixdim: [1. 0.976562 0.976562 5. 0. 0. 0. 0. ]
image, label = data_dict["image"], data_dict["label"]
plt.figure("visualize", (8, 4))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 30], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 30])
plt.show()
チャネル次元を追加する
MONAI の画像変換の殆どは入力データが次の shape を持つことを仮定しています :
[num_channels, spatial_dim_1, spatial_dim_2, … ,spatial_dim_n]
(チャネル 1st は PyTorch で一般に使用されるので) それらが一貫して解釈できるためです 。
ここでは入力画像は shape (512, 512, 55) を持ちますが、これは受け入れられる shape ではありませんので (チャネル次元が欠落しています)、shape を更新するために呼び出される変換を作成します :
add_channel = AddChanneld(keys=["image", "label"])
datac_dict = add_channel(data_dict)
print(f"image shape: {datac_dict['image'].shape}")
image shape: (1, 512, 512, 55)
今は幾つかの強度と空間変換を行なう準備ができました。
一貫したボクセルサイズへの再サンプリング
入力ボリュームは異なるボクセルサイズを持つかもしれません。以下の変換はボリュームが (1.5, 1.5, 5.) mm ボクセルサイズを持つように正規化するために作成します。変換は data_dict[‘image.affine’] からの元のボクセルサイズを読むように設定されています、これは対応する NIfTI ファイルからのもので、LoadImaged により先にロードされます。
spacing = Spacingd(keys=["image", "label"], pixdim=(
1.5, 1.5, 5.0), mode=("bilinear", "nearest"))
data_dict = spacing(datac_dict)
print(f"image shape: {data_dict['image'].shape}")
print(f"label shape: {data_dict['label'].shape}")
print(f"image affine after Spacing:\n{data_dict['image_meta_dict']['affine']}")
print(f"label affine after Spacing:\n{data_dict['label_meta_dict']['affine']}")
image shape: (1, 334, 334, 55) label shape: (1, 334, 334, 55) image affine after Spacing: [[ 1.5 0. 0. -499.02319336] [ 0. 1.5 0. -499.02319336] [ 0. 0. 5. 0. ] [ 0. 0. 0. 1. ]] label affine after Spacing: [[ 1.5 0. 0. -499.02319336] [ 0. 1.5 0. -499.02319336] [ 0. 0. 5. 0. ] [ 0. 0. 0. 1. ]]
spacing の変更を追跡するため、data_dict は Spacingd により更新されました :
- image.original_affine キーが data_dict に追加されて、元のアフィンを記録します。
- image.affine キーは現在のアフィンを持つように更新されます。
image, label = data_dict["image"], data_dict["label"]
plt.figure("visualise", (8, 4))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[0, :, :, 30], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[0, :, :, 30])
plt.show()
指定された軸コードへの再方向付け (= Reorientation)
時に総ての入力ボリュームを一貫した軸方向にすることが望ましいです。デフォルトの軸ラベルは Left (L), Right (R), Posterior (P), Anterior (A), Inferior (I), Superior (S) です。以下の変換はボリュームを ‘Posterior, Left, Inferior’ (PLI) 方向を持つように再方向付けるために作成されます :
orientation = Orientationd(keys=["image", "label"], axcodes="PLI")
data_dict = orientation(data_dict)
print(f"image shape: {data_dict['image'].shape}")
print(f"label shape: {data_dict['label'].shape}")
print(f"image affine after Spacing:\n{data_dict['image_meta_dict']['affine']}")
print(f"label affine after Spacing:\n{data_dict['label_meta_dict']['affine']}")
image shape: (1, 334, 334, 55) label shape: (1, 334, 334, 55) image affine after Spacing: [[ 0. -1.5 0. 0.47680664] [ -1.5 0. 0. 0.47680664] [ 0. 0. -5. 270. ] [ 0. 0. 0. 1. ]] label affine after Spacing: [[ 0. -1.5 0. 0.47680664] [ -1.5 0. 0. 0.47680664] [ 0. 0. -5. 270. ] [ 0. 0. 0. 1. ]]
image, label = data_dict["image"], data_dict["label"]
plt.figure("visualise", (8, 4))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[0, :, :, 30], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[0, :, :, 30])
plt.show()
ランダムなアフィン変換
以下のアフィン変換は (300, 300, 50) 画像パッチを出力するように定義されています。パッチの一は x, y と z 軸のそれぞれについて (-40, 40), (-40, 40), (-2, 2) の範囲でランダムに選択されます。変換は画像の中心に対して相対的です。3D 回転角度は z 軸周りに (-45, 45) 度、そして x と y 軸周りに 5 度からランダムに選択されます。ランダムなスケーリング因子は各軸に沿って (1.0 – 0.15, 1.0 + 0.15) からランダムに選択されます。
rand_affine = RandAffined(
keys=["image", "label"],
mode=("bilinear", "nearest"),
prob=1.0,
spatial_size=(300, 300, 50),
translate_range=(40, 40, 2),
rotate_range=(np.pi / 36, np.pi / 36, np.pi / 4),
scale_range=(0.15, 0.15, 0.15),
padding_mode="border",
)
rand_affine.set_random_state(seed=123)
元の画像の様々なランダム化されたバージョンを生成するためにこのセルを再実行できます。
affined_data_dict = rand_affine(data_dict)
print(f"image shape: {affined_data_dict['image'].shape}")
image, label = affined_data_dict["image"][0], affined_data_dict["label"][0]
plt.figure("visualise", (8, 4))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 15], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 15])
plt.show()
image shape: (1, 300, 300, 50)
ランダムな elastic 変形
同様に、以下の elastic 変形は (300, 300, 10) 画像パッチを出力するように定義されています。画像はアフィン変換と elastic 変形の組み合わせから再サンプリングされます。sigma_range は変形の滑らかさを制御します (15 より大きいと CPU 上では遅くなる可能性があります)。magnitude_range は変形の振幅を制御します (500 より大きいと、画像が非現実的になります)。
rand_elastic = Rand3DElasticd(
keys=["image", "label"],
mode=("bilinear", "nearest"),
prob=1.0,
sigma_range=(5, 8),
magnitude_range=(100, 200),
spatial_size=(300, 300, 10),
translate_range=(50, 50, 2),
rotate_range=(np.pi / 36, np.pi / 36, np.pi),
scale_range=(0.15, 0.15, 0.15),
padding_mode="border",
)
rand_elastic.set_random_state(seed=123)
元の画像の様々なランダム化されたバージョンを生成するためにこのセルを再実行できます。
deformed_data_dict = rand_elastic(data_dict)
print(f"image shape: {deformed_data_dict['image'].shape}")
image, label = deformed_data_dict["image"][0], deformed_data_dict["label"][0]
plt.figure("visualise", (8, 4))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 5], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 5])
plt.show()
image shape: (1, 300, 300, 10)
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : 高速化 – MONAI 機能による高速訓練
MONAI 0.7 : tutorials : 高速化 – MONAI 機能による高速訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/15/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : 高速化 – MONAI 機能による高速訓練
このドキュメントは、訓練パイプラインをプロファイルする方法、データセットを分析して適切なアルゴリズムを選択する方法、そしてシングル GPU、マルチ GPU 更にはマルチノードで GPU 利用を最適化する方法の詳細を紹介します。
このチュートリアルは PyTorch 訓練プログラムと MONAI 最適化訓練プログラムを示し、そしてパフォーマンスを比較します :
- AMP (Auto 混合精度)
- 決定論的変換のための CacheDataset
- データを GPU とキャッシュに移してから、GPU 上でランダムな変換を実行する。
- 軽量タスクではマルチスレッド化された ThreadDataLoader は PyTorch DataLoader よりも高速です。
- 通常の Dice 損失の代わりに MONAI DiceCE 損失を使用する。
- 通常の Adam optimizer の代わりに MONAI Novograd optimizer を使用する。
V100 GPU で、1 分内に 0.95 の検証平均 dice (損失) への訓練収束を獲得できます、それは同じメトリック達成するときの PyTorch の通常の実装と比べておよそ 200x 高速化しています。そして総てのエポックは通常の訓練よりも 20x 高速です。
それは 脾臓 3D セグメンテーション・チュートリアル ノートブックからの変更で、脾臓データセットは http://medicaldecathlon.com/ からダウンロードできます。
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
インポートのセットアップ
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import math
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
import torch
from torch.optim import Adam
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import (
CacheDataset,
DataLoader,
ThreadDataLoader,
Dataset,
decollate_batch,
)
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric
from monai.networks.layers import Norm
from monai.networks.nets import UNet
from monai.optimizers import Novograd
from monai.transforms import (
AddChanneld,
AsDiscrete,
Compose,
CropForegroundd,
FgBgToIndicesd,
LoadImaged,
Orientationd,
RandCropByPosNegLabeld,
ScaleIntensityRanged,
Spacingd,
ToDeviced,
EnsureTyped,
EnsureType,
)
from monai.utils import get_torch_version_tuple, set_determinism
print_config()
if get_torch_version_tuple() < (1, 6):
raise RuntimeError(
"AMP feature only exists in PyTorch version greater than v1.6."
)
MONAI version: 0.2.0+1008.gf65d296f Numpy version: 1.21.2 Pytorch version: 1.10.0a0+3fd9dcf MONAI flags: HAS_EXT = False, USE_COMPILED = False MONAI rev id: f65d296fe780f869dca84b6714dc36b94794930e Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.18.3 Pillow version: 8.2.0 Tensorboard version: 2.6.0 gdown version: 3.13.0 TorchVision version: 0.11.0a0 tqdm version: 4.62.1 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.3.2 einops version: 0.3.2 transformers version: NOT INSTALLED or UNKNOWN VERSION. For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(f"root dir is: {root_dir}")
root dir is: /workspace/data/medical
データセットのダウンロード
Decathlon 脾臓データセットをダウンロードして展開します。
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = os.path.join(root_dir, "Task09_Spleen.tar")
data_root = os.path.join(root_dir, "Task09_Spleen")
if not os.path.exists(data_root):
download_and_extract(resource, compressed_file, root_dir, md5)
MSD 脾臓データセット・パスの設定
train_images = sorted(
glob.glob(os.path.join(data_root, "imagesTr", "*.nii.gz"))
)
train_labels = sorted(
glob.glob(os.path.join(data_root, "labelsTr", "*.nii.gz"))
)
data_dicts = [
{"image": image_name, "label": label_name}
for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
訓練と検証のための変換のセットアップ
def transformations(fast=False):
train_transforms = [
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
# change to execute transforms with Tensor data
EnsureTyped(keys=["image", "label"]),
]
if fast:
# move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
train_transforms.append(
ToDeviced(keys=["image", "label"], device="cuda:0")
)
train_transforms.extend([
ScaleIntensityRanged(
keys=["image"],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
# pre-compute foreground and background indexes
# and cache them to accelerate training
FgBgToIndicesd(
keys="label",
fg_postfix="_fg",
bg_postfix="_bg",
image_key="image",
),
# randomly crop out patch samples from big
# image based on pos / neg ratio
# the image centers of negative samples
# must be in valid image area
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
fg_indices_key="label_fg",
bg_indices_key="label_bg",
),
])
val_transforms = [
LoadImaged(keys=["image", "label"]),
AddChanneld(keys=["image", "label"]),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
Orientationd(keys=["image", "label"], axcodes="RAS"),
ScaleIntensityRanged(
keys=["image"],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
EnsureTyped(keys=["image", "label"]),
]
if fast:
# move the data to GPU and cache to avoid CPU -> GPU sync in every epoch
val_transforms.append(
ToDeviced(keys=["image", "label"], device="cuda:0")
)
return Compose(train_transforms), Compose(val_transforms)
訓練手順を定義する
典型的な PyTorch 通常の学習手続きについては、 モデルを訓練するために通常の Dataset, DataLoader, Adam optimizer と Dice 損失を使用します。
MONAI 高速訓練手順については、主として以下の機能を導入します :
- AMP (auto 混合精度): AMP は PyTorch v1.6 でリリースされた重要な機能で、NVIDIA CUDA 11 は AMP の強力なサポートを追加して訓練スピードを大幅に改善しました。
- CacheDataset: キャッシュ機構を持つ Dataset で、訓練の間にデータをロードして決定論的変換の結果をキャッシュできます。
- ToDeviced 変換: データを GPU に移して CacheDataset でキャッシュしてから、直接 GPU 上でランダムな変換を実行し、総てのエポックでの CPU -> GPU 同期を回避します。総ての MONAI 変換が GPU 演算をサポートはしてないことに注意してください、作業は進行中です。
- ThreadDataLoader: マルチプロセス処理の代わりにマルチスレッドを使用します、殆どの計算の結果を既にキャッシュしていますので軽量タスクでは DataLoader よりも高速です。
- Novograd optimizer: Novograd はペーパー "Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks" < https://arxiv.org/pdf/1905.11286.pdf > に基づいています。
- DiceCE 損失関数: Dice 損失と交差エントロピー損失を計算して、これら 2 つの損失の重み付けられた合計を返します。
def train_process(fast=False):
max_epochs = 300
learning_rate = 2e-4
val_interval = 1 # do validation for every epoch
train_trans, val_trans = transformations(fast=fast)
# set CacheDataset, ThreadDataLoader and DiceCE loss for MONAI fast training
if fast:
train_ds = CacheDataset(
data=train_files,
transform=train_trans,
cache_rate=1.0,
num_workers=8,
)
val_ds = CacheDataset(
data=val_files, transform=val_trans, cache_rate=1.0, num_workers=5
)
# disable multi-workers because `ThreadDataLoader` works with multi-threads
train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=4, shuffle=True)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1)
loss_function = DiceCELoss(to_onehot_y=True, softmax=True, squared_pred=True, batch=True)
else:
train_ds = Dataset(data=train_files, transform=train_trans)
val_ds = Dataset(data=val_files, transform=val_trans)
# num_worker=4 is the best parameter according to the test
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
device = torch.device("cuda:0")
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=2,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
norm=Norm.BATCH,
).to(device)
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=True, num_classes=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=True, num_classes=2)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
# set Novograd optimizer for MONAI training
if fast:
# Novograd paper suggests to use a bigger LR than Adam,
# because Adam does normalization by element-wise second moments
optimizer = Novograd(model.parameters(), learning_rate * 10)
scaler = torch.cuda.amp.GradScaler()
else:
optimizer = Adam(model.parameters(), learning_rate)
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
epoch_times = []
total_start = time.time()
for epoch in range(max_epochs):
epoch_start = time.time()
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step_start = time.time()
step += 1
inputs, labels = (
batch_data["image"].to(device),
batch_data["label"].to(device),
)
optimizer.zero_grad()
# set AMP for MONAI training
if fast:
with torch.cuda.amp.autocast():
outputs = model(inputs)
loss = loss_function(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = math.ceil(len(train_ds) / train_loader.batch_size)
print(
f"{step}/{epoch_len}, train_loss: {loss.item():.4f}"
f" step time: {(time.time() - step_start):.4f}"
)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (160, 160, 160)
sw_batch_size = 4
# set AMP for MONAI validation
if fast:
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model
)
else:
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model
)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
dice_metric(y_pred=val_outputs, y=val_labels)
metric = dice_metric.aggregate().item()
dice_metric.reset()
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
best_metrics_epochs_and_time[0].append(best_metric)
best_metrics_epochs_and_time[1].append(best_metric_epoch)
best_metrics_epochs_and_time[2].append(
time.time() - total_start
)
torch.save(model.state_dict(), "best_metric_model.pth")
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current"
f" mean dice: {metric:.4f}"
f" best mean dice: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
)
print(
f"time consuming of epoch {epoch + 1} is:"
f" {(time.time() - epoch_start):.4f}"
)
epoch_times.append(time.time() - epoch_start)
total_time = time.time() - total_start
print(
f"train completed, best_metric: {best_metric:.4f}"
f" at epoch: {best_metric_epoch}"
f" total time: {total_time:.4f}"
)
return (
max_epochs,
epoch_loss_values,
metric_values,
epoch_times,
best_metrics_epochs_and_time,
total_time,
)
決定論を有効にして通常の PyTorch 訓練を実行する
set_determinism(seed=0)
regular_start = time.time()
(
epoch_num,
epoch_loss_values,
metric_values,
epoch_times,
best,
train_time,
) = train_process(fast=False)
total_time = time.time() - regular_start
print(
f"total time of {epoch_num} epochs with regular PyTorch training: {total_time:.4f}"
)
決定論を有効にして MONAI 最適化訓練を実行する
set_determinism(seed=0)
monai_start = time.time()
(
epoch_num,
m_epoch_loss_values,
m_metric_values,
m_epoch_times,
m_best,
m_train_time,
) = train_process(fast=True)
m_total_time = time.time() - monai_start
print(
f"total time of {epoch_num} epochs with MONAI fast training: {m_train_time:.4f},"
f" time of preparing cache: {(m_total_time - m_train_time):.4f}"
)
訓練損失と検証メトリクスをプロットする
plt.figure("train", (12, 12))
plt.subplot(2, 2, 1)
plt.title("Regular Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="red")
plt.subplot(2, 2, 2)
plt.title("Regular Val Mean Dice")
x = [i + 1 for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.ylim(0, 1)
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="red")
plt.subplot(2, 2, 3)
plt.title("Fast Epoch Average Loss")
x = [i + 1 for i in range(len(m_epoch_loss_values))]
y = m_epoch_loss_values
plt.xlabel("epoch")
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="green")
plt.subplot(2, 2, 4)
plt.title("Fast Val Mean Dice")
x = [i + 1 for i in range(len(m_metric_values))]
y = m_metric_values
plt.xlabel("epoch")
plt.ylim(0, 1)
plt.grid(alpha=0.4, linestyle=":")
plt.plot(x, y, color="green")
plt.show()
合計時間と総てのエポック時間をプロットする
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Total Train Time(300 epochs)")
plt.bar(
"regular PyTorch", total_time, 1, label="Regular training", color="red"
)
plt.bar("Fast", m_total_time, 1, label="Fast training", color="green")
plt.ylabel("secs")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.subplot(1, 2, 2)
plt.title("Epoch Time")
x = [i + 1 for i in range(len(epoch_times))]
plt.xlabel("epoch")
plt.ylabel("secs")
plt.plot(x, epoch_times, label="Regular training", color="red")
plt.plot(x, m_epoch_times, label="Fast training", color="green")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.show()
メトリクスを取得するための合計時間をプロットする
def get_best_metric_time(threshold, best_values):
for i, v in enumerate(best_values[0]):
if v > threshold:
return best_values[2][i]
return -1
def get_best_metric_epochs(threshold, best_values):
for i, v in enumerate(best_values[0]):
if v > threshold:
return best_values[1][i]
return -1
def get_label(index):
if index == 0:
return "Regular training"
elif index == 1:
return "Fast training"
else:
return None
plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Metrics Time")
plt.xlabel("secs")
plt.ylabel("best mean_dice")
plt.plot(best[2], best[0], label="Regular training", color="red")
plt.plot(m_best[2], m_best[0], label="Fast training", color="green")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.subplot(1, 3, 2)
plt.title("Typical Metrics Time")
plt.xlabel("best mean_dice")
plt.ylabel("secs")
labels = ["0.90", "0.90 ", "0.93", "0.93 ", "0.95", "0.95 ", "0.97", "0.97 "]
x_values = [0.9, 0.9, 0.93, 0.93, 0.95, 0.95, 0.97, 0.97]
for i, (l, x) in enumerate(zip(labels, x_values)):
value = int(get_best_metric_time(x, best if i % 2 == 0 else m_best))
color = "red" if i % 2 == 0 else "green"
plt.bar(l, value, 0.5, label=get_label(i), color=color)
plt.text(l, value, "%s" % value, ha="center", va="bottom")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.subplot(1, 3, 3)
plt.title("Typical Metrics Epochs")
plt.xlabel("best mean_dice")
plt.ylabel("epochs")
for i, (l, x) in enumerate(zip(labels, x_values)):
value = int(get_best_metric_epochs(x, best if i % 2 == 0 else m_best))
color = "red" if i % 2 == 0 else "green"
plt.bar(l, value, 0.5, label=get_label(i), color=color)
plt.text(l, value, "%s" % value, ha="center", va="bottom")
plt.grid(alpha=0.4, linestyle=":")
plt.legend(loc="best")
plt.show()
データディレクトリのクリーンアップ
一時ディレクトリが使用された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (配列版)
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (配列版) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/14/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (配列版)
このノートブックは GanTrainer、モジュール化された敵対的学習のための MONAI ワークフロー・エンジンを示します。MedNIST ハンド CT スキャン・データセットを使用して医療画像再構築ネットワークを訓練します。配列バージョン。
MONAI フレームワークは敵対的生成ネットワークを簡単に設計し、訓練して評価するために使用できます。このノートブックは、ハンド CT スキャンの画像を再構築するために単純な GAN モデルを設計して訓練する MONAI コンポーネントを使用する実例を示します。
ネットワーク・アーキテクチャと損失関数についての詳細は MONAI Mednist GAN チュートリアル を読んでください。
Step 1: セットアップ
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[ignite, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import set_determinism
from monai.transforms import (
AddChannel,
Compose,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
EnsureType,
Transform,
)
from monai.networks.nets import Discriminator, Generator
from monai.networks import normal_init
from monai.handlers import CheckpointSaver, MetricLogger, StatsHandler
from monai.engines.utils import GanKeys, default_make_latent
from monai.engines import GanTrainer
from monai.data import CacheDataset, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import numpy as np
import torch
import matplotlib.pyplot as plt
import shutil
import sys
import logging
import tempfile
import os
インポートのセットアップ
print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0 Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 7.0.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットをダウンロードする
データセットをダウンロードして展開します。
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
MedNIST データセットを使用する場合、出典を明示してください、e.g. https://github.com/Project-MONAI/tutorials/blob/master/2d_classification/mednist_tutorial.ipynb。
resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
hands = [
os.path.join(data_dir, "Hand", x)
for x in os.listdir(os.path.join(data_dir, "Hand"))
]
print(hands[:5])
['/workspace/data/medical/MedNIST/Hand/000317.jpeg', '/workspace/data/medical/MedNIST/Hand/002344.jpeg', '/workspace/data/medical/MedNIST/Hand/000816.jpeg', '/workspace/data/medical/MedNIST/Hand/004046.jpeg', '/workspace/data/medical/MedNIST/Hand/003316.jpeg']
Step 2: MONAI コンポーネントを初期化する
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
set_determinism(0)
device = torch.device("cuda:0")
画像変換チェインを作成する
セーブされたディスク画像を利用可能なテンソルに変換するために処理パイプラインを定義します。
class LoadTarJpeg(Transform):
def __call__(self, data):
return plt.imread(data)
train_transforms = Compose(
[
LoadTarJpeg(),
AddChannel(),
ScaleIntensity(),
RandRotate(range_x=np.pi / 12, prob=0.5, keep_size=True),
RandFlip(spatial_axis=0, prob=0.5),
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureType(),
]
)
データセットとデーたローダを作成する
データを保持して訓練の間にバッチを提示します。
real_dataset = CacheDataset(hands, train_transforms)
100%|██████████| 10000/10000 [00:09<00:00, 1092.83it/s]
batch_size = 300
real_dataloader = DataLoader(
real_dataset, batch_size=batch_size, shuffle=True, num_workers=10)
# We don't need to do any preparing so just return "as is"
def prepare_batch(batchdata, device=None, non_blocking=False):
return batchdata.to(device=device, non_blocking=non_blocking)
generator と discriminator を定義する
基本的なコンピュータビジョン GAN ネットワークをライブラリからロードします。
# define networks
disc_net = Discriminator(
in_shape=(1, 64, 64),
channels=(8, 16, 32, 64, 1),
strides=(2, 2, 2, 2, 1),
num_res_units=1,
kernel_size=5,
).to(device)
latent_size = 64
gen_net = Generator(
latent_shape=latent_size,
start_shape=(latent_size, 8, 8),
channels=[32, 16, 8, 1],
strides=[2, 2, 2, 1],
)
gen_net.conv.add_module("activation", torch.nn.Sigmoid())
gen_net = gen_net.to(device)
# initialize both networks
disc_net.apply(normal_init)
gen_net.apply(normal_init)
# define optimizors
learning_rate = 2e-4
betas = (0.5, 0.999)
disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas)
gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas)
# define loss functions
disc_loss_criterion = torch.nn.BCELoss()
gen_loss_criterion = torch.nn.BCELoss()
real_label = 1
fake_label = 0
def discriminator_loss(gen_images, real_images):
real = real_images.new_full((real_images.shape[0], 1), real_label)
gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
realloss = disc_loss_criterion(disc_net(real_images), real)
genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
return (genloss + realloss) / 2
def generator_loss(gen_images):
output = disc_net(gen_images)
cats = output.new_full(output.shape, real_label)
return gen_loss_criterion(output, cats)
訓練ハンドラを作成する
モデル訓練の間に操作を実行します。
metric_logger = MetricLogger(
loss_transform=lambda x: {
GanKeys.GLOSS: x[GanKeys.GLOSS], GanKeys.DLOSS: x[GanKeys.DLOSS]},
metric_transform=lambda x: x,
)
handlers = [
StatsHandler(
name="batch_training_loss",
output_transform=lambda x: {
GanKeys.GLOSS: x[GanKeys.GLOSS],
GanKeys.DLOSS: x[GanKeys.DLOSS],
},
),
CheckpointSaver(
save_dir=os.path.join(root_dir, "hand-gan"),
save_dict={"g_net": gen_net, "d_net": disc_net},
save_interval=10,
save_final=True,
epoch_level=True,
),
metric_logger,
]
GanTrainer を作成する
敵対的学習のための MONAI ワークフロー・エンジン。GanTrainer によってコンポーネントはここで集められます。
Goodfellow et al. 2014 https://arxiv.org/abs/1406.2661 に基づいた訓練ループを使用します。
- ランダムな潜在コードから m 個の fakes を生成します。
- これらの fakes と現在のバッチ reals で D を更新します、d_train_steps 回反復されます。
- 新しいランダムな潜在コードから m fakes を生成します。
- discriminator フィードバックを使用してこれらの fakes で generator を更新します。
disc_train_steps = 5
max_epochs = 50
trainer = GanTrainer(
device,
max_epochs,
real_dataloader,
gen_net,
gen_opt,
generator_loss,
disc_net,
disc_opt,
discriminator_loss,
d_prepare_batch=prepare_batch,
d_train_steps=disc_train_steps,
g_update_latents=True,
latent_shape=latent_size,
key_train_metric=None,
train_handlers=handlers,
)
Step 3: 訓練の開始
trainer.run()
結果を評価する
G と D の損失カーブを崩れていないか調べます。
g_loss = [loss[1][GanKeys.GLOSS] for loss in metric_logger.loss]
d_loss = [loss[1][GanKeys.DLOSS] for loss in metric_logger.loss]
plt.figure(figsize=(12, 5))
plt.semilogy(g_loss, label="Generator Loss")
plt.semilogy(d_loss, label="Discriminator Loss")
plt.grid(True, "both", "both")
plt.legend()
plt.show()
画像再構築を見る
ランダムな潜在コードで訓練された generator の出力を見ます。
test_img_count = 10
test_latents = default_make_latent(test_img_count, latent_size).to(device)
fakes = gen_net(test_latents)
fig, axs = plt.subplots(2, (test_img_count // 2), figsize=(20, 8))
axs = axs.flatten()
for i, ax in enumerate(axs):
ax.axis("off")
ax.imshow(fakes[i, 0].cpu().data.numpy(), cmap="gray")
データディレクトリのクリーンアップ
一時ディレクトリが作成された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版)
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/14/2021 (0.7.0)
* 本ページは、MONAI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- テレワーク & オンライン授業を支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
MONAI 0.7 : tutorials : モジュール – GAN ワークフロー・エンジン (辞書版)
このノートブックは GanTrainer、モジュール化された敵対的学習のための MONAI ワークフロー・エンジンを示します。MedNIST ハンド CT スキャン・データセットを使用して医療画像再構築ネットワークを訓練します。辞書バージョン。
MONAI フレームワークは敵対的生成ネットワークを簡単に設計し、訓練して評価するために使用できます。このノートブックは、ハンド CT スキャンの画像を再構築するために単純な GAN モデルを設計して訓練する MONAI コンポーネントを使用する実例を示します。
ネットワーク・アーキテクチャと損失関数についての詳細は MONAI Mednist GAN チュートリアル を読んでください。
Step 1: セットアップ
環境のセットアップ
!python -c "import monai" || pip install -q "monai-weekly[ignite, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline
from monai.utils import set_determinism
from monai.transforms import (
AddChannelD,
Compose,
LoadImageD,
RandFlipD,
RandRotateD,
RandZoomD,
ScaleIntensityD,
EnsureTypeD,
)
from monai.networks.nets import Discriminator, Generator
from monai.networks import normal_init
from monai.handlers import CheckpointSaver, MetricLogger, StatsHandler
from monai.engines.utils import GanKeys, default_make_latent
from monai.engines import GanTrainer
from monai.data import CacheDataset, DataLoader
from monai.config import print_config
from monai.apps import download_and_extract
import numpy as np
import torch
import matplotlib.pyplot as plt
import tempfile
import sys
import shutil
import os
import logging
インポートのセットアップ
print_config()
MONAI version: 0.6.0rc1+23.gc6793fd0 Numpy version: 1.20.3 Pytorch version: 1.9.0a0+c3d40fd MONAI flags: HAS_EXT = True, USE_COMPILED = False MONAI rev id: c6793fd0f316a448778d0047664aaf8c1895fe1c Optional dependencies: Pytorch Ignite version: 0.4.5 Nibabel version: 3.2.1 scikit-image version: 0.15.0 Pillow version: 7.0.0 Tensorboard version: 2.5.0 gdown version: 3.13.0 TorchVision version: 0.10.0a0 ITK version: 5.1.2 tqdm version: 4.53.0 lmdb version: 1.2.1 psutil version: 5.8.0 pandas version: 1.1.4 einops version: 0.3.0 For details about installing the optional dependencies, please visit: https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies
データディレクトリのセットアップ
MONAI_DATA_DIRECTORY 環境変数でディレクトリを指定できます。これは結果をセーブしてダウンロードを再利用することを可能にします。指定されない場合、一時ディレクトリが使用されます。
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
/workspace/data/medical
データセットをダウンロードする
データセットをダウンロードして展開します。
MedMNIST データセットは TCIA, RSNA Bone Age チャレンジ と NIH Chest X-ray データセット からの様々なセットから集められました。
データセットは Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) のお陰により Creative Commons CC BY-SA 4.0 ライセンス のもとで利用可能になっています。
MedNIST データセットを使用する場合、出典を明示してください、e.g. https://github.com/Project-MONAI/tutorials/blob/master/2d_classification/mednist_tutorial.ipynb。
resource = "https://drive.google.com/uc?id=1QsnnkvZyJPcbRoV_ArW8SnE1OTuoVbKE"
md5 = "0bc7306e7427e00ad1c5526a6677552d"
compressed_file = os.path.join(root_dir, "MedNIST.tar.gz")
data_dir = os.path.join(root_dir, "MedNIST")
if not os.path.exists(data_dir):
download_and_extract(resource, compressed_file, root_dir, md5)
hand_dir = os.path.join(data_dir, "Hand")
training_datadict = [
{"hand": os.path.join(hand_dir, filename)}
for filename in os.listdir(hand_dir)
]
print(training_datadict[:5])
[{'hand': '/workspace/data/medical/MedNIST/Hand/000317.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/002344.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/000816.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/004046.jpeg'}, {'hand': '/workspace/data/medical/MedNIST/Hand/003316.jpeg'}]
Step 2: MONAI コンポーネントを初期化する
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
set_determinism(0)
device = torch.device("cuda:0")
画像変換チェインを作成する
セーブされたディスク画像を利用可能なテンソルに変換するために処理パイプラインを定義します。
train_transforms = Compose(
[
LoadImageD(keys=["hand"]),
AddChannelD(keys=["hand"]),
ScaleIntensityD(keys=["hand"]),
RandRotateD(keys=["hand"], range_x=np.pi /
12, prob=0.5, keep_size=True),
RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5),
RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
EnsureTypeD(keys=["hand"]),
]
)
データセットとデーたローダを作成する
データを保持して訓練の間にバッチを提示します。
real_dataset = CacheDataset(training_datadict, train_transforms)
100%|██████████| 10000/10000 [00:09<00:00, 1000.72it/s]
batch_size = 300
real_dataloader = DataLoader(
real_dataset, batch_size=batch_size, shuffle=True, num_workers=10)
def prepare_batch(batchdata, device=None, non_blocking=False):
return batchdata["hand"].to(device=device, non_blocking=non_blocking)
generator と discriminator を定義する
基本的なコンピュータビジョン GAN ネットワークをライブラリからロードします。
# define networks
disc_net = Discriminator(
in_shape=(1, 64, 64),
channels=(8, 16, 32, 64, 1),
strides=(2, 2, 2, 2, 1),
num_res_units=1,
kernel_size=5,
).to(device)
latent_size = 64
gen_net = Generator(
latent_shape=latent_size,
start_shape=(latent_size, 8, 8),
channels=[32, 16, 8, 1],
strides=[2, 2, 2, 1],
)
gen_net.conv.add_module("activation", torch.nn.Sigmoid())
gen_net = gen_net.to(device)
# initialize both networks
disc_net.apply(normal_init)
gen_net.apply(normal_init)
# define optimizors
learning_rate = 2e-4
betas = (0.5, 0.999)
disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas)
gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas)
# define loss functions
disc_loss_criterion = torch.nn.BCELoss()
gen_loss_criterion = torch.nn.BCELoss()
real_label = 1
fake_label = 0
def discriminator_loss(gen_images, real_images):
real = real_images.new_full((real_images.shape[0], 1), real_label)
gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
realloss = disc_loss_criterion(disc_net(real_images), real)
genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
return (genloss + realloss) / 2
def generator_loss(gen_images):
output = disc_net(gen_images)
cats = output.new_full(output.shape, real_label)
return gen_loss_criterion(output, cats)
訓練ハンドラを作成する
モデル訓練の間に操作を実行します。
metric_logger = MetricLogger(
loss_transform=lambda x: {
GanKeys.GLOSS: x[GanKeys.GLOSS], GanKeys.DLOSS: x[GanKeys.DLOSS]},
metric_transform=lambda x: x,
)
handlers = [
StatsHandler(
name="batch_training_loss",
output_transform=lambda x: {
GanKeys.GLOSS: x[GanKeys.GLOSS],
GanKeys.DLOSS: x[GanKeys.DLOSS],
},
),
CheckpointSaver(
save_dir=os.path.join(root_dir, "hand-gan"),
save_dict={"g_net": gen_net, "d_net": disc_net},
save_interval=10,
save_final=True,
epoch_level=True,
),
metric_logger,
]
GanTrainer を作成する
敵対的学習のための MONAI ワークフロー・エンジン。GanTrainer によってコンポーネントはここで集められます。
Goodfellow et al. 2014 https://arxiv.org/abs/1406.2661 に基づいた訓練ループを使用します。
- ランダムな潜在コードから m 個の fakes を生成します。
- これらの fakes と現在のバッチ reals で D を更新します、d_train_steps 回反復されます。
- 新しいランダムな潜在コードから m fakes を生成します。
- discriminator フィードバックを使用してこれらの fakes で generator を更新します。
disc_train_steps = 5
max_epochs = 50
trainer = GanTrainer(
device,
max_epochs,
real_dataloader,
gen_net,
gen_opt,
generator_loss,
disc_net,
disc_opt,
discriminator_loss,
d_prepare_batch=prepare_batch,
d_train_steps=disc_train_steps,
g_update_latents=True,
latent_shape=latent_size,
key_train_metric=None,
train_handlers=handlers,
)
Step 3: 訓練の開始
trainer.run()
結果を評価する
G と D の損失カーブを崩れていないか調べます。
g_loss = [loss[1][GanKeys.GLOSS] for loss in metric_logger.loss]
d_loss = [loss[1][GanKeys.DLOSS] for loss in metric_logger.loss]
plt.figure(figsize=(12, 5))
plt.semilogy(g_loss, label="Generator Loss")
plt.semilogy(d_loss, label="Discriminator Loss")
plt.grid(True, "both", "both")
plt.legend()
plt.show()
画像再構築を見る
ランダムな潜在コードで訓練された generator の出力を見ます。
test_img_count = 10
test_latents = default_make_latent(test_img_count, latent_size).to(device)
fakes = gen_net(test_latents)
fig, axs = plt.subplots(2, (test_img_count // 2), figsize=(20, 8))
axs = axs.flatten()
for i, ax in enumerate(axs):
ax.axis("off")
ax.imshow(fakes[i, 0].cpu().data.numpy(), cmap="gray")
データディレクトリのクリーンアップ
一時ディレクトリが作成された場合にはディレクトリを削除します。
if directory is None:
shutil.rmtree(root_dir)
以上