エポック終了時に学習曲線図を保存するコールバックを作成(keras)
今回したこと
エポック終了時にそれまでの学習曲線を図として保存するコールバックを作成した.
通常はkeras.callbacks.History()を使用し,学習が終わってから1度のみhistoryを取得するが,途中で学習を止めた際にはhistoryが取得できないのでコールバックを自作した.
作成したコールバックは下記gistで公開している.
https://gist.github.com/catdance124/0976c5dbacdaeeaa7a6ac852c1f59cff
使う際には下記のように
from plot_history import PlotHistory dir_name = './dst' title = f'{model_name}_{optimizer_name}' ph = PlotHistory(save_interval=5, dir_name=dir_name, csv_output=True, title=title) cbs = [ph] model.fit_generator( generator=train_generator, steps_per_epoch=train_datagen.steps_per_epoch, epochs=args.epochs, callbacks=cbs )
上記だと
./dst/配下に下記のファイルが5エポックごとに保存される.
がっつり過学習しているが今回はその話はしない
コード説明
大まかに2つの要素からなる
plot_history関数
この関数はコールバック専用ではなく,通常の学習で得られるhistoryを渡しても描画できるように作成した.
辞書形式のhistoryを受け取り,trainのacc/loss,valがあればval_acc/val_lossをplotする.
オプションとして受け取ったhistoryをcsvアウトプットできる.後からhistoryを見たいときに便利.
下記は工夫点のみに触れるので全体のコードは載せていない.
def plot_history(history, begin_epoch=1, dir_name=None, csv_output=True, title='learning_curve'): # plot init settings val_exist = 'val_acc' in history.keys() plt.figure(figsize=(18, 7)) plt.suptitle(title, fontsize=16)
受け取ったhistoryにvalについての情報があるかをval_existに格納しておく.
今後val_existでval_acc/val_lossを描画するかどうかを判断する.
plt.suptitle()を用いて全体としてのグラフタイトルを表示する.
# plot accuracy settings plt.subplot(121) plt.title(f'model accuracy') plt.xlabel('epoch') plt.ylabel('accuracy') plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True)) # plot accuracy plt.plot(list(range(begin_epoch+1, len(history['acc'][begin_epoch:])+1)), history['acc'][begin_epoch:]) if val_exist: plt.plot(list(range(begin_epoch+1, len(history['val_acc'][begin_epoch:])+1)), history['val_acc'][begin_epoch:]) plt.legend(['acc', 'val_acc'], loc='lower right') else: plt.legend(['acc'], loc='lower right') # plot loss settings plt.subplot(122) plt.title(f'model loss') plt.xlabel('epoch') plt.ylabel('loss') plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True)) # plot loss plt.plot(list(range(begin_epoch+1, len(history['loss'][begin_epoch:])+1)), history['loss'][begin_epoch:]) if val_exist: plt.plot(list(range(begin_epoch+1, len(history['val_loss'][begin_epoch:])+1)), history['val_loss'][begin_epoch:]) plt.legend(['loss', 'val_loss'], loc='upper right') else: plt.legend(['loss'], loc='upper right')
ここはacc/lossでほとんど同じコード.valがあればvalも描画する.
plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True))
の部分でepoch軸は整数のみを取るようにしている.
# show or save? if dir_name is None: plt.show() else: plt.savefig(f'{dir_name}/learning_curve.png') if csv_output: values = [] for key in history.keys(): values.append(history[key]) values = np.array(values) with open(f'./{dir_name}/history.csv', 'w') as f_handle: writer = csv.writer(f_handle, lineterminator="\n") writer.writerows([history.keys()]) # header np.savetxt(f_handle, values.T, fmt="%.6f", delimiter=',') plt.close()
PlotHistoryクラス
keras.callbacks.Callbackを継承し,epoch_endにplot_history()を呼び出すようにする.
下記も工夫点のみに触れ,全体のコードは載せていない.
class PlotHistory(Callback): def __init__(self, save_interval=1, dir_name='./', csv_output=False, title=''): def on_train_begin(self, logs=None): self.history = {} self.history['loss'] = [] self.history['acc'] = [] self.do_validation = self.params['do_validation'] if self.do_validation: self.history['val_loss'] = [] self.history['val_acc'] = []
学習開始時にloss/acc | val_loss/val_accの空リストを持っておく.
また,valが渡されるかどうかはself.params['do_validation']
で判断することができる.
def on_epoch_end(self, epoch, logs=None): self.history['loss'].append(logs.get('loss')) self.history['acc'].append(logs.get('acc')) if self.do_validation: self.history['val_loss'].append(logs.get('val_loss')) self.history['val_acc'].append(logs.get('val_acc')) if (epoch-1) % self.interval == 0: plot_history(history=self.history, dir_name=self.dir_name, csv_output=self.csv_output, title=self.title) def on_train_end(self, logs=None): plot_history(history=self.history, dir_name=self.dir_name, csv_output=self.csv_output, title=self.title)
毎エポックの終わりon_epoch_endでそのエポックのloss/ accをリストにappendしていく.
指定されたintervalでplot_history()を呼び出すようにした.
終わりに
今回はエポック終了時に学習曲線図を保存するコールバックを作成した.
keras.callbacks.Callbackを継承すればエポック終わり・バッチ学習終わりなど好きなタイミングで処理を行うことができる.
コールバックとして実装しておけば学習モデル・タスクを問わず使い回せることが多いのでとても便利だと思う.
実装したコードはここに
https://gist.github.com/catdance124/0976c5dbacdaeeaa7a6ac852c1f59cff