torch_collate_fns
- aitoolbox.nlp.torch_collate_fns.qa_concat_ctx_span_collate_fn(data)[source]
QA system collate function
Creates mini-batch tensors from the list of tuples (src_seq, trg_seq). We should build a custom collate_fn rather than using default collate_fn, because merging sequences (including padding) is not supported in default. Sequences are padded to the maximum length of mini-batch sequences (dynamic padding).
- Parameters:
data – list of tuple (src_seq, trg_seq). - src_seq: torch tensor of shape (?); variable length. - trg_seq: torch tensor of shape (?); variable length.
- Returns:
torch tensor of shape (batch_size, padded_length). src_lengths: list of length (batch_size); valid length for each padded source sequence. trg_seqs: torch tensor of shape (batch_size, padded_length). trg_lengths: list of length (batch_size); valid length for each padded target sequence.
- Return type:
src_seqs