wide and deep

エポック終了時に学習曲線図を保存するコールバックを作成(keras)

今回したこと

エポック終了時にそれまでの学習曲線を図として保存するコールバックを作成した.
通常はkeras.callbacks.History()を使用し,学習が終わってから1度のみhistoryを取得するが,途中で学習を止めた際にはhistoryが取得できないのでコールバックを自作した.

作成したコールバックは下記gistで公開している.
https://gist.github.com/catdance124/0976c5dbacdaeeaa7a6ac852c1f59cff

使う際には下記のように

from plot_history import PlotHistory
dir_name = './dst'
title = f'{model_name}_{optimizer_name}'
ph = PlotHistory(save_interval=5, dir_name=dir_name, csv_output=True, title=title)
cbs = [ph]
model.fit_generator(
    generator=train_generator,
    steps_per_epoch=train_datagen.steps_per_epoch,
    epochs=args.epochs,
    callbacks=cbs
)

上記だと
./dst/配下に下記のファイルが5エポックごとに保存される.

f:id:catdance124:20190919193530p:plain
実際に作成された学習曲線図
がっつり過学習しているが今回はその話はしない

コード説明

大まかに2つの要素からなる

  • 渡されたhistory(dic)をmatplotlibで画像化するplot_history関数
  • コールバック用のPlotHistoryクラス

plot_history関数

この関数はコールバック専用ではなく,通常の学習で得られるhistoryを渡しても描画できるように作成した.
辞書形式のhistoryを受け取り,trainのacc/loss,valがあればval_acc/val_lossをplotする.
オプションとして受け取ったhistorycsvアウトプットできる.後からhistoryを見たいときに便利.
下記は工夫点のみに触れるので全体のコードは載せていない.

def plot_history(history, begin_epoch=1, dir_name=None, csv_output=True, title='learning_curve'):
    # plot init settings
    val_exist = 'val_acc' in history.keys()
    plt.figure(figsize=(18, 7))
    plt.suptitle(title, fontsize=16)

受け取ったhistoryにvalについての情報があるかをval_existに格納しておく.
今後val_existでval_acc/val_lossを描画するかどうかを判断する.
plt.suptitle()を用いて全体としてのグラフタイトルを表示する.

    # plot accuracy settings
    plt.subplot(121)
    plt.title(f'model accuracy')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True))
    # plot accuracy
    plt.plot(list(range(begin_epoch+1, len(history['acc'][begin_epoch:])+1)), history['acc'][begin_epoch:])
    if val_exist:
        plt.plot(list(range(begin_epoch+1, len(history['val_acc'][begin_epoch:])+1)), history['val_acc'][begin_epoch:])
        plt.legend(['acc', 'val_acc'], loc='lower right')
    else:
        plt.legend(['acc'], loc='lower right')
    
    # plot loss settings
    plt.subplot(122)
    plt.title(f'model loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True))
    # plot loss
    plt.plot(list(range(begin_epoch+1, len(history['loss'][begin_epoch:])+1)), history['loss'][begin_epoch:])
    if val_exist:
        plt.plot(list(range(begin_epoch+1, len(history['val_loss'][begin_epoch:])+1)), history['val_loss'][begin_epoch:])
        plt.legend(['loss', 'val_loss'], loc='upper right')
    else:
        plt.legend(['loss'], loc='upper right')

ここはacc/lossでほとんど同じコード.valがあればvalも描画する.
plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True))の部分でepoch軸は整数のみを取るようにしている.

    # show or save?
    if dir_name is None:
        plt.show()
    else:
        plt.savefig(f'{dir_name}/learning_curve.png')
        if csv_output:
            values = []
            for key in history.keys():
                values.append(history[key])
            values = np.array(values)
            with open(f'./{dir_name}/history.csv', 'w') as f_handle:
                writer = csv.writer(f_handle, lineterminator="\n")
                writer.writerows([history.keys()])  # header
                np.savetxt(f_handle, values.T, fmt="%.6f", delimiter=',')
    plt.close()

出力ディレクトリが指定されていれば画像として出力する.
更にcsv_outputの指定があればcsvも出力する.

PlotHistoryクラス

keras.callbacks.Callbackを継承し,epoch_endにplot_history()を呼び出すようにする.
下記も工夫点のみに触れ,全体のコードは載せていない.

class PlotHistory(Callback):
    def __init__(self, save_interval=1, dir_name='./', csv_output=False, title=''):

    def on_train_begin(self, logs=None):
        self.history = {}
        self.history['loss'] = []
        self.history['acc'] = []
        self.do_validation = self.params['do_validation']
        if self.do_validation:
            self.history['val_loss'] = []
            self.history['val_acc'] = []

学習開始時にloss/acc | val_loss/val_accの空リストを持っておく.
また,valが渡されるかどうかはself.params['do_validation']で判断することができる.

    def on_epoch_end(self, epoch, logs=None):
        self.history['loss'].append(logs.get('loss'))
        self.history['acc'].append(logs.get('acc'))
        if self.do_validation:
            self.history['val_loss'].append(logs.get('val_loss'))
            self.history['val_acc'].append(logs.get('val_acc'))
        if (epoch-1) % self.interval == 0:
            plot_history(history=self.history, dir_name=self.dir_name, csv_output=self.csv_output, title=self.title)

    def on_train_end(self, logs=None):
        plot_history(history=self.history, dir_name=self.dir_name, csv_output=self.csv_output, title=self.title)

毎エポックの終わりon_epoch_endでそのエポックのloss/ accをリストにappendしていく.
指定されたintervalでplot_history()を呼び出すようにした.

終わりに

今回はエポック終了時に学習曲線図を保存するコールバックを作成した.
keras.callbacks.Callbackを継承すればエポック終わり・バッチ学習終わりなど好きなタイミングで処理を行うことができる.
コールバックとして実装しておけば学習モデル・タスクを問わず使い回せることが多いのでとても便利だと思う.
実装したコードはここに
https://gist.github.com/catdance124/0976c5dbacdaeeaa7a6ac852c1f59cff

参考にしたサイト

qiita.com
minus9d.hatenablog.com
keras.io