Ahogrammer

Deep Dive Into NLP, ML and Cloud

デバッガを使ってKerasのモデルをデバッグする

3行まとめ

  • Keras で作成したモデルをデバッグしたい。
  • Keras には標準でデバッガが用意されていない。
  • Keras の Session オブジェクトを tfdbg でラップしてデバッグする。

オープンソースニューラルネットワークライブラリである Keras は計算グラフに基づいています。 一般的な Keras のプログラムは以下の段階から構成されています。

  1. Sequentialモデルか Functional API を用いて、モデルを構築する
  2. 計算グラフを実行し、学習や予測を行う

ここで問題となるのが、2番目の計算グラフの実行中にエラーやバグが発生した場合、デバッグが難しいことです。 なぜなら、内部構造や内部状態が外から見えないため、どのような状況で計算に失敗したのかわからないからです。

Kerasには今のところdebuggerは存在しませんが、TensorFlowには tfdbg と呼ばれるdebuggerが用意されています。 tfdbgを使うことで、学習中や予測中の内部構造や内部状態を明らかにすることができます。

本記事では tfdbg をKerasで使ってデバッグを行う方法を紹介します。 tfdbgを使うため、KerasのバックエンドにはTensorFlowを想定しています。 では、短いサンプルコードを通してデバッガの動作を見てみます。 以下のコードは、簡単な評判分析のモデルを Sequentialモデルで定義し、学習します。

import keras.backend as K
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Embedding, LSTM
from keras.datasets import imdb
from tensorflow.python import debug as tf_debug

sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)

max_features = 20000
maxlen = 80
batch_size = 32

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)

model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=batch_size,
          epochs=15, validation_data=(x_test, y_test))
score, acc = model.evaluate(x_test, y_test,
                            batch_size=batch_size)

この例でわかるように、Keras の Session オブジェクトをデバッグ用のクラス(LocalCLIDebugWrapperSession)でラップしています。 これにより、計算グラフを実行した際に、 tfdbg インターフェースが起動します。この インターフェース内でマウスクリックやコマンドを使うと、その後のコードの実行や、グラフのノードや属性の調査をできます。詳しい使い方については、tfdbgのドキュメントを見てください。

f:id:Hironsan:20170915085618p:plain

まとめ

  • Kerasで作成したモデルをデバッグしたい。
  • Kerasには標準でデバッガが用意されていない。
  • Kerasのsessionオブジェクトをtfdbgでラップしてデバッグする。

参考資料

TensorFlow のデバッガ tfdbg に関する公式ドキュメント
Debugging TensorFlow Programs

Google Developers blog による tfdbg の解説
Debug TensorFlow Models with tfdbg

Keras backends の公式ドキュメント
Keras backends