やむやむもやむなし

やむやむもやむなし

自然言語処理やエンジニアリングのメモ

学習済みELMoをAllenNLPで読み込む -りたーんず!-

この記事は自然言語処理アドベントカレンダー 2019の15日目です。

きっかけ

[1]

f:id:ymym3412:20191211154119p:plain

[2]

f:id:ymym3412:20191211154201p:plain

[3]

f:id:ymym3412:20191211154322p:plain

[4]

f:id:ymym3412:20191211154804p:plain

ストックマークさんが本気を出したんだ。
俺も覚悟を決めなくてはならない。

ということで、ストックマークさんがあらためて出してくれた学習済み日本語ELMoを使って、こちらの記事ではできなかった学習済みELMoをAllenNLPで読み込むことを今度こそ成し遂げます。

後述するELMoForManyLangs -> AllenNLPのスクリプトGithubで公開しています。 github.com

AllenNLPとELMoForManyLangs

ストックマーク社が公開している日本語学習済みELMoはELMoForManyLangsというGIthubリポジトリのコードを使って学習されています。

github.com

以前の記事でも触れましたが、このELMoForManyLangsとAllenNLPにはある程度の互換性があり、ELMoForManyLangsで学習させたモデルに変換をかませることでAllenNLPで使えるようにすることができます。

ここからはELMoForManyLangs -> AllenNLPへのConvertについて紹介していきます。

ELMoの変換

ELMoForManyLangs/AllenNLPで使われているELMoはおおまかに分けて以下の5つのパーツから成り立っています。 それぞれの変換の方法を説明していきます。

  1. Char Emb
  2. Char Convolution
  3. Highway Net
  4. Projection
  5. Bi-LSTM

f:id:ymym3412:20191215025035p:plain

1.Char Embedding

文字通り文字の埋め込み表現です。 これは特に工夫もなくchar embeddingの重みをそのままAllenNLPに持ち込むことができます。 ひとつだけ注意点として、AllenNLPではpadding用のtokenは基本的にはid: 0で指定されるので、char embeddingのid :0の埋め込みがpadding用のtokenのものになるように入れ替えておきます。 (語彙とidのマッピングを行うchar.dicの中身も変えておきましょう)

def create_char_embed_weight(hdf5_file, embedding_layer, char_dic):
    emb = embedding_layer.state_dict()['embedding.weight'].cpu().numpy()
    # AllenNLP makes padding token id zero
    # Swap top and padding vector
    pad_token_id = char_dic['<pad>']
    emb[0], emb[pad_token_id] = emb[pad_token_id], emb[0]
    hdf5_file.create_dataset('char_embed', data=emb)

2.Char Convolution

charのembeddingに対して1d Convolutionをかける層です。 NLPでよく見られるembeddingを並べたものにConvolutionをかけてN-gramの特徴を抽出するものです。

def create_CNN_weight(hdf5_file, convolutions):
    for i, conv1d in enumerate(convolutions):
        state_dict = conv1d.state_dict()
        weight = state_dict['weight'].cpu().numpy()  # width * char_emb_dim * out_ch
        bias = state_dict['bias'].cpu().numpy()

        weight = np.transpose(weight)
        weight = weight.reshape(1, *weight.shape)  # 1 * in_ch * char_emb_dim * width
        hdf5_file.create_dataset('CNN/W_cnn_{}'.format(i), data=weight)
        hdf5_file.create_dataset('CNN/b_cnn_{}'.format(i), data=bias)

3.Highway Net

Highway NetはResNetのSkip connectionのように層の出力値に入力値を足し合わせるような構造になっています。
Skip Connectionと異なるのは層の出力値と入力値を足し合わせる割合をゲーティングで制御しているという点です。


y = H(x, W_h) \cdot (1 - T(x, W_T))  + x \cdot T(x, W_T)

H (transform)はxを変換するLinear層、T (carry)は各次元のゲーティングを行うためのベクトルを生成するLinea層です。
この変換を複数回繰り返す(ストックマークのものは2回)して最終的な出力を得ます。

