Source code for discopat.display

from __future__ import annotations

import time
from typing import TYPE_CHECKING

import imageio
import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
from tqdm import tqdm

from discopat.repositories.local import DISCOPATH

if TYPE_CHECKING:
    from matplotlib.axes._axes import Axes

    from discopat.core import Annotation, Box, Frame, Keypoint, Movie


[docs] def to_int(image_array: np.array, eps: float = 1e-10) -> np.array: """Convert array to int [0, 255]. Warning: values will be scaled even if the input array is already of type int. """ min_val = image_array.min() max_val = image_array.max() return ( np.round((image_array - min_val) / (max_val - min_val + eps) * 255) ).astype(np.uint8)
[docs] def get_center_path(track: np.array) -> np.array: """Get the trajectory of the center of the boxes corresponding to one track. Args: track: Array of shape (num_frames, 5) where each line corresponds to a box: - track[:, 0] = id of the frame, - track[:, 1] = xmin, - track[:, 2] = ymin, - track[:, 3] = xmax, - track[:, 4] = ymax. Returns: Array of shape (num_frame, 2), where each line corresponds to the (x, y) coordinates of the center of the box. """ xmin_array = track[:, 1] ymin_array = track[:, 2] xmax_array = track[:, 3] ymax_array = track[:, 4] id_col = np.expand_dims(track[:, 0], axis=1) x_col = np.expand_dims((xmin_array + xmax_array) / 2, axis=1) y_col = np.expand_dims((ymin_array + ymax_array) / 2, axis=1) return np.hstack([id_col, x_col, y_col])
[docs] def frame_to_pil( frame: Frame, tracks: np.array, max_track_length: int = 0, persistence: int = 10, cmap: str = "gray", track_color: str = "red", ) -> Image: """Make a PIL image from a frame.""" i = int(frame.name) color_map = mpl.colormaps.get_cmap(cmap) image_array = to_int(color_map(to_int(frame.image_array))) pil_image = Image.fromarray(image_array) pil_image = pil_image.convert("RGB") draw = ImageDraw.Draw(pil_image) for track in tracks.values(): if int(track[-1, 0]) < i - persistence: continue current_track = track[track[:, 0] <= i] if len(current_track) <= 1: continue center_path = get_center_path(current_track) draw.line( xy=[(pos[1], pos[2]) for pos in center_path[-max_track_length:]], fill=track_color, width=3, ) for box in frame.annotations: draw.rectangle( [box.xmin, box.ymin, box.xmax, box.ymax], outline=track_color ) return pil_image
[docs] def make_movie( movie: Movie, persistence: int, cmap: str, track_color: str, fps: int, output_format: str, ): mpl.use("agg") time_str = time.strftime("%y%m%d_%H%M%S") movie_path = DISCOPATH / f"misc/{movie.name}_{time_str}.{output_format}" with imageio.get_writer(movie_path, fps=fps) as writer: for frame in tqdm(movie.frames): writer.append_data( np.array( frame_to_pil( frame, movie.tracks, persistence=persistence, cmap=cmap, track_color=track_color, ) ) )
[docs] def plot_frame( # noqa: RET503 frame: Frame, cmap: str = "gray", annotation_color: str = "tab:red", show_figure: bool = True, return_figure: bool = False, figure_size: tuple[float, float] or None = None, figure_dpi: int or None = None, ): mpl.use("inline") image_array = frame.image_array fig, ax = plt.subplots(1, 1, figsize=figure_size, dpi=figure_dpi) fig.subplots_adjust(top=1, bottom=0, left=0, right=1) ax.imshow(image_array, cmap=cmap) ax.axis("off") for annotation in frame.annotations: plot_annotation(ax, annotation, color=annotation_color) if show_figure: plt.show() if return_figure: return fig
[docs] def plot_annotation(ax: Axes, annotation: Annotation, color: str): annotation_type_dict = {"box": plot_box, "keypoint": plot_keypoint} plot_function = annotation_type_dict[annotation.type] plot_function(ax, annotation, color)
[docs] def plot_box(ax: Axes, box: Box, color: str): ax.add_patch( plt.Rectangle( xy=(box.xmin, box.ymin), width=box.width, height=box.height, edgecolor=color, facecolor="none", ) )
[docs] def plot_keypoint(ax: Axes, keypoint: Keypoint, color: str): point_list = keypoint.point_list for i, point_1 in enumerate(point_list[:-1]): point_2 = point_list[i + 1] x1, y1 = point_1 x2, y2 = point_2 ax.plot([x1, x2], [y1, y2], color=color)