torch_collate_fns

aitoolbox.nlp.torch_collate_fns.qa_concat_ctx_span_collate_fn(data)[source]

QA system collate function

Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Sequences are padded to the maximum length of mini-batch sequences (dynamic padding).

Parameters:

data – list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length.

Returns:

torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence.

Return type:

src_seqs