KerasのLambda層でreshapeしたとき、保存に失敗する(場合がある)話
TL;DR
- keras.backendのreshapeを使ってLambda層でreshapeしたい
- reshapeのshapeにテンソルを指定するとモデルの保存(save)に失敗する
- saveではなくsave_weightsを使うと保存できる
背景
まず問題が起きる状況について説明しておきたい。 簡単にまとめると以下のような状況だ。
- ミニバッチごとにshapeの異なるデータを入力する
- モデルの途中にreshapeが入っている
- reshapeには入力のshapeを使用する
ミニバッチごとにshapeの異なるデータを入力するというのは、言語処理なんかではよく行われると思う。 たとえば、LSTMでの計算を効率化するために、ミニバッチ内の最大の系列長に合わせてpaddingを行い、系列長を揃える場合が挙げられる。
上記の状況を示すプログラムは以下の通り:
import keras.backend as K from keras.layers import Input, Lambda emb_size = 25 x = Input(batch_shape=(None, None, None, emb_size)) s = K.shape(x) pred = Lambda(K.reshape, arguments={'shape': (-1, s[-2], emb_size)})(x) model = Model(inputs=[x], outputs=pred) model.compile('sgd', 'categorical_crossentropy') model.save('model.h5')
実際にはこんなモデルはありえないが、説明のためのtoy programだと了承していただきたい。
コードについて簡単に説明しておく。 まず、入力のうち3つの次元は動的に決まる。このうち一つはバッチサイズである。 この入力に対して、keras.backendのshapeを使ってshapeを取得している。ちなみに取得したshapeはTensorで表される:
>>> s = K.shape(x) >>> s <tf.Tensor 'Shape:0' shape=(4,) dtype=int32>
次に、取得したshapeとbackendのreshapeを使って入力を変換している。 Reshape層を使わないのは、Reshape層のshapeにTensorを指定できないからである。 Tensorを指定すると「Tensorはboolみたいには使用できないよ」という例外が出る:
>>> s = K.shape(x) >>> Reshape(target_shape=(-1, s[-2], emb_size))(x) Traceback (most recent call last): ... raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. " TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
それならば、K.int_shapeで得られるTensorではないshapeを指定すればよいではないか?という話になるがそれもできない。 というのも、K.int_shapeで得られたshapeのs[-2]はNoneであり、ReshapeにはNoneを指定できないからである。 そのときに発生する例外は以下の通り:
>>> s = K.int_shape(x) >>> Reshape(target_shape=(-1, s[-2], emb_size))(x) Traceback (most recent call last): ... TypeError: unorderable types: NoneType() < int()
したがって、keras.backendのreshapeを使うことになる。 ただ、Kerasの場合、レイヤーを接続してモデルを構築しないと「keras_historyがないよ」という例外が発生する。 したがって、Lambdaでreshapeをラップしている。
pred = Lambda(K.reshape, arguments={'shape': (-1, s[-2], emb_size)})(x)
このtoy programではcompileはうまくいく。また学習も順調に進む。
model = Model(inputs=[x], outputs=pred) model.compile('sgd', 'categorical_crossentropy')
問題
しかし以下のようにsaveメソッドを使ってモデルを保存すると問題が起きる。
model.save('model.h5')
確認しているだけでも以下の5つの例外がランダム(に見えるよう)に発生する:
- TypeError: cannot serialize ‘_io.TextIOWrapper’ object
- TypeError: object.new(PyCapsule) is not safe, use PyCapsule.new()
- AttributeError: ‘NoneType’ object has no attribute ‘update’
- TypeError: cannot deepcopy this pattern object
- TypeError: can’t pickle module objects
調べていくと、これらの例外はreshape時にテンソルを指定すると、Lambdaでシリアライズができないことに起因するようだ。
解決策
save_weightsでネットワークの重みだけ保存することで解決する。
model.save_weights('model_weights.h5')
まとめ
- keras.backendのreshapeを使ってLambda層でreshapeしたい
- reshapeのshapeにテンソルを指定するとモデルの保存(save)に失敗する
- saveではなくsave_weightsを使うと保存できる