Source code for aitoolbox.nlp.torch_collate_fns

import torch


[docs]def qa_concat_ctx_span_collate_fn(data): """QA system collate function Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Sequences are padded to the maximum length of mini-batch sequences (dynamic padding). Args: data: list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length. Returns: src_seqs: torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence. """ def merge(sequences): lengths = [len(seq) for seq in sequences] padded_seqs = torch.zeros(len(sequences), max(lengths)).long() for i, seq in enumerate(sequences): end = lengths[i] padded_seqs[i, :end] = seq[:end] return padded_seqs, torch.LongTensor(lengths) paragraph, question, span = list(zip(*data)) paragraph_pad, paragraph_lengths = merge(paragraph) question_pad, question_lengths = merge(question) return paragraph_pad, paragraph_lengths, question_pad, question_lengths, torch.LongTensor(span)