Source code for discopat.nn_training.detr.collate

from discopat.nn_training.detr.nested_tensor import (
    nested_tensor_from_tensor_list,
)


[docs] def collate_fn(batch): batch = list(zip(*batch)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch)