やむやむもやむなし

やむやむもやむなし

自然言語処理やエンジニアリングのメモ

Grad-CAMを使ったNLPモデルの判断根拠の可視化

機械学習モデルの解釈性は業務で使う上ではなぜそのような予測を行ったかの判断根拠の可視化として、また学習させたモデルをデバックする際にどんな入力に反応して誤予測を引き起こしてしまったか分析する上で非常に重要な要素です。
画像分野ではGrad-CAMと呼ばれる勾配を使った予測根拠の可視化手法が提案されており、今回はその手法を使ってNLP向けのCNNモデルの判断根拠を可視化していきます。
実験で使用したノートブックはGithub上で公開しています。

github.com

機械学習モデルの解釈性

機械学習モデルに対する解釈性は近年では特に重要なトピックです。例えば

  • 業務の自動化を機械学習で行う場合に説明責任が生じる
  • DNNのデバッグをして性能改善を行いたい

といったときに機械学習モデルの解釈性は必要になります。
機械学習モデルの解釈性についてはステアラボ人工知能セミナーでの原聡先生の資料がとても分かりやすいです。

機械学習の解釈性には「大域的な説明(Global Interpretability)」と「局所的な説明(Local Interpretability)」のふたつに大きく分けられます。

大域的な説明

大域的な説明は複雑なモデルを決定木や線形回帰といった解釈が容易なモデルで近似することでモデルを説明する方法です。
説明したいモデルの全体を解釈しやすいモデルで近似することで、モデルがどのように予測を行うかというモデルの内部を説明しているのが特徴です。

局所的な説明

モデル全体を説明する大域的な説明とは異なり、特定の入力に対する予測結果の説明を行うのが局所的な説明です。
これは入力のどの特徴量、あるいはどの訓練データによってその予測が行われたかの根拠を提示するために用いられます。
有名な手法では「LIME(Local Interpretable Model-agnostic Explanations)」があります。
LIMEは説明を行いたい予測への入力データに摂動を加えたデータセットを生成し、それに対して説明を行いたいモデルfがどのように振る舞うかを解釈可能なモデルg (e.g. 決定木やLasso)で近似することで、特定の予測に対する特徴量の寄与度を測る方法です。
説明したいモデルをf、近似するモデルをg、説明した入力データxと摂動を加えたデータ間の距離を測る関数\pi_x、近似するモデルの複雑度を表す関数 \Omegaを使って以下の式を求めることでLIMEによって説明を行うモデルを構築します。


\begin{equation} 
\xi(x) = argmin_{g \in G} L(f, g, \pi_x) + \Omega(g) 
\end{equation}

深層学習のモデルの説明には局所的な説明を行うものが多く、後述するGrad-CAMも局所的な説明を行う手法の一種です。

Grad-CAM

Grad-CAMは画像認識の分野で使われている、分類の根拠を提示する局所的な説明手法です。
Grad-CAMではVGGやResNetのようにConvolutionやpoolingを繰り返し最後に全結合層に接続してクラス分類を行うようなモデルに対して、全結合層の前のConvolution層で生成された特徴マップが、予測したラベルに対してどれくらい影響を与えているかを以下のように勾配を使って計算します。


\begin{equation} 
\quad\quad \alpha_k^c = {\frac{1}{Z} \sum_i \sum_j} \frac{\partial y^c}{\partial A_{ij}^k} \quad\quad (1)
\end{equation}

\begin{equation} 
\quad\quad L^c_{\rm Grad-CAM} = ReLU \left(\sum_k \alpha_k^c A^k \right) \quad\quad (2)
\end{equation}

式(1)では、クラスcに対して特徴マップの各要素が微小に変化した際にクラスの確率がどれくらい変化するかを計算して特徴マップないで平滑化を行い、そして式(2)でその値を使って特徴マップ領域内の重要度を計算しています。
これにより、モデルの予測に対して入力画像のどの領域が影響を与えるかを計算して可視化することが可能となります。
またこれは予測に対する各特徴マップの勾配が計算できれば良いため、画像分類に限らず、画像キャプション生成やVisual Question Answeringなどにも利用することができます。

f:id:ymym3412:20190319002222p:plain

f:id:ymym3412:20190319002450p:plain

f:id:ymym3412:20190319002505p:plain

CNNの分類モデルの判断根拠の可視化

では画像分野で使われているモデルの説明手法をどうやってNLPで使うのか。
答えはシンプルで、NLP向けのCNNをベースにしたモデルを使えばよいのです。
幸い、以前に日本語の記事のカテゴリ分類を行うCNNモデルを作ったので、それを使ってGrad-CAMによる予測の判断根拠の可視化を行います。

