top of page

分類 Classification

公開·1名のメンバー

事前学習済みモデルを使うとジブリの分類の精度が上がった


事前学習済みのモデルとしてResNet18というモデルを使います。これは224×224画素のデータで学習してありますが、適合的平均プーリング関数(AdaptiveAvgPool2d)というレイヤー関数が組み込まれているおかげで、画像の画素数によらず、入力画像として受け付けてくれるらしいです。


PyTorch用にデータも作ったし、あとはモデルを動かすだけなんだけども、

メモリが足りないとのエラー。


訓練用データは470枚程度なので一発でやれるかと甘く見ていました。こま切れにしてバッチ処理します。

PyTorchには僕のような見通しの甘い人のために、訓練用データと訓練用ラベルをくっつけて、データを作るモジュールが用意されています。

train_dataset = torch.utils.data.TensorDataset(x_train, y_train)

これで画像データとナウシカ、キキ、シータが0,1,2の整数で入ったtrain_datasetができました。あとはこれをデータローダーに渡してfor文で回すことでバッチごとの作業をしてくれます。


後は見守るだけ....かと思いきや、またまたエラー

cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

なんか嫌なエラーが出た。このCUDA関係のエラーが出ると解決方法がさっぱりわかりません。

ネットで調べると、NVIDIAのドライバーをアップデートしたら動いたとの記事があった。こんな単純なので直るのかと半信半疑でやってみたところ、動きました。簡単すぎて拍子抜け。



結果は、

前回自作したモデル

 正解   予測   数

シータ  シータ   30

     キキ    5

     キキ    55

キキ   シータ   7

     キキ   51

ナウシカ シータ    9

    ナウシカ  21(前回は31と書いていましたが21の間違いでした。)


今回の訓練済みモデル(resnet18)のファインチューニング

 正解   予測   数

シータ シータ   40

キキ シータ 3

   キキ   55

ナウシカ シータ   7

     ナウシカ 23


accuracyは0.8台前半から0.913へ劇的に向上しています(でも少し過学習気味)。

正解がキキのものは間違いが7から3に半減しました。正解がナウシカも間違いが9から7に減少しました。ナウシカの精度向上が控えめなのが気になります。シータについては完全に正解に見えますが、シータと予測したものが全部で50個あり、そのうち本当にシータだったものは40個ですので適合率は0.8です。キキは全部で58個、そのうち55個を予測できましたので再現率が約0.95となります。キキが最も選びやすかったようです。


今やTensorflowやPytorchを使って一からモデルを作って学習させることは稀で、事前学習したモデルを使って作業するのはもう常識となっているそうです。

今回の結果を見るとそれも分かるような気がします。


これなら深層学習について学習しなくても、学習済みモデルをチョコチョコっといじるだけで高精度のものが作れそうな気がします。

では内燃機関の仕組みを知らなくても車の運転ができるように、深層学習をやるのに合成関数とか、誤差逆伝播、勾配降下法などの知識はもう必要ないのでしょうか。


僕はそうは思いませんが、いかがでしょうか。

閲覧数:180

新規投稿をお知らせします。

登録ありがとうございます。

© 2023 by Healthy Together. Proudly created with Wix.com

bottom of page