import json
from pathlib import Path
import torch
from discopat.core import Frame, Model, Movie
from discopat.datasets.torch import TorchBoxDataset
from discopat.nn_models import FasterRCNNModel
from discopat.repositories.hdf5 import HDF5Repository
from discopat.repositories.local import DISCOPATH, LocalNNModelRepository
from discopat.utils import get_device
MOVIE_TABLE = {
"blob_dwi_512": "250610_103200",
"blob_i_512": "250605_164500",
"turb_dwi_256": "250603_111000",
"turb_dwi_512": "250610_110800",
"turb_i_256": "250603_105600",
"turb_i_512": "250715_150500",
}
SET_TABLE = {
"train": {"movie_name": "blob_i_512", "annotation_task": "250606_110200"},
"val": {"movie_name": "blob_dwi_512", "annotation_task": "250610_211600"},
"test": {"movie_name": "turb_dwi_512", "annotation_task": "250610_220000"},
}
MOVIE_REPO = HDF5Repository("tokam2d")
MODEL_REPO = LocalNNModelRepository("models")
COMPUTING_DEVICE = get_device(allow_mps=False)
print("Computing device:", COMPUTING_DEVICE)
[docs]
def load_movie(movie_name: str) -> Movie:
if movie_name not in MOVIE_TABLE:
msg = (
f"Unkown movie name: {movie_name}. "
f"Allowed names: {sorted(MOVIE_TABLE)}"
)
raise ValueError(msg)
movie = MOVIE_REPO.read(MOVIE_TABLE[movie_name])
print(f"Loaded {len(movie)} frames.")
return movie
[docs]
def load_set(set_name: str) -> Movie:
if set_name not in SET_TABLE:
msg = f"Unknown set name: {set_name}. Allowed names: {list(SET_TABLE)}"
raise ValueError(msg)
movie_name = SET_TABLE[set_name]["movie_name"]
annotation_task = SET_TABLE[set_name]["annotation_task"]
# Load movie
movie = load_movie(movie_name)
# Load annotations
annotation_path = (
DISCOPATH / "annotations" / annotation_task / "annotated_movie.json"
)
with Path.open(annotation_path) as f:
annotation_dict = {
frame.name: frame.annotations
for frame in Movie.from_dict(json.load(f)).frames
}
# Filter out non-annotated frames
movie.frames = [
frame for frame in movie.frames if frame.name in annotation_dict
]
# Add annotations to frames
for frame in movie.frames:
frame.annotations = annotation_dict[frame.name]
print(f"Loaded {len(movie)} annotated frames.")
return movie
[docs]
def load_model(model_name: str) -> Model:
raw_model = MODEL_REPO.read(model_name)
model = FasterRCNNModel.from_dict(raw_model)
model.set_device(COMPUTING_DEVICE)
return model
[docs]
def collate_fn(batch: torch.Tensor) -> tuple:
"""Collate function to make dataloaders."""
return tuple(zip(*batch))
[docs]
def make_dataloaders(
train_frames: list[Frame],
val_frames: list[Frame],
label_map: dict[int, str],
train_batch_size: int,
val_batch_size: int,
) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
"""Build data loaders from train- and val frames."""
train_ds = TorchBoxDataset(train_frames, label_map)
val_ds = TorchBoxDataset(val_frames, label_map)
train_dataloader = torch.utils.data.DataLoader(
train_ds,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)
val_dataloader = torch.utils.data.DataLoader(
val_ds,
batch_size=val_batch_size,
shuffle=False,
collate_fn=collate_fn,
)
return train_dataloader, val_dataloader