学習

モデルは以前使用したKim[2014]のCNNによる文書分類モデルに少し改良を加えたものを使用します。

import torch
import torch.nn as nn
import torch.nn.functional as F

class  CNN_Text(nn.Module):
    
    def __init__(self, pretrained_wv, output_dim, kernel_num, kernel_sizes=[3,4,5], dropout=0.5, static=False):
        super(CNN_Text,self).__init__()
        
        weight = torch.from_numpy(pretrained_wv)
        self.embed = nn.Embedding.from_pretrained(weight, freeze=False)
        self.convs1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (k, self.embed.weight.shape[1])) for k in kernel_sizes])
        self.bns1 = nn.ModuleList([nn.BatchNorm2d(kernel_num) for _ in kernel_sizes])
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(len(kernel_sizes)*kernel_num, output_dim)
        self.static = static

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3) #(N,Co,W)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x


    def forward(self, x):
        x = self.embed(x) # (N,W,D)
        
        if self.static:
            x = x.detach()

        x = x.unsqueeze(1) # (N,Ci,W,D)
        x = x.float()
        x = [F.relu(bn(conv(x))).squeeze(3) for conv, bn in zip(self.convs1, self.bns1)] #[(N,Co,W), ...]*len(Ks)

        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)

        x = torch.cat(x, 1)
        x = self.dropout(x) # (N,len(Ks)*Co)
        logit = self.fc1(x) # (N,C)
        return logit

最適化にはAdaBoundを使用し、400epoch学習させたあとはSGDで学習率を下げながら学習させました。

from tensorboardX import SummaryWriter
import torch
import numpy as np
import adabound

from sklearn.metrics import accuracy_score

def calc_accuray(model, target_preprocessed, writer, ite, mode, use_cuda):
    model.eval()
    feature = torch.LongTensor(target_preprocessed['article'])
    if use_cuda:
        feature = feature.cuda()
    forward = model(feature)
    predicted_label = forward.argmax(dim=1).cpu()
    test_target = torch.LongTensor(np.argmax(target_preprocessed['label'], axis=1))
    accuracy = accuracy_score(test_target.numpy(), predicted_label.numpy())
    writer.add_scalar('data/{}_accuracy'.format(mode), accuracy, ite)
    model.train()
    return accuracy

writer = SummaryWriter()
output_dim = 9
kernel_num = 200
kernel_sizes = [3,4,5,7]
dropout = 0.5

model = CNN_Text(embedding, output_dim, kernel_num, kernel_sizes, dropout)
use_cuda = True
opt = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1, weight_decay=5e-4)
# opt = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
model.train()
if use_cuda:
    model = model.cuda()
    
for ite, b in enumerate(dp.iterate(preprocessed, batch_size=64, epoch=400)):
    feature = torch.LongTensor(b['article'])
    target = torch.LongTensor(np.argmax(b['label'], axis=1))
    if use_cuda:
        feature = feature.cuda()
        target = target.cuda()
        
    opt.zero_grad()
    logit = model(feature)
    loss = F.nll_loss(F.log_softmax(logit), target)
    loss.backward()
    opt.step()
    writer.add_scalar('data/training_loss', loss.item(), ite)
    
    # check training accuray
    calc_accuray(model, b, writer, ite, 'training', use_cuda)
    
    # check validation accuracy
    if ite % 100 == 0:
        # calc training accuracy
        calc_accuray(model, val_preprocessed, writer, ite, 'validation', use_cuda)

writer.close()

前処理の工夫やモデルの改善により、以前の実験より大幅にtestセットでの性能が改善しました。

              precision    recall  f1-score   support

           0       0.88      0.82      0.85       131
           1       0.75      0.91      0.82       131
           2       0.81      0.74      0.77       130
           3       0.74      0.62      0.68        77
           4       0.94      0.91      0.92       130
           5       0.75      0.90      0.82       126
           6       0.99      0.84      0.91       131
           7       0.93      0.93      0.93       135
           8       0.84      0.88      0.86       115

   micro avg       0.85      0.85      0.85      1106
   macro avg       0.85      0.84      0.84      1106
weighted avg       0.86      0.85      0.85      1106
accuracy: 0.849005424954792

Grad-CAMによる可視化

学習したCNNモデルの予測根拠をGrad-CAMを使って可視化していきます。
PyTorchでのGrad-CAMの実装はCNNを使った分類問題の判断根拠(画像編)のコードをお借りしました。

