Source code for discopat.core.entities.annotation

from __future__ import annotations

from abc import abstractmethod

from typing_extensions import Self

from discopat.core.entities.metadata import Metadata


[docs] class Annotation(Metadata): """Abstract class to represent annotations and model predictions.""" def __init__(self, label: str, score: float): """Initialise the object. Args: label (str): Name of the object modelled by the annotation. score (float): Confidence score for predictions (in [0, 1]). For annotations, the score is 1. For negative annotations (e.g., a caption saying that there is no corresponding object in the image), the score is 0. """ self.label = label self.score = float(score)
[docs] @abstractmethod def rescale(self, w_ratio: float, h_ratio: float) -> None: """Rescale object. Args: w_ratio (float): Width ratio. h_ratio (float): Height ratio. """
@property def type(self) -> str: """Handle for annotation types (box, keypoint, track, ...).""" return type(self).__name__.lower()
[docs] @classmethod def printable_fields(cls) -> list[str]: """List of the relevant fields to serialise the object.""" return ["label", "score"]
[docs] def to_dict(self) -> dict: """Serialise object to a dictionary.""" output = super().to_dict() return {"type": self.type, **output}
[docs] class Box(Annotation): """Class to represent bounding boxes.""" def __init__( self, label: str, x: float, y: float, width: float, height: float, score: float, ): """Initialise the bounding box. Args: label (str): Name of the object localised by the box (e.g., cat). x (float): X-position of the top-left corner. y (float): Y-position of the top-left corner. width (float): Width of the bounding box. height (float): Height of the bounding box. score (float): Confidence score of the detection (in [0, 1]). """ super().__init__(label, score) self.x = float(x) self.y = float(y) self.width = float(width) self.height = float(height) @property def xmin(self) -> float: """Xmin for XYXY format.""" return self.x @property def xmax(self) -> float: """Xmax for XYXY format.""" return self.x + self.width @property def ymin(self) -> float: """Ymin for XYXY format.""" return self.y @property def ymax(self) -> float: """Ymax for XYXY format.""" return self.y + self.height
[docs] @classmethod def printable_fields(cls) -> list[str]: """List of the relevant fields to serialise the object.""" output = super().printable_fields() return [*output, "x", "y", "width", "height"]
[docs] @classmethod def from_dict(cls, data_as_dict: dict) -> Self: """Make object from a dictionary.""" init_params = { k: data_as_dict[k] for k in cls.printable_fields() if k != "type" } return cls(**init_params)
[docs] def rescale(self, w_ratio: float, h_ratio: float) -> None: """Rescale object. Args: w_ratio (float): Width ratio. h_ratio (float): Height ratio. """ self.x = self.x * w_ratio self.y = self.y * h_ratio self.width = self.width * w_ratio self.height = self.height * h_ratio
[docs] class Keypoint(Annotation): """Class to represent keypoint annotations (e.g., for pose estimation).""" def __init__( self, label: str, point_list: list[tuple[float, float]], score: float, ): """Initialise the keypoint object. Args: label (str): Name of the object localised by the keypoint object. point_list (list): List of points for the keypoint annotation in the format [(x0, y0), (x1, y1), ...]. score (float): Confidence of the detection. """ super().__init__(label, score) self.point_list = [ (float(coord[0]), float(coord[1])) for coord in point_list ]
[docs] @classmethod def printable_fields(cls) -> list[str]: """List of the relevant fields to serialise the object.""" output = super().printable_fields() return [*output, "point_list"]
[docs] @classmethod def from_dict(cls, data_as_dict: dict) -> Self: """Make object from a dictionary.""" label = data_as_dict["label"] point_list = [tuple(point) for point in data_as_dict["point_list"]] score = data_as_dict["score"] return cls(label, point_list, score)
[docs] def rescale(self, w_ratio: float, h_ratio: float) -> None: """Rescale object. Args: w_ratio (float): Width ratio. h_ratio (float): Height ratio. """ self.point_list = [ (x * w_ratio, y * h_ratio) for x, y in self.point_list ]
[docs] class Track(Annotation): """Class to represent object tracks accross frames in a movie.""" def __init__(self, track_id: int, box_list: list[tuple[int, Box]]): """Initialise the track. Args: track_id (int): Identifier of the track. box_list (list): List of tuples (frame_id, box) linking the frames where the tracked object appears to its position in these frames. """ self.track_id = track_id self.box_list = box_list
ANNOTATION_TYPE_DICT = {"box": Box, "keypoint": Keypoint, "track": Track}
[docs] def annotation_factory(annotation_as_dict: dict) -> Annotation: annotation_type = annotation_as_dict["type"] annotation_class = ANNOTATION_TYPE_DICT[annotation_type] return annotation_class.from_dict(annotation_as_dict)