Source code for discopat.nn_training.detr.collate
from discopat.nn_training.detr.nested_tensor import (
nested_tensor_from_tensor_list,
)
[docs]
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)