class GradCAM:
    def __init__(self, model, feature_layer):
        self.model = model
        self.feature_layer = feature_layer
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hooks = []

        def save_feature_grad(module, in_grad, out_grad):
            self.feature_grad = out_grad[0]
        self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad))

        def save_feature_map(module, inp, outp):
            self.feature_map = outp[0]
        self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map))

    def forward(self, x):
        return self.model(x)

    def backward_on_target(self, output, target):
        self.model.zero_grad()
        one_hot_output = torch.zeros([1, output.size()[-1]]).cuda()
        one_hot_output[0][target] = 1
        output.backward(gradient=one_hot_output, retain_graph=True)

    def clear_hook(self):
        for hook in self.hooks:
            hook.remove()

まずは予測ラベルがあっていた場合に、どの単語が予測に影響を与えているかを見ていきます。
モデルはカーネルサイズを3, 4, 5, 7で設定しており、文書中の3-gram, 4-gram, 5-gram, 7-gramの関係を見て畳込んでいます。まずは3-gram単位で重要度のヒートマップを出してみます。

news_id = 20
# Grad-CAMのインスタンスを作成
grad_cam = GradCAM(model=model, feature_layer=model.convs1[0])
# testセットの準備
test_input = test_preprocessed['article'][news_id:news_id+1]
test_tensor = torch.LongTensor(test_input).cuda()
test_target = torch.LongTensor(np.argmax(test_preprocessed['label'][news_id:news_id+1], axis=1))
# モデルでの順伝搬
model_output = grad_cam.forward(test_tensor)
predicted_label = model_output.argmax().item()
# 予測したラベルに対する逆伝搬を計算
grad_cam.backward_on_target(model_output, predicted_label)
# 各特徴マップの要素に対する勾配を取得して平滑化
feature_grad = grad_cam.feature_grad.cpu().data.numpy()[0]
weights = np.mean(feature_grad, axis=(1, 2))
# 重みを特徴マップに掛け合わせ後にReLUを適用
feature_map = grad_cam.feature_map.cpu().data.numpy()
cam = np.sum((weights * feature_map.T), axis=2).T
cam = np.maximum(cam, 0)
# hookの初期化
grad_cam.clear_hook()

# seaornでヒートマップの表示
import seaborn as sns
sns.heatmap(cam)

f:id:ymym3412:20190319013220p:plain

ヒートマップで可視化してみると、冒頭に強く反応している部分があります。Grad-CAMの数値が高い上位10個の3-gramを抽出してみます。

『プロメテウス』
『グラディエーター』
『エイリアン』
824公開
弾が公開
また、劇場
』『グラディエーター
3d作品
エイリアン』『
』の巨匠

このデータでは「movie-enter」という映画に関するメディアの記事を正しく予測できたため、特に映画に関する単語に強く反応して予測を行えたことが分かりました。
また3, 4, 5, 7-gram全てのGrad-CAMのスコアを足し合わせて平均を取ったヒートマップも作成してみます。

f:id:ymym3412:20190319021710p:plain

全体的にヒートマップがぼやける形になりましたが3-gramの時も強く反応していた冒頭の部分は依然として高い値を保持してることが分かります。

続いて予測が失敗したケース、「dokujo-tsushin」を「peachy」と誤予測してしまった場合の判断根拠を可視化します。前者はエンタメや芸能、後者は主に食べ物系について触れているメディアです。

f:id:ymym3412:20190319013956p:plain

こちらでも一部のブロックに強く反応している部分があります。こちらもGrad-CAMの数値が高い上位10個の3-gramを抽出してみます。

「独通信
は独、
未婚女性のみ
未婚女性の
シングル女性の
独身女性が
の恋や
の恋や
と恋に
女性の恋

どうやら「女性」に関する単語に強く反応して「dokujo-tsushin」を「peachy」と誤予測しているようです。peachyでは女性向けの食事の記事などの割合が多いのでしょうか?
これに基づいてコーパスをもっと深く分析して、前処理にさらなる工夫を加えることができそうです。

まとめ

NLPのDNNモデルの判断根拠の可視化のために、画像分野でよく使われているGrad-CAMを使いました。
CNN系のモデルは画像分野でもよく使われているので、ノウハウを輸入するのが低コストでいいですね。
NLPで解釈性というとAttentionがありますが、Attention層に対する勾配ベースの解釈性の導入はすでに研究として行われつつあります。
Interpreting Recurrent and Attention-Based Neural Models: a Case Study on Natural Language Inference

Attentionのヒートマップではモデルの解釈には不十分で、勾配ベースのスコアなら解釈として妥当というのは自分の中でまだもやっとしているので、ここらへん詳しい方がいらっしゃればTwitterなどでコメント頂けると嬉しいです。

参考文献