Source code for discopat.nn_training.torch_detection_utils.coco_utils

import torch
import torch.utils.data
import torchvision
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO


[docs] def convert_to_coco_api(ds): coco_ds = COCO() # annotation IDs need to start at 1, not 0, see torchvision issue #1530 ann_id = 1 dataset = {"images": [], "categories": [], "annotations": [], "info": ""} categories = set() for img_idx in range(len(ds)): # find better way to get target # targets = ds.get_annotations(img_idx) img, targets = ds[img_idx] image_id = targets["image_id"] if isinstance(image_id, torch.Tensor): image_id = image_id.item() img_dict = {} img_dict["id"] = image_id img_dict["height"] = img.shape[-2] img_dict["width"] = img.shape[-1] dataset["images"].append(img_dict) bboxes = targets["boxes"].clone() bboxes[:, 2:] -= bboxes[:, :2] bboxes = bboxes.tolist() labels = targets["labels"].tolist() areas = targets["area"].tolist() iscrowd = targets["iscrowd"].tolist() if "masks" in targets: masks = targets["masks"] # make masks Fortran contiguous for coco_mask masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) if "keypoints" in targets: keypoints = targets["keypoints"] keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() num_objs = len(bboxes) for i in range(num_objs): ann = {} ann["image_id"] = image_id ann["bbox"] = bboxes[i] ann["category_id"] = labels[i] categories.add(labels[i]) ann["area"] = areas[i] ann["iscrowd"] = iscrowd[i] ann["id"] = ann_id if "masks" in targets: ann["segmentation"] = coco_mask.encode(masks[i].numpy()) if "keypoints" in targets: ann["keypoints"] = keypoints[i] ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) dataset["annotations"].append(ann) ann_id += 1 dataset["categories"] = [{"id": i} for i in sorted(categories)] coco_ds.dataset = dataset coco_ds.createIndex() return coco_ds
[docs] def get_coco_api_from_dataset(dataset): # FIXME: This is... awful? for _ in range(10): if isinstance(dataset, torchvision.datasets.CocoDetection): break if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset if isinstance(dataset, torchvision.datasets.CocoDetection): return dataset.coco return convert_to_coco_api(dataset)