ELMoForManyLangsでは入力値の変換を行うtransform: Hと入力値をそのまま足すためのcarry: Cが同じLinear層で実装されていますが、AllenNLPでは別々のLinear層として実装されています。
なので、Linear層の重みとバイアスを分割して別々に保存します。

def create_hightway_weight(hdf5_file, hightway_layers):
    """
    In ELMoForManyLangs, highway layer has linear layer.
    The weight of linear layer consist of two part, carry weight and non-linear weight.
    First half of weight is carry weight, and latter half is non-linear part.
    See also https://medium.com/jim-fleming/highway-networks-with-tensorflow-1e6dfa667daa
    """
    for i, layer in enumerate(hightway_layers):
        state_dict = layer.state_dict()
        # input_dim * 2, input_dim
        weight = state_dict['weight'].cpu().numpy()
        # input_dim
        bias = state_dict['bias'].cpu().numpy()

        input_dim = weight.shape[1]
        w_carry = weight[:input_dim, :]
        w_transform = weight[input_dim:, :]
        b_carry = bias[:input_dim]
        b_transform = bias[input_dim:]

        hdf5_file.create_dataset('CNN_high_{}/W_carry'.format(i), data=np.transpose(w_carry))
        hdf5_file.create_dataset('CNN_high_{}/W_transform'.format(i), data=np.transpose(w_transform))
        hdf5_file.create_dataset('CNN_high_{}/b_carry'.format(i), data=b_carry)
        hdf5_file.create_dataset('CNN_high_{}/b_transform'.format(i), data=b_transform)

4.Projection

Highway Netで得た表現をLSTMへ投入する次元へと変換するLinear層です。
ELMoForManyLangsでは単語の埋め込みと文字の埋め込みの両方を利用して入力値の埋め込みを得ることができますが、AllenNLPでは文字を使っての埋め込みしか対応していません。
そのためこのProjection層も文字の埋め込みの変換の重みだけを抜き出してAllenNLPに持ち込みます。

def create_projection_weight(hdf5_file, projection, word_dim):
    # In ELMoForManyLangs, embedding is created by concat of word emb and char emb.
    # So transfer only char emb projection.
    weight = projection.state_dict()['weight'].cpu().numpy()[:, word_dim:]
    bias = projection.state_dict()['bias'].cpu().numpy()
    hdf5_file.create_dataset('CNN_proj/W_proj', data=np.transpose(weight))
    hdf5_file.create_dataset('CNN_proj/b_proj', data=bias)

5.Bi-LSTM

最後にELMoの根幹となるBiLSTMの重みです。
これは以前の記事でも触れたようにLstmCellWithProjectionというクラスが単層単方向のLSTMを表現するクラスであり、これを2つ合わせてBi-LSTMを、それを2層重ねてELMoのLSTMを構築しています。

ELMoのLSTMの重みは以下のような構造となっており、これとAllenNLPへの対応付けが次のようになっています。

# ELMoForManyLangsの学習済みモデルの読み込み
from ELMoForManyLangs.elmoformanylangs import Embedder

e = Embedder('ja')
for k in e.model.encoder.state_dict().keys():
    print(k)

>>>
forward_layer_0.input_linearity.weight
forward_layer_0.state_linearity.weight
forward_layer_0.state_linearity.bias
forward_layer_0.state_projection.weight
backward_layer_0.input_linearity.weight
backward_layer_0.state_linearity.weight
backward_layer_0.state_linearity.bias
backward_layer_0.state_projection.weight
forward_layer_1.input_linearity.weight
forward_layer_1.state_linearity.weight
forward_layer_1.state_linearity.bias
forward_layer_1.state_projection.weight
backward_layer_1.input_linearity.weight
backward_layer_1.state_linearity.weight
backward_layer_1.state_linearity.bias
backward_layer_1.state_projection.weight
W_0  = {input_linearity.weightとstate_linearity.weightをdim=1でconcatしたもの}  
B  = {state_linearity.bias}  
W_P_0 =  {state_projection.weight}  

上記の変換を以下のコードで実装しています。

