ホーム » HuggingFace Transformers » HuggingFace Transformers 4.17 : Tutorials : 事前訓練済みモデルの再調整

HuggingFace Transformers 4.17 : Tutorials : 事前訓練済みモデルの再調整

HuggingFace Transformers 4.17 : Tutorials : 事前訓練済みモデルの再調整 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/18/2022 (v4.17.0)

* 本ページは、HuggingFace Transformers の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

HuggingFace Transformers : Tutorials : 事前訓練済みモデルの再調整

事前訓練モデルの使用は大きな利点があります。それは計算コスト、カーボン・フットプリントを削減して、スクラッチから訓練する必要なく最先端のモデルを利用することを可能にします。 Transformers は広範囲のタスクに対して数千の事前訓練済みモデルへのアクセスを提供しています。事前訓練済みモデルを使用するとき、貴方のタスクに固有のデータセット上でそれを訓練します。これは再調整 (= fine-tuning, 微調整) として知られる、非常に強力な訓練テクニックです。このチュートリアルでは、事前訓練済みモデルを貴方の選択した深層学習フレームワークで再調整します。

  • Transformers Trainer で事前訓練済みモデルを再調整する。

  • TensorFlow with Keras で事前訓練済みモデルを再調整する。

  • native PyTorch で事前訓練済みモデルを再調整する。

 

データセットを準備する

事前訓練済みモデルを再調整できる前に、データセットをダウンロードしてそれを訓練のために準備します。前のチュートリアルは訓練のためにデータを処理する方法を紹介しましたが、今はそれらのスキルをテストする機会を得ています!

Yelp Reviews データセットをロードすることから始めます :

from datasets import load_dataset

dataset = load_dataset("yelp_review_full")
dataset.keys()
dict_keys(['train', 'test'])
len(dataset['train']), len(dataset['test'])
(650000, 50000)
dataset['train'][0]
{'label': 4,
 'text': "dr. goldberg offers everything i look for in a general practitioner.  he's nice and easy to talk to without being patronizing; he's always on time in seeing his patients; he's affiliated with a top-notch hospital (nyu) which my parents have explained to me is very important in case something happens and you need surgery; and you can get referrals to see specialists without having to see him first.  really, what more do you need?  i'm sitting here trying to think of any complaints i have about him, but i'm really drawing a blank."}
dataset['test'][0]
{'label': 0,
 'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'}

