attention_heatmap

class aitoolbox.nlp.experiment_evaluation.attention_heatmap.AttentionHeatMap(attention_matrices, source_sentences, target_sentences, plot_save_dir)[source]

Bases: AbstractBaseMetric

Neural attention heatmap plotting

Parameters:
  • attention_matrices (numpy.array or list) – list of attention 2D matrices

  • source_sentences (list) – list of corresponding source sentence text tokens

  • target_sentences (list) – list of corresponding target sentence text tokens

  • plot_save_dir (str) – folder path on local drive where the plots should be saved

calculate_metric()[source]

Perform metric calculation and return it from this function

Returns:

return metric_result

Return type:

float or dict

static plot_sentence_attention(attention_matrix, sentence_source, sentence_target, plot_file_path=None)[source]

Plot the provided attention matrix

Parameters:
  • attention_matrix (np.array) – 2D attention matrix

  • sentence_source (list) – corresponding source sentence text tokens

  • sentence_target (list) – corresponding target sentence text tokens

  • plot_file_path (str) – local drive file path where to save the plotted attention matrix heatmap

Returns:

None

static prepare_folder_for_saving(output_plot_dir)[source]

Create attention heatmaps local folder where the heatmaps will be saved

Parameters:

output_plot_dir (str) –

Returns:

path to the created folder

Return type:

str