def create_lstm_weight(hdf5_file, encoder):
    state_dict = encoder.state_dict()
    directions = ['forward', 'backward']
    layers = [0, 1]
    for direction in directions:
        for layer in layers:
            direction_num = 0 if direction == 'forward' else 1
            base_key = f'{direction}_layer_{layer}.'
            concat_weight = torch.cat([state_dict[base_key + 'input_linearity.weight'], state_dict[base_key + 'state_linearity.weight']], dim=1)
            # weight
            hdf5_file.create_dataset(
                f'RNN_{direction_num}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/W_0',
                data=np.transpose(concat_weight.cpu())
            )
            # bias
            hdf5_file.create_dataset(
                f'RNN_{direction_num}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/B',
                data=np.transpose(state_dict[base_key + 'state_linearity.bias'].cpu())
            )
            # projection
            hdf5_file.create_dataset(
                f'RNN_{direction_num}/RNN/MultiRNNCell/Cell{layer}/LSTMCell/W_P_0',
                data=np.transpose(state_dict[base_key + 'state_projection.weight'].cpu())
            )

config.json

あとは設定ファイルの変換です。
こちらはattributeがほとんど一緒なのでAllenNLP側のフォーマットに合わせてえいやと変換します。

def convert_config(config):
    """
    convert ELMoForManyLangs config to AllenNLP
    """
    allennlp_config = {}

    char_cnn_dict = {}
    char_cnn_dict['activation'] = config['token_embedder']['activation']
    char_cnn_dict['filters'] = config['token_embedder']['filters']
    char_cnn_dict['n_highway'] = config['token_embedder']['n_highway']
    char_cnn_dict['embedding'] = {'dim': config['token_embedder']['char_dim']}
    char_cnn_dict['max_characters_per_token'] = config['token_embedder']['max_characters_per_token']
    allennlp_config['char_cnn'] = char_cnn_dict

    lstm_dict = {}
    # Currently, AllenNLP support lstm with skip connection only
    lstm_dict['use_skip_connections'] = True
    lstm_dict['projection_dim'] = config['encoder']['projection_dim']
    lstm_dict['cell_clip'] = config['encoder']['cell_clip']
    lstm_dict['proj_clip'] = config['encoder']['proj_clip']
    lstm_dict['dim'] = config['encoder']['dim']
    lstm_dict['n_layers'] = config['encoder']['n_layers']
    allennlp_config['lstm'] = lstm_dict

    return allennlp_config

上記のスクリプトを実行すれば、AllenNLP向けに変換したconfigのjson、ELMoForManyLangsのpickleファイルをhdf5フォーマットに変換したモデルのファイル、変換した語彙のファイルが出来上がります。
これでAllenNLPにファイルを持ち込む用意が出来ました。

変換したファイルを使って以下のコードがエラーを吐かずに実行できれば変換は成功です。

from allennlp.modules.token_embedders import ElmoTokenEmbedder

options_file = 'allennlp_config.json'
weight_file = 'allennlp_elmo.hdf5'
elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)

AllenNLPでELMoを使った学習

AllenNLPではひとつひとつの単語をToken、Tokenをまとめた文章をField、文章のFieldやラベルのFieldなど学習に使うデータをひとつにまとめたものをInstanceと呼びます。
そして、TokenにIDを振るIndexer、IDで表現されたTokenを埋め込み表現に変換するTokenEmbedder、この2つを使って単語を埋め込み表現へと変換します。

ELMoについても同様で、Tokenを受け取って文字のIDを振るELMoTokenCharactersIndexerと文字IDからELMo表現を得るElmoTokenEmbedderが用意されています。

単語を受け取ってそれをIDにマッピングするIndexerですが、公式のELMoのTutorialを読むと単語とIDへのマッピングを行わなくても各単語のELMo表現の計算が出来てしまっています。何故でしょうか?

from allennlp.modules.elmo import Elmo, batch_to_ids

options_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"

# Compute two different representation for each token.
# Each representation is a linear weighted combination for the
# 3 layers in ELMo (i.e., charcnn, the outputs of the two BiLSTM))
elmo = Elmo(options_file, weight_file, 2, dropout=0)

