Ahogrammer

Deep Dive Into NLP, ML and Cloud

KerasのLambda層でreshapeしたとき、保存に失敗する(場合がある)話

TL;DR

  • keras.backendのreshapeを使ってLambda層でreshapeしたい
  • reshapeのshapeにテンソルを指定するとモデルの保存(save)に失敗する
  • saveではなくsave_weightsを使うと保存できる

背景

まず問題が起きる状況について説明しておきたい。 簡単にまとめると以下のような状況だ。

  • ミニバッチごとにshapeの異なるデータを入力する
  • モデルの途中にreshapeが入っている
  • reshapeには入力のshapeを使用する

ミニバッチごとにshapeの異なるデータを入力するというのは、言語処理なんかではよく行われると思う。 たとえば、LSTMでの計算を効率化するために、ミニバッチ内の最大の系列長に合わせてpaddingを行い、系列長を揃える場合が挙げられる。

上記の状況を示すプログラムは以下の通り:

import keras.backend as K
from keras.layers import Input, Lambda

emb_size = 25
x = Input(batch_shape=(None, None, None, emb_size))
s = K.shape(x)
pred = Lambda(K.reshape, 
              arguments={'shape': (-1, s[-2], emb_size)})(x)
model = Model(inputs=[x], outputs=pred)
model.compile('sgd', 'categorical_crossentropy')
model.save('model.h5')

実際にはこんなモデルはありえないが、説明のためのtoy programだと了承していただきたい。

コードについて簡単に説明しておく。 まず、入力のうち3つの次元は動的に決まる。このうち一つはバッチサイズである。 この入力に対して、keras.backendのshapeを使ってshapeを取得している。ちなみに取得したshapeはTensorで表される:

>>> s = K.shape(x)
>>> s
<tf.Tensor 'Shape:0' shape=(4,) dtype=int32>

次に、取得したshapeとbackendのreshapeを使って入力を変換している。 Reshape層を使わないのは、Reshape層のshapeにTensorを指定できないからである。 Tensorを指定すると「Tensorはboolみたいには使用できないよ」という例外が出る:

>>> s = K.shape(x)
>>> Reshape(target_shape=(-1, s[-2], emb_size))(x)
Traceback (most recent call last):
...
    raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

それならば、K.int_shapeで得られるTensorではないshapeを指定すればよいではないか?という話になるがそれもできない。 というのも、K.int_shapeで得られたshapeのs[-2]はNoneであり、ReshapeにはNoneを指定できないからである。 そのときに発生する例外は以下の通り:

>>> s = K.int_shape(x)
>>> Reshape(target_shape=(-1, s[-2], emb_size))(x)
Traceback (most recent call last):
...
TypeError: unorderable types: NoneType() < int()

したがって、keras.backendのreshapeを使うことになる。 ただ、Kerasの場合、レイヤーを接続してモデルを構築しないと「keras_historyがないよ」という例外が発生する。 したがって、Lambdaでreshapeをラップしている。

pred = Lambda(K.reshape, 
              arguments={'shape': (-1, s[-2], emb_size)})(x)

このtoy programではcompileはうまくいく。また学習も順調に進む。

model = Model(inputs=[x], outputs=pred)
model.compile('sgd', 'categorical_crossentropy')

問題

しかし以下のようにsaveメソッドを使ってモデルを保存すると問題が起きる。

model.save('model.h5')

確認しているだけでも以下の5つの例外がランダム(に見えるよう)に発生する:

  • TypeError: cannot serialize ‘_io.TextIOWrapper’ object
  • TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()
  • AttributeError: ‘NoneType’ object has no attribute ‘update’
  • TypeError: cannot deepcopy this pattern object
  • TypeError: can’t pickle module objects

調べていくと、これらの例外はreshape時にテンソルを指定すると、Lambdaでシリアライズができないことに起因するようだ。

解決策

save_weightsでネットワークの重みだけ保存することで解決する。

model.save_weights('model_weights.h5')

まとめ

  • keras.backendのreshapeを使ってLambda層でreshapeしたい
  • reshapeのshapeにテンソルを指定するとモデルの保存(save)に失敗する
  • saveではなくsave_weightsを使うと保存できる