Grad-CAMを使ったNLPモデルの判断根拠の可視化
機械学習モデルの解釈性は業務で使う上ではなぜそのような予測を行ったかの判断根拠の可視化として、また学習させたモデルをデバックする際にどんな入力に反応して誤予測を引き起こしてしまったか分析する上で非常に重要な要素です。
画像分野ではGrad-CAMと呼ばれる勾配を使った予測根拠の可視化手法が提案されており、今回はその手法を使ってNLP向けのCNNモデルの判断根拠を可視化していきます。
実験で使用したノートブックはGithub上で公開しています。
機械学習モデルの解釈性
機械学習モデルに対する解釈性は近年では特に重要なトピックです。例えば
といったときに機械学習モデルの解釈性は必要になります。
機械学習モデルの解釈性についてはステアラボ人工知能セミナーでの原聡先生の資料がとても分かりやすいです。
機械学習の解釈性には「大域的な説明(Global Interpretability)」と「局所的な説明(Local Interpretability)」のふたつに大きく分けられます。
大域的な説明
大域的な説明は複雑なモデルを決定木や線形回帰といった解釈が容易なモデルで近似することでモデルを説明する方法です。
説明したいモデルの全体を解釈しやすいモデルで近似することで、モデルがどのように予測を行うかというモデルの内部を説明しているのが特徴です。
局所的な説明
モデル全体を説明する大域的な説明とは異なり、特定の入力に対する予測結果の説明を行うのが局所的な説明です。
これは入力のどの特徴量、あるいはどの訓練データによってその予測が行われたかの根拠を提示するために用いられます。
有名な手法では「LIME(Local Interpretable Model-agnostic Explanations)」があります。
LIMEは説明を行いたい予測への入力データに摂動を加えたデータセットを生成し、それに対して説明を行いたいモデルがどのように振る舞うかを解釈可能なモデル (e.g. 決定木やLasso)で近似することで、特定の予測に対する特徴量の寄与度を測る方法です。
説明したいモデルを、近似するモデルを、説明した入力データxと摂動を加えたデータ間の距離を測る関数、近似するモデルの複雑度を表す関数を使って以下の式を求めることでLIMEによって説明を行うモデルを構築します。
深層学習のモデルの説明には局所的な説明を行うものが多く、後述するGrad-CAMも局所的な説明を行う手法の一種です。
Grad-CAM
Grad-CAMは画像認識の分野で使われている、分類の根拠を提示する局所的な説明手法です。
Grad-CAMではVGGやResNetのようにConvolutionやpoolingを繰り返し最後に全結合層に接続してクラス分類を行うようなモデルに対して、全結合層の前のConvolution層で生成された特徴マップが、予測したラベルに対してどれくらい影響を与えているかを以下のように勾配を使って計算します。
式(1)では、クラスに対して特徴マップの各要素が微小に変化した際にクラスの確率がどれくらい変化するかを計算して特徴マップないで平滑化を行い、そして式(2)でその値を使って特徴マップ領域内の重要度を計算しています。
これにより、モデルの予測に対して入力画像のどの領域が影響を与えるかを計算して可視化することが可能となります。
またこれは予測に対する各特徴マップの勾配が計算できれば良いため、画像分類に限らず、画像キャプション生成やVisual Question Answeringなどにも利用することができます。
CNNの分類モデルの判断根拠の可視化
では画像分野で使われているモデルの説明手法をどうやってNLPで使うのか。
答えはシンプルで、NLP向けのCNNをベースにしたモデルを使えばよいのです。
幸い、以前に日本語の記事のカテゴリ分類を行うCNNモデルを作ったので、それを使ってGrad-CAMによる予測の判断根拠の可視化を行います。
学習
モデルは以前使用したKim[2014]のCNNによる文書分類モデルに少し改良を加えたものを使用します。
import torch import torch.nn as nn import torch.nn.functional as F class CNN_Text(nn.Module): def __init__(self, pretrained_wv, output_dim, kernel_num, kernel_sizes=[3,4,5], dropout=0.5, static=False): super(CNN_Text,self).__init__() weight = torch.from_numpy(pretrained_wv) self.embed = nn.Embedding.from_pretrained(weight, freeze=False) self.convs1 = nn.ModuleList([nn.Conv2d(1, kernel_num, (k, self.embed.weight.shape[1])) for k in kernel_sizes]) self.bns1 = nn.ModuleList([nn.BatchNorm2d(kernel_num) for _ in kernel_sizes]) self.dropout = nn.Dropout(dropout) self.fc1 = nn.Linear(len(kernel_sizes)*kernel_num, output_dim) self.static = static def conv_and_pool(self, x, conv): x = F.relu(conv(x)).squeeze(3) #(N,Co,W) x = F.max_pool1d(x, x.size(2)).squeeze(2) return x def forward(self, x): x = self.embed(x) # (N,W,D) if self.static: x = x.detach() x = x.unsqueeze(1) # (N,Ci,W,D) x = x.float() x = [F.relu(bn(conv(x))).squeeze(3) for conv, bn in zip(self.convs1, self.bns1)] #[(N,Co,W), ...]*len(Ks) x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks) x = torch.cat(x, 1) x = self.dropout(x) # (N,len(Ks)*Co) logit = self.fc1(x) # (N,C) return logit
最適化にはAdaBoundを使用し、400epoch学習させたあとはSGDで学習率を下げながら学習させました。
from tensorboardX import SummaryWriter import torch import numpy as np import adabound from sklearn.metrics import accuracy_score def calc_accuray(model, target_preprocessed, writer, ite, mode, use_cuda): model.eval() feature = torch.LongTensor(target_preprocessed['article']) if use_cuda: feature = feature.cuda() forward = model(feature) predicted_label = forward.argmax(dim=1).cpu() test_target = torch.LongTensor(np.argmax(target_preprocessed['label'], axis=1)) accuracy = accuracy_score(test_target.numpy(), predicted_label.numpy()) writer.add_scalar('data/{}_accuracy'.format(mode), accuracy, ite) model.train() return accuracy writer = SummaryWriter() output_dim = 9 kernel_num = 200 kernel_sizes = [3,4,5,7] dropout = 0.5 model = CNN_Text(embedding, output_dim, kernel_num, kernel_sizes, dropout) use_cuda = True opt = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1, weight_decay=5e-4) # opt = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1) model.train() if use_cuda: model = model.cuda() for ite, b in enumerate(dp.iterate(preprocessed, batch_size=64, epoch=400)): feature = torch.LongTensor(b['article']) target = torch.LongTensor(np.argmax(b['label'], axis=1)) if use_cuda: feature = feature.cuda() target = target.cuda() opt.zero_grad() logit = model(feature) loss = F.nll_loss(F.log_softmax(logit), target) loss.backward() opt.step() writer.add_scalar('data/training_loss', loss.item(), ite) # check training accuray calc_accuray(model, b, writer, ite, 'training', use_cuda) # check validation accuracy if ite % 100 == 0: # calc training accuracy calc_accuray(model, val_preprocessed, writer, ite, 'validation', use_cuda) writer.close()
前処理の工夫やモデルの改善により、以前の実験より大幅にtestセットでの性能が改善しました。
precision recall f1-score support 0 0.88 0.82 0.85 131 1 0.75 0.91 0.82 131 2 0.81 0.74 0.77 130 3 0.74 0.62 0.68 77 4 0.94 0.91 0.92 130 5 0.75 0.90 0.82 126 6 0.99 0.84 0.91 131 7 0.93 0.93 0.93 135 8 0.84 0.88 0.86 115 micro avg 0.85 0.85 0.85 1106 macro avg 0.85 0.84 0.84 1106 weighted avg 0.86 0.85 0.85 1106
accuracy: 0.849005424954792
Grad-CAMによる可視化
学習したCNNモデルの予測根拠をGrad-CAMを使って可視化していきます。
PyTorchでのGrad-CAMの実装はCNNを使った分類問題の判断根拠(画像編)のコードをお借りしました。
class GradCAM: def __init__(self, model, feature_layer): self.model = model self.feature_layer = feature_layer self.model.eval() self.feature_grad = None self.feature_map = None self.hooks = [] def save_feature_grad(module, in_grad, out_grad): self.feature_grad = out_grad[0] self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad)) def save_feature_map(module, inp, outp): self.feature_map = outp[0] self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map)) def forward(self, x): return self.model(x) def backward_on_target(self, output, target): self.model.zero_grad() one_hot_output = torch.zeros([1, output.size()[-1]]).cuda() one_hot_output[0][target] = 1 output.backward(gradient=one_hot_output, retain_graph=True) def clear_hook(self): for hook in self.hooks: hook.remove()
まずは予測ラベルがあっていた場合に、どの単語が予測に影響を与えているかを見ていきます。
モデルはカーネルサイズを3, 4, 5, 7で設定しており、文書中の3-gram, 4-gram, 5-gram, 7-gramの関係を見て畳込んでいます。まずは3-gram単位で重要度のヒートマップを出してみます。
news_id = 20 # Grad-CAMのインスタンスを作成 grad_cam = GradCAM(model=model, feature_layer=model.convs1[0]) # testセットの準備 test_input = test_preprocessed['article'][news_id:news_id+1] test_tensor = torch.LongTensor(test_input).cuda() test_target = torch.LongTensor(np.argmax(test_preprocessed['label'][news_id:news_id+1], axis=1)) # モデルでの順伝搬 model_output = grad_cam.forward(test_tensor) predicted_label = model_output.argmax().item() # 予測したラベルに対する逆伝搬を計算 grad_cam.backward_on_target(model_output, predicted_label) # 各特徴マップの要素に対する勾配を取得して平滑化 feature_grad = grad_cam.feature_grad.cpu().data.numpy()[0] weights = np.mean(feature_grad, axis=(1, 2)) # 重みを特徴マップに掛け合わせ後にReLUを適用 feature_map = grad_cam.feature_map.cpu().data.numpy() cam = np.sum((weights * feature_map.T), axis=2).T cam = np.maximum(cam, 0) # hookの初期化 grad_cam.clear_hook() # seaornでヒートマップの表示 import seaborn as sns sns.heatmap(cam)
ヒートマップで可視化してみると、冒頭に強く反応している部分があります。Grad-CAMの数値が高い上位10個の3-gramを抽出してみます。
『プロメテウス』 『グラディエーター』 『エイリアン』 824公開 弾が公開 また、劇場 』『グラディエーター 3d作品 エイリアン』『 』の巨匠
このデータでは「movie-enter」という映画に関するメディアの記事を正しく予測できたため、特に映画に関する単語に強く反応して予測を行えたことが分かりました。
また3, 4, 5, 7-gram全てのGrad-CAMのスコアを足し合わせて平均を取ったヒートマップも作成してみます。
全体的にヒートマップがぼやける形になりましたが3-gramの時も強く反応していた冒頭の部分は依然として高い値を保持してることが分かります。
続いて予測が失敗したケース、「dokujo-tsushin」を「peachy」と誤予測してしまった場合の判断根拠を可視化します。前者はエンタメや芸能、後者は主に食べ物系について触れているメディアです。
こちらでも一部のブロックに強く反応している部分があります。こちらもGrad-CAMの数値が高い上位10個の3-gramを抽出してみます。
「独通信 は独、 未婚女性のみ 未婚女性の シングル女性の 独身女性が の恋や の恋や と恋に 女性の恋
どうやら「女性」に関する単語に強く反応して「dokujo-tsushin」を「peachy」と誤予測しているようです。peachyでは女性向けの食事の記事などの割合が多いのでしょうか?
これに基づいてコーパスをもっと深く分析して、前処理にさらなる工夫を加えることができそうです。
まとめ
NLPのDNNモデルの判断根拠の可視化のために、画像分野でよく使われているGrad-CAMを使いました。
CNN系のモデルは画像分野でもよく使われているので、ノウハウを輸入するのが低コストでいいですね。
NLPで解釈性というとAttentionがありますが、Attention層に対する勾配ベースの解釈性の導入はすでに研究として行われつつあります。
Interpreting Recurrent and Attention-Based Neural Models: a Case Study on Natural Language Inference
Attentionのヒートマップではモデルの解釈には不十分で、勾配ベースのスコアなら解釈として妥当というのは自分の中でまだもやっとしているので、ここらへん詳しい方がいらっしゃればTwitterなどでコメント頂けると嬉しいです。
参考文献
- 機械学習モデルの判断根拠の説明
- 機械学習と解釈可能性 / Machine Learning and Interpretability
- Interpretable Machine Learning A Guide for Making Black Box Models Explainable
- “Why Should I Trust You?” Explaining the Predictions of Any Classifier
- Convolutional Neural Networks for Sentence Classification
- Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
- CNNを使った分類問題の判断根拠(画像編)| JXPRESS developer's blog
- 深層学習は画像のどこを見ている!? CNNで「お好み焼き」と「ピザ」の違いを検証 | Pratinum Data Blog