discopat.datasets.torch

Classes

TorchBoxDataset(frame_list, label_map[, ...])

TorchDataset(frame_list, label_map[, ...])

TorchKeypointDataset(frame_list, label_map)

class discopat.datasets.torch.TorchBoxDataset(frame_list, label_map, transforms=None, channel_mode='channels_first')[source]

Bases: TorchDataset

Parameters:
  • frame_list (list[Frame])

  • label_map (dict[str, int])

  • transforms (T.Transform or None)

  • channel_mode (str)

make_target(frame)[source]
Parameters:

frame (Frame)

Return type:

dict[str, int]

class discopat.datasets.torch.TorchDataset(frame_list, label_map, transforms=None, channel_mode='channels_first')[source]

Bases: Dataset

Parameters:
  • frame_list (list[Frame])

  • label_map (dict[str, int])

  • transforms (T.Transform or None)

  • channel_mode (str)

abstractmethod make_target(frame)[source]
Parameters:

frame (Frame)

Return type:

dict[str, Tensor]

prepare_image_tensor(frame)[source]
Parameters:

frame (Frame)

Return type:

Tensor

class discopat.datasets.torch.TorchKeypointDataset(frame_list, label_map, box_w_padding=0.5, box_h_padding=0.5)[source]

Bases: TorchDataset

Parameters:
  • frame_list (list[Frame])

  • label_map (dict[str, int])

  • box_w_padding (float)

  • box_h_padding (float)

make_target(frame)[source]
Parameters:

frame (Frame)

Return type:

dict[str, int]