IT男のアジャイルキャリア奮闘記

IT業界でキャリアを築きたい貴方にお役立ち情報を発信していきます

MENU

多クラス分類BERTを利用した事前学習モデルのファインチューニングTips

BERTを利用したテキスト多クラス分類モデルのファインチューニングについてTipsを紹介します。

 

目次

参考サイト

今回以下の英語サイトを参考にしました。

www.thepythoncode.com 

環境

Google Colabratoryです。

colab.research.google.com

目的

BERTで用意されている事前学習モデルをファインチューニング(転移学習に近い)し、与えられるテストデータに対してより分類精度の高いモデルの作成を目指します。

 

ライブラリのインストール&インポート

参考サイト同様に実行します。

f:id:hungrycrazyman:20210611073909p:plain

データ&モデルのロード

モデルをロードする前にどのモデルをロードするかをパラメータで指定してあげます。

f:id:hungrycrazyman:20210611074051p:plain

 

トークナイザーをロードします。トークナイザーは訓練用のテキストデータに対し、トークン化(テキスト情報をベクトル化)するためのものです。

f:id:hungrycrazyman:20210611074234p:plain

 

データをロードします。参考サイトではライブラリを利用して訓練データ・検証データ・訓練データラベル・検証データラベル・ターゲット名リスト(targetnames)を用意しています。モデルを学習させるためのデータを使用する際はどんなデータがINPUTになるかを確認することは必須ですのでデータを個別に確認していきましょう。

f:id:hungrycrazyman:20210611074609p:plain

train_texts, valid_textsは各要素に文章を格納するリストである必要があります。

データの中身は下記のようなイメージです。

['I have a pen', 'I have an apple', 'Oh, APPLEPEN!']

 

train_labels, valid_labelsは元々がテキストラベルを数値にカテゴライズした数値ラベルが要素として格納されているリストである必要があります。

データの中身は下記のようなイメージです。

[1,0,2]

 

target_namesは最後にモデルの推論で出力した数値ラベルをテキストラベルに変換する際に利用しますので、数値ラベル昇順で概要するテキストラベルが要素として格納されたリストである必要があります。

例えば、['apple','orange','banana']というtrain_textsがあり、train_labelsは[1,0,2]となっていたとします。学習したモデルに推論させた時に出力が1と出た時、テキストとして'apple'を返したいので['orange','apple','banana']となっているtarget_namesリストを利用することで要素1を指定したときにtarget_namesから'apple'を取り出すことができます。

 

テキストデータはこのままでは学習に使えないので先ほど作成したトークナイザーで数値ベクトルにトークン化します。

f:id:hungrycrazyman:20210611080330p:plain

 

この後のモデルのチューニングではpytorchというフレームワークを利用するため、pytorch用のデータセットにデータを変換してあげる必要があります。

f:id:hungrycrazyman:20210611080507p:plain

 

モデルのファインチューニング(トレーニング)

事前学習モデルをダウンロードします。

f:id:hungrycrazyman:20210611080611p:plain

ここでパラメーターとして最初に指定したmodel_nameと作成したtarget_namesを挿入します。

 

モデルの学習で使用する評価指標のCall Back関数を設定します。

f:id:hungrycrazyman:20210611081023p:plain

 

レーニングする前にハイパーパラメータを設定します。

f:id:hungrycrazyman:20210611080032p:plain

 

レーニングするトレーナーインスタンスを作成します。

f:id:hungrycrazyman:20210611080143p:plain

この時にこれまで作成したmodel, args(ハイパーパラメータ),dataset,metrics(評価指標)をパラメーターとして挿入します。

 

モデルを学習(ファインチューニング)します。

f:id:hungrycrazyman:20210611081112p:plain

google colabで実行中の様子

f:id:hungrycrazyman:20210611082020p:plain

 

ファインチューニング完了!

Loss(損失関数) が下がっていき、Accuracy(精度)が上がってますね。

f:id:hungrycrazyman:20210611083357p:plain



 学習したモデルを評価してみます。

f:id:hungrycrazyman:20210611081250p:plain

私の例では以下のようなOutputになりました。まだまだ精度は低いですね。。。笑

f:id:hungrycrazyman:20210611083500p:plain

 

最後にファインチューニングしたモデルを保存しましょう!

f:id:hungrycrazyman:20210611083544p:plain

 

次回は推論編を発信しようと思います!

 

ではまた、会いましょう。