Source code for discopat.display

from __future__ import annotations

import time
from typing import TYPE_CHECKING

import imageio
import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw

from discopat.repositories.local import DISCOPATH

if TYPE_CHECKING:
    from matplotlib.axes._axes import Axes

    from discopat.core import Annotation, Box, Frame, Keypoint, Movie


[docs] def to_int(image_array: np.array, eps: float = 1e-10) -> np.array: """Convert array to int [0, 255]. Warning: values will be scaled even if the input array is already of type int. """ min_val = image_array.min() max_val = image_array.max() return ( np.round((image_array - min_val) / (max_val - min_val + eps) * 255) ).astype(np.uint8)
[docs] def get_center_path(track: np.array) -> np.array: """Get the trajectory of the center of the boxes corresponding to one track. Args: track: Array of shape (num_frames, 5) where each line corresponds to a box: - track[:, 0] = id of the frame, - track[:, 1] = xmin, - track[:, 2] = ymin, - track[:, 3] = xmax, - track[:, 4] = ymax. Returns: Array of shape (num_frame, 2), where each line corresponds to the (x, y) coordinates of the center of the box. """ xmin_array = track[:, 1] ymin_array = track[:, 2] xmax_array = track[:, 3] ymax_array = track[:, 4] id_col = np.expand_dims(track[:, 0], axis=1) x_col = np.expand_dims((xmin_array + xmax_array) / 2, axis=1) y_col = np.expand_dims((ymin_array + ymax_array) / 2, axis=1) return np.hstack([id_col, x_col, y_col])
[docs] def frame_to_pil( frame: Frame, tracks: np.array, max_track_length: int = 0, persistence: int = 10, cmap: str = "gray", track_color: str = "red", ) -> Image: """Make a PIL image from a frame.""" i = int(frame.name) color_map = mpl.colormaps.get_cmap(cmap) image_array = to_int(color_map(to_int(frame.image_array))) pil_image = Image.fromarray(image_array) pil_image = pil_image.convert("RGB") draw = ImageDraw.Draw(pil_image) for track in tracks.values(): if int(track[-1, 0]) < i - persistence: continue current_track = track[track[:, 0] <= i] if len(current_track) <= 1: continue center_path = get_center_path(current_track) draw.line( xy=[(pos[1], pos[2]) for pos in center_path[-max_track_length:]], fill=track_color, width=3, ) for box in frame.annotations: draw.rectangle( [box.xmin, box.ymin, box.xmax, box.ymax], outline=track_color ) return pil_image
[docs] def make_movie( movie: Movie, persistence: int, cmap: str, track_color: str, fps: int, output_format: str, ): mpl.use("agg") time_str = time.strftime("%y%m%d_%H%M%S") movie_path = DISCOPATH / f"misc/{movie.name}_{time_str}.{output_format}" with imageio.get_writer(movie_path, fps=fps) as writer: for frame in movie.frames: writer.append_data( np.array( frame_to_pil( frame, movie.tracks, persistence=persistence, cmap=cmap, track_color=track_color, ) ) )
[docs] def plot_frame( # noqa: RET503 frame: Frame, cmap: str = "gray", annotation_color: str = "tab:red", show_figure: bool = True, return_figure: bool = False, figure_size: tuple[float, float] or None = None, figure_dpi: int or None = None, ): mpl.use("inline") image_array = frame.image_array fig, ax = plt.subplots(1, 1, figsize=figure_size, dpi=figure_dpi) fig.subplots_adjust(top=1, bottom=0, left=0, right=1) ax.imshow(image_array, cmap=cmap) ax.axis("off") for annotation in frame.annotations: plot_annotation(ax, annotation, color=annotation_color) if show_figure: plt.show() if return_figure: return fig
[docs] def plot_annotation(ax: Axes, annotation: Annotation, color: str): annotation_type_dict = {"box": plot_box, "keypoint": plot_keypoint} plot_function = annotation_type_dict[annotation.type] plot_function(ax, annotation, color)
[docs] def plot_box(ax: Axes, box: Box, color: str): ax.add_patch( plt.Rectangle( xy=(box.xmin, box.ymin), width=box.width, height=box.height, edgecolor=color, facecolor="none", ) )
[docs] def plot_keypoint(ax: Axes, keypoint: Keypoint, color: str): point_list = keypoint.point_list for i, point_1 in enumerate(point_list[:-1]): point_2 = point_list[i + 1] x1, y1 = point_1 x2, y2 = point_2 ax.plot([x1, x2], [y1, y2], color=color)