# use batch_to_ids to convert sentences to character ids
sentences = [['First', 'sentence', '.'], ['Another', '.']]
# 事前学習済みファイルに紐づく文字:IDマッピングを必要としていない
character_ids = batch_to_ids(sentences)

embeddings = elmo(character_ids)

# embeddings['elmo_representations'] is length two list of tensors.
# Each element contains one layer of ELMo representations with shape
# (2, 3, 1024).
#   2    - the batch size
#   3    - the sequence length of the batch
#   1024 - the length of each ELMo vector

これはELMoTokenCharactersIndexerが内部で使用しているELMoCharacterMapperというクラスの影響で、このクラスは受け取った文字のUnicodeのバイト表現を使ってマッピングを行っています(例: y -> U+0079 -> 121)。

def convert_word_to_char_ids(self, word: str) -> List[int]:
    if word in self.tokens_to_add:
        char_ids = [ELMoCharacterMapper.padding_character] * ELMoCharacterMapper.max_word_length
        char_ids[0] = ELMoCharacterMapper.beginning_of_word_character
        char_ids[1] = self.tokens_to_add[word]
        char_ids[2] = ELMoCharacterMapper.end_of_word_character
    elif word == ELMoCharacterMapper.bos_token:
        char_ids = ELMoCharacterMapper.beginning_of_sentence_characters
    elif word == ELMoCharacterMapper.eos_token:
        char_ids = ELMoCharacterMapper.end_of_sentence_characters
    else:
        word_encoded = word.encode("utf-8", "ignore")[
            : (ELMoCharacterMapper.max_word_length - 2)
        ]
        char_ids = [ELMoCharacterMapper.padding_character] * ELMoCharacterMapper.max_word_length
        char_ids[0] = ELMoCharacterMapper.beginning_of_word_character
        # unicodeのIDをそのままchar embeddingのIDにしている
        for k, chr_id in enumerate(word_encoded, start=1):
            char_ids[k] = chr_id
        char_ids[len(word_encoded) + 1] = ELMoCharacterMapper.end_of_word_character

    # +1 one for masking
    return [c + 1 for c in char_ids]

そのため、マルチバイト文字を使用するとうまくIDにマッピング出来ないという問題があります("あ"の場合、UTF-8で「\xe7\x94\xb7」なので3つのIDに分割されてしまいます)。

この問題を回避するためにELMoForManyLangsの学習で得た日本語の文字とIDへのマッピング「char.dic」を使う新しいクラスを自作します。

from typing import Dict, List

from overrides import overrides
import torch

from allennlp.common.checks import ConfigurationError
from allennlp.common.util import pad_sequence_to_length
from allennlp.data.tokenizers.token import Token
from allennlp.data.token_indexers.token_indexer import TokenIndexer
from allennlp.data.vocabulary import Vocabulary


def _make_bos_eos(
    character: int,
    padding_character: int,
    beginning_of_word_character: int,
    end_of_word_character: int,
    max_word_length: int,
):
    char_ids = [padding_character] * max_word_length
    char_ids[0] = beginning_of_word_character
    char_ids[1] = character
    char_ids[2] = end_of_word_character
    return char_ids


