Source code for aitoolbox.nlp.experiment_evaluation.attention_heatmap

import os
import shutil
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.font_manager import FontProperties
import seaborn as sns

from aitoolbox.experiment.core_metrics.abstract_metric import AbstractBaseMetric


[docs]class AttentionHeatMap(AbstractBaseMetric): def __init__(self, attention_matrices, source_sentences, target_sentences, plot_save_dir): """Neural attention heatmap plotting Args: attention_matrices (numpy.array or list): list of attention 2D matrices source_sentences (list): list of corresponding source sentence text tokens target_sentences (list): list of corresponding target sentence text tokens plot_save_dir (str): folder path on local drive where the plots should be saved """ if len(attention_matrices) != len(source_sentences) != len(target_sentences): raise ValueError(f'Lengths of attention_matrices, source_sentences and target_sentences are not the same. ' f'Their lengths are: {len(attention_matrices)}, {len(source_sentences)}, {len(target_sentences)}') self.attention_matrices = attention_matrices self.source_sentences = source_sentences self.target_sentences = target_sentences self.plot_save_dir = plot_save_dir AbstractBaseMetric.__init__(self, None, None, metric_name='Attention_HeatMap', np_array=False)
[docs] def calculate_metric(self): dir_path = self.prepare_folder_for_saving(self.plot_save_dir) output_plot_paths = [] for i, (attn_matrix, source_sent, target_sent) in \ enumerate(zip(self.attention_matrices, self.source_sentences, self.target_sentences)): plot_file_path = os.path.join(dir_path, f'attn_plot_{i}.png') output_plot_paths.append(plot_file_path) attn_matrix = attn_matrix[:len(target_sent)] self.plot_sentence_attention(attn_matrix, source_sent, target_sent, plot_file_path) return output_plot_paths
[docs] @staticmethod def plot_sentence_attention(attention_matrix, sentence_source, sentence_target, plot_file_path=None): """Plot the provided attention matrix Args: attention_matrix (np.array): 2D attention matrix sentence_source (list): corresponding source sentence text tokens sentence_target (list): corresponding target sentence text tokens plot_file_path (str): local drive file path where to save the plotted attention matrix heatmap Returns: None """ # alpha_arr /= np.max(np.abs(alpha_arr),axis=0) fig = plt.figure() fig.set_size_inches(8, 8) gs = gridspec.GridSpec(2, 2, width_ratios=[12, 1], height_ratios=[12, 1]) ax = plt.subplot(gs[0]) ax_c = plt.subplot(gs[1]) cmap = sns.light_palette((200, 75, 60), input="husl", as_cmap=True) # prop = FontProperties(fname='fonts/IPAfont00303/ipam.ttf', size=12) ax = sns.heatmap(attention_matrix, xticklabels=sentence_source, yticklabels=sentence_target, ax=ax, cmap=cmap, cbar_ax=ax_c) ax.xaxis.tick_top() ax.yaxis.tick_right() ax.set_xticklabels(sentence_target, minor=True, rotation=60, size=12) for label in ax.get_xticklabels(minor=False): label.set_fontsize(12) # label.set_font_properties(prop) for label in ax.get_yticklabels(minor=False): label.set_fontsize(12) label.set_rotation(-90) label.set_horizontalalignment('left') ax.set_xlabel("Source", size=20) ax.set_ylabel("Hypothesis", size=20) if plot_file_path: fig.savefig(plot_file_path, format="png") plt.close()
[docs] @staticmethod def prepare_folder_for_saving(output_plot_dir): """Create attention heatmaps local folder where the heatmaps will be saved Args: output_plot_dir (str): Returns: str: path to the created folder """ if os.path.exists(output_plot_dir): shutil.rmtree(output_plot_dir) os.mkdir(output_plot_dir) dir_path = os.path.join(output_plot_dir, 'attention_heatmaps') os.mkdir(dir_path) return dir_path