今ではご存知のように、テキストを処理して、可変なシークエンス長を扱うためにパディングと truncation ストラテジーを含めるためにはトークナイザーが必要です。データセットを 1 ステップで処理するため、データセット全体に対して前処理関数を適用する Datasets map メソッドを使用します。

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets['train'][0].keys()
dict_keys(['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'])
print(tokenized_datasets['train'][0])
{'label': 4, 'text': "dr. goldberg offers everything i look for in a general practitioner.  he's nice and easy to talk to without being patronizing; he's always on time in seeing his patients; he's affiliated with a top-notch hospital (nyu) which my parents have explained to me is very important in case something happens and you need surgery; and you can get referrals to see specialists without having to see him first.  really, what more do you need?  i'm sitting here trying to think of any complaints i have about him, but i'm really drawing a blank.",
'input_ids': [101, 173, 1197, 119, 2284, 2953, 3272, 1917, 178, 1440, 1111, 1107, 170, 1704, 22351, 119, 1119, 112, 188, 3505, 1105, 3123, 1106, 2037, 1106, 1443, 1217, 10063, 4404, 132, 1119, 112, 188, 1579, 1113, 1159, 1107, 3195, 1117, 4420, 132, 1119, 112, 188, 6559, 1114, 170, 1499, 118, 23555, 2704, 113, 183, 9379, 114, 1134, 1139, 2153, 1138, 3716, 1106, 1143, 1110, 1304, 1696, 1107, 1692, 1380, 5940, 1105, 1128, 1444, 6059, 132, 1105, 1128, 1169, 1243, 5991, 16179, 1106, 1267, 18137, 1443, 1515, 1106, 1267, 1140, 1148, 119, 1541, 117, 1184, 1167, 1202, 1128, 1444, 136, 178, 112, 182, 2807, 1303, 1774, 1106, 1341, 1104, 1251, 11344, 178, 1138, 1164, 1140, 117, 1133, 178, 112, 182, 1541, 4619, 170, 9153, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

望むのであれば、かかる時間を削減するために (その上で) 再調整する完全なデータセットの小さいサブセットを作成することができます :

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
len(small_train_dataset), len(small_eval_dataset)
(1000, 1000)

 

Trainer による再調整

Transformers は Transformers モデルを訓練するために最適化された Trainer クラスを提供していて、貴方自身の訓練ループを手動で書くことなく訓練を開始することを容易にします。Trainer API はロギング, 勾配集積 (= accumulation), そして混合精度のような幅広い訓練オプションと機能をサポートします。

モデルをロードすることから始めて想定されるラベル数を指定します。Yelp Review dataset カード から、5 つのラベルがあることがわかります :

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Note : 事前訓練済み重みの一部は使用されずに幾つかの重みはランダムに初期化されたことについて警告を見るでしょう。心配しないでください、これは完全に普通のことです!BERT の事前訓練済みヘッドが捨てられ、ランダムに初期化された分類ヘッドで置き換えられます。シークエンス分類タスクでこの新しいモデルヘッドを再調整し、事前訓練済みモデルの知識をそれに転送します。

 

訓練ハイパーパラメータ

次に、TrainingArguments クラスを作成します、これは調整可能な総てのハイパーパラメータと、異なる訓練オプションを有効にするフラグを含みます。このチュートリアルのためにはデフォルトの訓練 ハイパーパラメータ で開始できますが、最適な設定を見つけるためにこれらで自由に実験してください。

訓練からチェックポイントをどこにセーブするか指定します :

from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")

 

メトリクス

Trainer は訓練の間に自動的にはモデル性能を評価しません。メトリクスを計算して報告する関数を Trainer に渡す必要があります。 Datasets ライブラリは load_metric (詳細はこの チュートリアル 参照) 関数でロード可能な単純な accuracy 関数を提供しています :

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

予測の精度を計算するためにメトリックの compute を呼び出します。予測を compute に渡す前に、予測をロジットに変換する必要があります (総ての Transformers モデルはロジットを返すことを忘れないでください) :

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

再調整の間に評価メトリックをモニタリングしたいのであれば、各エポックの最後に評価メトリックをレポートするために訓練引数で evaluation_strategy パラメータを指定します :

from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

 

Trainer

モデル, 訓練引数, 訓練とテストデータセット, そして評価関数で Trainer オブジェクトを作成します :

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

そして train() を呼び出してモデルを再調整します :

trainer.train()

 

Keras で再調整

Transformers モデルはまた Keras API による TensorFlow での訓練もサポートしています。再調整できる前に幾つかの変更を行なう必要があるだけです。

 

データセットを TensorFlow 形式に変換する

DefaultDataCollator はモデルがその上で訓練するためにテンソルをバッチに集めます。TensorFlow テンソルを返すために return_tensors を指定することを確実にしてください :

from transformers import DefaultDataCollator

data_collator = DefaultDataCollator(return_tensors="tf")

Note : Trainer はデフォルトで DataCollatorWithPadding を使用しますので、データ collator (照合機) を明示的に指定する必要はありません。

次に、トークン化されたデータセットを to_tf_dataset メソッドで TensorFlow データセットに変換します。columns で入力を、label_cols でラベルを指定します :

tf_train_dataset = small_train_dataset.to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["labels"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

tf_validation_dataset = small_eval_dataset.to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["labels"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=8,
)

 

Compile と fit

想定されるラベル数と共に TensorFlow モデルをロードします :

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification

model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

そして他の Keras モデルでそうするように compile してから fit でモデルを再調整します :

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=tf.metrics.SparseCategoricalAccuracy(),
)

model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3)

 

native PyTorch で再調整

Trainer は訓練ループを処理してモデルを単一行のコードで再調整することを可能にします。独自の訓練ループを書くことを好むユーザについては、native PyTorch で Transformers モデルを再調整することもできます。

この時点で、ノートブックを再起動するか、あるいは何某かのメモリを解放するために以下のコードを実行する必要があるかもしれません :

del model
del pytorch_model
del trainer
torch.cuda.empty_cache()

 
次に、tokenized_dataset を訓練用に準備するため手動で後処理します。

  1. text カラムを削除します、モデルは入力として raw テキストを受け取らないからです :
    tokenized_datasets = tokenized_datasets.remove_columns(["text"])
    

  2. label カラムを labels に名前変更します、モデルは引数が labels と命名されていることを想定しているからです :
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
    

  3. リストの代わりに PyTorch テンソルを返すようにデータセットの形式を設定します :
    tokenized_datasets.set_format("torch")
    

それから再調整をスピードアップするために前に示されたようにデータセットの小さいなサブセットを作成します :

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

 

DataLoader

データのバッチに対してイテレートできるように、訓練とテストデータセットのために DataLoader を作成します :

from torch.utils.data import DataLoader

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

想定されるラベルの数と共にモデルをロードします :

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

 

Optimizer と学習率スケジューラ

モデルを再調整するために optimizer と学習率スケジューラを作成します。PyTorch からの AdamW optimizer を使用しましょう :

from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

Trainer からのデフォルトの学習率スケジューラを作成します :

from transformers import get_scheduler

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

最後に、アクセス可能な GPU が持つならばそれを利用するように device を指定します。そうでないなら、CPU 上の訓練は数分ではなく数時間かかるかもしれません。

import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

Great, now you are ready to train!

 

訓練ループ

訓練進捗を追跡するため、訓練ステップ数に対してプログレスバーを追加する tqdm ライブラリを使用します :

from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

 

メトリクス

Trainer に評価関数を追加する必要があるのと同様に、貴方自身の訓練ループを書く時に同じことを行なう必要があります。しかし各エポックの最後にメトリックを計算してレポートする代わりに、今回は add_batch で総てのバッチを蓄積して最後の最後にメトリックを計算します。

metric = load_metric("accuracy")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()
 

以上



ClassCat® Chatbot

人工知能開発支援
クラスキャットは 人工知能研究開発支援 サービスを提供しています :
  • テクニカルコンサルティングサービス
  • 実証実験 (プロトタイプ構築)
  • アプリケーションへの実装
  • 人工知能研修サービス
◆ お問合せ先 ◆
(株)クラスキャット
セールス・インフォメーション
E-Mail:sales-info@classcat.com

カテゴリー