学習済みELMoをAllenNLPで読み込もうとした
Stockmark社が公開している学習済みELMoをAllenNLPで読み込もうとして、ちょっと足りずにできなかった話です。
2019/12/15追記
リベンジしました
学習済みELMoをAllenNLPで読み込む -りたーんず!-
ELMoとは
ELMoの日本語での解説は多く出ているのでここではあまり深くは行いません。
ELMoは双方向LSTMを使って学習させた言語モデルで、このLSTMの出力をtokenに対する文脈を考慮したベクトルとして扱います。
このベクトルを単語ベクトルやRNNの隠れそうのベクトルにconcatするだけでタスクの精度向上を狙えるというものです。
モデルのアーキテクチャに依存せず、ただconcatするだけで利用できるためとても汎用性が高いです。
日本語の学習済みELMoについてはStockmark社が提供しているものやHIT-SCIR/ELMoForManyLangsのリポジトリでホスティングされているものなどが存在します。
Stockmark社が提供している学習済みELMoもこの「ELMoForManyLangs」のリポジトリのコードを使って学習させたものとなっています。
今回はこのELMoForManyLangsで学習させた学習済みELMoをAllenNLPで読み込ませます。
AllenNLPとELMoForManyLangsの互換性について
そもそもELMoForManyLangsのリポジトリで学習させたELMoはAllenNLPで使用することができるのでしょうか?
結論から言えば「Yes」です。
AllenNLPでELMoを使用する際には、TokenをELMoで変換して文脈ベクトルを獲得するためのELMoTokenEmbedderというクラスが存在します。
このクラスはELMoに関する設定ファイル(json)とELMoで使用される重み(hdf5)を引数として受け取り学習済みELMoを読みこんで使用します。
このELMoTokenEmbedderは受け取ったTokenを埋め込むための処理諸々を実装しているクラスで、ELMoの重み自体は、
ELMoTokenEmbedder.Elmo._ElmoBiLmが保持しています。
AllenNLPのELMoでは、受け取ったtokenを文字の埋め込みに変換する「_ElmoCharacterEncoder」と埋め込んだベクトルをBiLSTMで変換する「ElmoLstm」の2つで構成されており、それぞれ学習する重みを持っています。
_ElmoBiLm ├── ElmoLstm └── _ElmoCharacterEncoder
ElmoLstm
ElmoLstmはBi-LSTMのforwardとbackwardのlayerがそれぞれ「LstmCellWithProjection」というクラスを使って分けて実装されており、これを指定した層の数(ELMoの場合2層)だけ保持する設計になっています。
幸いにも、ELMoForManyLangsでもこのLstmCellWithProjectionを使ってELMoのBi-LSTMのlayerを実装しているため、保存した重みをそのまま持ち込むことができます。
(しかしELMoForManyLangsの方ではAllenNLPのライブラリを使うわけではなく、コピペしたコードが生で置かれている...これは良いのだろうか)
_ElmoCharacterEncoder
AllenNLPの「_ElmoCharacterEncoder」では、文字を埋め込みそれをLSTMやCNNで変換することでTokenの文字をベクトルに変換しています。
こちらは同名のクラスはありませんが、ELMoForManyLangsでも「ConvTokenEmbedder」といったクラスを用意しており、ほぼ同じ内部構造をしているため移行は(おそらく)可能です。
双方の設定ファイルを見てみると、(一部名前が違っていたりしますが)非常に類似している部分が多く、ここのパラメータを探っていけばうまく相互での対応付けを知ることができます。
AllenNLP
{ "lstm": { "use_skip_connections": true, "projection_dim": 512, "cell_clip": 3, "proj_clip": 3, "dim": 4096, "n_layers": 2 }, "char_cnn": { "activation": "relu", "filters": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]], "n_highway": 2, "embedding": { "dim": 16 }, "n_characters": 262, "max_characters_per_token": 50 } }
ELMoForManyLangs
{ "encoder": { "name": "elmo", "projection_dim": 512, "cell_clip": 3, "proj_clip": 3, "dim": 4096, "n_layers": 2 }, "token_embedder": { "name": "cnn", "activation": "relu", "filters": [[1, 32], [2, 32], [3, 64], [4, 128], [5, 256], [6, 512], [7, 1024]], "n_highway": 2, "word_dim": 100, "char_dim": 0, "max_characters_per_token": 50 }, "classifier": { "name": "sampled_softmax", "n_samples": 8192 }, "dropout": 0.1 }
ファイルの保存形式
ただし厄介なのが、双方で学習した重みを保存する方式が違うということです。
AllenNLPではhdf5で、ELMoForManyLangsではpickleで重みを保存しています。
そのためそのままではAllenNLPで読み込みことができず、少し工夫をしてやる必要があります。
AllenNLPで読み込むためのhdf5形式では、以下のような階層構造を持ったファイルとして保存されている必要があります。各要素は重みを保存したnumpyのarrayです。
CNN/W_cnn_0 CNN/W_cnn_1 CNN/W_cnn_2 CNN/W_cnn_3 CNN/W_cnn_4 CNN/W_cnn_5 CNN/W_cnn_6 CNN/b_cnn_0 CNN/b_cnn_1 CNN/b_cnn_2 CNN/b_cnn_3 CNN/b_cnn_4 CNN/b_cnn_5 CNN/b_cnn_6 CNN_high_0/W_carry CNN_high_0/W_transform CNN_high_0/b_carry CNN_high_0/b_transform CNN_high_1/W_carry CNN_high_1/W_transform CNN_high_1/b_carry CNN_high_1/b_transform CNN_proj/W_proj CNN_proj/b_proj RNN_0/RNN/MultiRNNCell/Cell0/LSTMCell/B RNN_0/RNN/MultiRNNCell/Cell0/LSTMCell/W_0 RNN_0/RNN/MultiRNNCell/Cell0/LSTMCell/W_P_0 RNN_0/RNN/MultiRNNCell/Cell1/LSTMCell/B RNN_0/RNN/MultiRNNCell/Cell1/LSTMCell/W_0 RNN_0/RNN/MultiRNNCell/Cell1/LSTMCell/W_P_0 RNN_1/RNN/MultiRNNCell/Cell0/LSTMCell/B RNN_1/RNN/MultiRNNCell/Cell0/LSTMCell/W_0 RNN_1/RNN/MultiRNNCell/Cell0/LSTMCell/W_P_0 RNN_1/RNN/MultiRNNCell/Cell1/LSTMCell/B RNN_1/RNN/MultiRNNCell/Cell1/LSTMCell/W_0 RNN_1/RNN/MultiRNNCell/Cell1/LSTMCell/W_P_0 char_embed
CNN prefixのものが_ElmoCharacterEncoderの、RNN prefixのものがElmoLstmの、char_embedがCharacterEncoderで使用する文字の埋め込み用の重みです。
これらの形式にあうようにELMoForManyLangsの重みを変換してやる必要があります。
ELMoForManyLangsのモデルがどのような重みを保持しているのか見てみます。
# ELMoForManyLangsの学習済みモデルの読み込み from ELMoForManyLangs.elmoformanylangs import Embedder e = Embedder('ja') for k in e.model.encoder.state_dict().keys(): print(k) >>> forward_layer_0.input_linearity.weight forward_layer_0.state_linearity.weight forward_layer_0.state_linearity.bias forward_layer_0.state_projection.weight backward_layer_0.input_linearity.weight backward_layer_0.state_linearity.weight backward_layer_0.state_linearity.bias backward_layer_0.state_projection.weight forward_layer_1.input_linearity.weight forward_layer_1.state_linearity.weight forward_layer_1.state_linearity.bias forward_layer_1.state_projection.weight backward_layer_1.input_linearity.weight backward_layer_1.state_linearity.weight backward_layer_1.state_linearity.bias backward_layer_1.state_projection.weight
ここでAllenNLPとELMoForManyLangsとの重みの対応付けですが、
W_0 = {input_linearity.weightとstate_linearity.weightをdim=1でconcatしたもの} B = {state_linearity.bias} W_P_0 = {state_projection.weight}
となっているので、そうなるように変換します。
(AllenNLPの重みの読み込み処理はこちら)
またhdf5では RNN_{direction(forwardなら0)}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/W_0
となっておるのでそこにも注意が必要です。
# hdf5向けに階層構造を作る def create_lstm_weight(hdf5_file, state_dict, direction, layer): direction_num = 0 if direction == 'forward' else 1 base_key = f'{direction}_layer_{layer}.' concat_weight = torch.cat([state_dict[base_key + 'input_linearity.weight'], state_dict[base_key + 'state_linearity.weight']], dim=1) # weight hdf5_file.create_dataset( f'RNN_{direction_num}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/W_0', data=np.transpose(concat_weight.cpu()) ) # bias hdf5_file.create_dataset( f'RNN_{direction_num}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/B', data=np.transpose(state_dict[base_key + 'state_linearity.bias'].cpu()) ) # projection hdf5_file.create_dataset( f'RNN_{direction_num}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/W_P_0', data=np.transpose(state_dict[base_key + 'state_projection.weight'].cpu()) ) from ELMoForManyLangs.elmoformanylangs import Embedder e = Embedder('ja') # hdf5形式で保存する with h5py.File('pretrained_weight.h5', 'w') as f: # LSTM directions = ['forward', 'backward'] layers = [0,1] for direction in directions: for layer in layers: create_lstm_weight(f, e.model.encoder.state_dict(), direction, layer)
これでAllenNLPで読み込む用の変換は完了です。
ちょっと足りなくないかぁ?
そうですね。CNNの部分の重みをまだ変換できていません。
AllenNLPのELMoでは、Tokenの埋め込みへの変換には文字単位で埋め込む「_ElmoCharacterEncoder」しか使用することができません。
しかし、調べてみるとStockmark社が公開している学習済みELMoモデルでは単語単位での埋め込みしか行っておらず、そのままAllenNLPで動かすことはできませんでした。
ひとまずAllenNLPで読み込めるところまで持っていく
足りないCNNとchar_embedの部分は別で公開されている学習済みELMoから流用してきます。
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
l = [ 'CNN/W_cnn_0', 'CNN/W_cnn_1', 'CNN/W_cnn_2', 'CNN/W_cnn_3', 'CNN/W_cnn_4', 'CNN/W_cnn_5', 'CNN/W_cnn_6', 'CNN/b_cnn_0', 'CNN/b_cnn_1', 'CNN/b_cnn_2', 'CNN/b_cnn_3', 'CNN/b_cnn_4', 'CNN/b_cnn_5', 'CNN/b_cnn_6', 'CNN_high_0/W_carry', 'CNN_high_0/W_transform', 'CNN_high_0/b_carry', 'CNN_high_0/b_transform', 'CNN_high_1/W_carry', 'CNN_high_1/W_transform', 'CNN_high_1/b_carry', 'CNN_high_1/b_transform', 'CNN_proj/W_proj', 'CNN_proj/b_proj', 'char_embed' ] with h5py.File('pretrained_weight.h5', 'w') as f: # LSTM directions = ['forward', 'backward'] layers = [0,1] for direction in directions: for layer in layers: create_lstm_weight(f, e.model.encoder.state_dict(), direction, layer) # CNNとchar_embedは他から持ってくる # CNN with h5py.File('elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', 'r') as pf: for path in l: if path.startswith('CNN'): f.create_dataset(path, data=pf[path].value) # char_embed f.create_dataset('char_embed', data=pf['char_embed'].value)
これでAllenNLPで読み込むのに必要な重みが全て揃いました。あとは
from allennlp.modules.token_embedders import ElmoTokenEmbedder elmo_embedder = ElmoTokenEmbedder(options_file='elmo_2x4096_512_2048cnn_2xhighway_options.json', weight_file='pretrained_weight.h5')
とやって問題なく読み込めれば成功です。
まとめ
ちょっと惜しかった
ELMoForManyLangsでchar_dimを16に設定していればそのまま持っていくことも可能だったと思われる。
どこかでcharの設定も行ったELMoForManyLangsでELMoを学習させて、それをAllenNLPに持っていくことも試してみたいものだ。
(Stockmarkさんが本気出したりしないだろうか)
ただ、記事中でも書いたが、ELMoForManyLangsの方でLstmCellWithProjectionがAllenNLPのライブラリではなく、リポジトリで生書きされているため、どこかでVer追従できなくなる恐れがあるため、そこは対策を考えなくてはならない。