class CustomELMoCharacterMapper:
    """
    Maps individual tokens to sequences of character ids, compatible with ELMo.
    To be consistent with previously trained models, we include it here as special of existing
    character indexers.
    We allow to add optional additional special tokens with designated
    character ids with ``tokens_to_add``.
    """
    max_word_length = 50
    def __init__(self, tokens_to_add: Dict[str, int] = None) -> None:
        self.tokens_to_add = tokens_to_add or {}
        
        # setting special token
        self.beginning_of_sentence_character = self.tokens_to_add['<bos>']  # <begin sentence>
        self.end_of_sentence_character = self.tokens_to_add['<eos>']  # <end sentence>
        self.beginning_of_word_character = self.tokens_to_add['<bow>']  # <begin word>
        self.end_of_word_character = self.tokens_to_add['<eow>']  # <end word>
        self.padding_character = self.tokens_to_add['<pad>']  # <padding>
        self.oov_character = self.tokens_to_add['<oov>']
        
        self.max_word_length = 50

        # char ids 0-255 come from utf-8 encoding bytes
        # assign 256-300 to special chars

        self.beginning_of_sentence_characters = _make_bos_eos(
            self.beginning_of_sentence_character,
            self.padding_character,
            self.beginning_of_word_character,
            self.end_of_word_character,
            self.max_word_length,
        )
        self.end_of_sentence_characters = _make_bos_eos(
            self.end_of_sentence_character,
            self.padding_character,
            self.beginning_of_word_character,
            self.end_of_word_character,
            self.max_word_length,
        )

        self.bos_token = "<bos>"
        self.eos_token = "<eos>"

    def convert_word_to_char_ids(self, word: str) -> List[int]:
        if word in self.tokens_to_add:
            char_ids = [self.padding_character] * self.max_word_length
            char_ids[0] = self.beginning_of_word_character
            char_ids[1] = self.tokens_to_add[word]
            char_ids[2] = self.end_of_word_character
        elif word == self.bos_token:
            char_ids = self.beginning_of_sentence_characters
        elif word == self.eos_token:
            char_ids = self.end_of_sentence_characters
        else:
            word = word[: (self.max_word_length - 2)]
            char_ids = [self.padding_character] * self.max_word_length
            char_ids[0] = self.beginning_of_word_character
            for k, char in enumerate(word, start=1):
                char_ids[k] = self.tokens_to_add[char] if char in self.tokens_to_add else self.oov_character
            char_ids[len(word) + 1] = self.end_of_word_character

        # +1 one for masking
        # return [c + 1 for c in char_ids]
        return char_ids

    def __eq__(self, other) -> bool:
        if isinstance(self, other.__class__):
            return self.__dict__ == other.__dict__
        return NotImplemented


@TokenIndexer.register("custom_elmo_characters")
class CustomELMoTokenCharactersIndexer(TokenIndexer[List[int]]):
    """
    Convert a token to an array of character ids to compute ELMo representations.
    Parameters
    ----------
    namespace : ``str``, optional (default=``elmo_characters``)
    tokens_to_add : ``Dict[str, int]``, optional (default=``None``)
        If not None, then provides a mapping of special tokens to character
        ids. When using pre-trained models, then the character id must be
        less then 261, and we recommend using un-used ids (e.g. 1-32).
    token_min_padding_length : ``int``, optional (default=``0``)
        See :class:`TokenIndexer`.
    """

    def __init__(
        self,
        namespace: str = "elmo_characters",
        tokens_to_add: Dict[str, int] = None,
        token_min_padding_length: int = 0,
    ) -> None:
        super().__init__(token_min_padding_length)
        self._namespace = namespace
        self._mapper = CustomELMoCharacterMapper(tokens_to_add)

    @overrides
    def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
        pass

    @overrides
    def tokens_to_indices(
        self, tokens: List[Token], vocabulary: Vocabulary, index_name: str
    ) -> Dict[str, List[List[int]]]:
        # TODO(brendanr): Retain the token to index mappings in the vocabulary and remove this

        # https://github.com/allenai/allennlp/blob/master/allennlp/data/token_indexers/wordpiece_indexer.py#L113

        texts = [token.text for token in tokens]

        if any(text is None for text in texts):
            raise ConfigurationError(
                "ELMoTokenCharactersIndexer needs a tokenizer " "that retains text"
            )
        return {index_name: [self._mapper.convert_word_to_char_ids(text) for text in texts]}

    @overrides
    def get_padding_lengths(self, token: List[int]) -> Dict[str, int]:

        return {}

    @staticmethod
    def _default_value_for_padding():
        return [0] * CustomELMoCharacterMapper.max_word_length

    @overrides
    def as_padded_tensor(
        self,
        tokens: Dict[str, List[List[int]]],
        desired_num_tokens: Dict[str, int],
        padding_lengths: Dict[str, int],
    ) -> Dict[str, torch.Tensor]:

        return {
            key: torch.LongTensor(
                pad_sequence_to_length(
                    val, desired_num_tokens[key], default_value=self._default_value_for_padding
                )
            )
            for key, val in tokens.items()
        }

元のクラスの基本的な処理には手を加えておらず、Indexerに渡して文字とIDのマッピングの辞書を使ってIndexingを行うように修正しただけです。

このクラスを使ってTokenを文字のID列にIndexingできるようになれば、AllenNLPで日本語の学習済みELMoを使うことができるようになります。

with open('char_for_allennlp.dic') as f:
    char_dic = {line.split('\t')[0]: int(line.split('\t')[1].strip('\n')) for line in f}

char_indexer = CustomELMoTokenCharactersIndexer(tokens_to_add=char_dic)

def text_to_instance(word_list, label):
    tokens = [Token(word) for word in word_list]
    word_sentence_field = TextField(tokens, {"tokens":SingleIdTokenIndexer()})
    char_sentence_field = TextField(tokens, {'char_tokens': char_indexer})
    fields = {"tokens":word_sentence_field, 'char_tokens': char_sentence_field}
    if label is not None:
        label_field = LabelField(label, skip_indexing=True)
        fields["label"] = label_field
    return Instance(fields)

train_dataset = [text_to_instance([token.surface for token in document], label) for document, label in zip(processed, train_y)]

VOCAB_SIZE = 30000
vocab = Vocabulary.from_instances(train_dataset, max_vocab_size=VOCAB_SIZE)

BATCH_SIZE = 4
iterator = BucketIterator(batch_size=BATCH_SIZE, sorting_keys=[("tokens", "num_tokens")])
iterator.index_with(vocab)

batch = next(iter(iterator(train_dataset)))
print(batch)

>>>{'tokens': {'tokens': tensor([[  163,   558,   219,  ...,  4557,   274,   150],
          [    1,     8,     2,  ...,     0,     0,     0],
          [   80,    52,   422,  ...,     0,     0,     0],
          [12897,    17,  6871,  ...,     0,     0,     0]])},
 'char_tokens': {'char_tokens': tensor([[[8639,  871, 8640,  ...,    0,    0,    0],
           [8639,   39,   43,  ...,    0,    0,    0],
           [8639,  263, 8640,  ...,    0,    0,    0],
           ...,
           [8639,  466,  508,  ...,    0,    0,    0],
           [8639,  557,  558,  ...,    0,    0,    0],
           [8639,  810,  376,  ...,    0,    0,    0]],
  
          [[8639,   27,  205,  ...,    0,    0,    0],
           [8639,   38, 8640,  ...,    0,    0,    0],
           [8639,   25, 8640,  ...,    0,    0,    0],
           ...,
           [   0,    0,    0,  ...,    0,    0,    0],
           [   0,    0,    0,  ...,    0,    0,    0],
           [   0,    0,    0,  ...,    0,    0,    0]],
  
          [[8639,  143, 8640,  ...,    0,    0,    0],
           [8639,   44, 8640,  ...,    0,    0,    0],
           [8639,   13,   14,  ...,    0,    0,    0],
           ...,
           [   0,    0,    0,  ...,    0,    0,    0],
           [   0,    0,    0,  ...,    0,    0,    0],
           [   0,    0,    0,  ...,    0,    0,    0]],
  
          [[8639,  202,   57,  ...,    0,    0,    0],
           [8639,   86, 8640,  ...,    0,    0,    0],
           [8639,  169, 2252,  ...,    0,    0,    0],
           ...,
           [   0,    0,    0,  ...,    0,    0,    0],
           [   0,    0,    0,  ...,    0,    0,    0],
           [   0,    0,    0,  ...,    0,    0,    0]]])},
 'label': tensor([4, 1, 4, 7])}

まとめ

ストックマークさんが公開している事前学習済み日本語ELMoをAllenNLPで読み込む手順を紹介しました。
AllenNLPはNLPの実験に特化した仕様になっているため、さくっとモデルを作って実験を回すのに非常に便利です。
ELMoによる埋め込みを得るモジュールの他にもBERTの埋め込みを得るモジュールも用意されているので、自分の用意したモデルの+1ポイントとして使ってみてはいかがでしょうか。

謝辞

文字表現も含めた事前学習済み日本語ELMoを公開してくださったストックマークさんにここで感謝を述べさせて頂きます。

参考文献