Source code for discopat.projects.tokam

import json
from pathlib import Path

import torch

from discopat.core import Frame, Model, Movie
from discopat.datasets.torch import TorchBoxDataset
from discopat.nn_models import FasterRCNNModel
from discopat.repositories.hdf5 import HDF5Repository
from discopat.repositories.local import DISCOPATH, LocalNNModelRepository
from discopat.utils import get_device

MOVIE_TABLE = {
    "blob_dwi_512": "250610_103200",
    "blob_i_512": "250605_164500",
    "turb_dwi_256": "250603_111000",
    "turb_dwi_512": "250610_110800",
    "turb_i_256": "250603_105600",
    "turb_i_512": "250715_150500",
}

SET_TABLE = {
    "train": {"movie_name": "blob_i_512", "annotation_task": "250606_110200"},
    "val": {"movie_name": "blob_dwi_512", "annotation_task": "250610_211600"},
    "test": {"movie_name": "turb_dwi_512", "annotation_task": "250610_220000"},
}

MOVIE_REPO = HDF5Repository("tokam2d")
MODEL_REPO = LocalNNModelRepository("models")

COMPUTING_DEVICE = get_device(allow_mps=False)
print("Computing device:", COMPUTING_DEVICE)


[docs] def load_movie(movie_name: str) -> Movie: if movie_name not in MOVIE_TABLE: msg = ( f"Unkown movie name: {movie_name}. " f"Allowed names: {sorted(MOVIE_TABLE)}" ) raise ValueError(msg) movie = MOVIE_REPO.read(MOVIE_TABLE[movie_name]) print(f"Loaded {len(movie)} frames.") return movie
[docs] def load_set(set_name: str) -> Movie: if set_name not in SET_TABLE: msg = f"Unknown set name: {set_name}. Allowed names: {list(SET_TABLE)}" raise ValueError(msg) movie_name = SET_TABLE[set_name]["movie_name"] annotation_task = SET_TABLE[set_name]["annotation_task"] # Load movie movie = load_movie(movie_name) # Load annotations annotation_path = ( DISCOPATH / "annotations" / annotation_task / "annotated_movie.json" ) with Path.open(annotation_path) as f: annotation_dict = { frame.name: frame.annotations for frame in Movie.from_dict(json.load(f)).frames } # Filter out non-annotated frames movie.frames = [ frame for frame in movie.frames if frame.name in annotation_dict ] # Add annotations to frames for frame in movie.frames: frame.annotations = annotation_dict[frame.name] print(f"Loaded {len(movie)} annotated frames.") return movie
[docs] def load_model(model_name: str) -> Model: raw_model = MODEL_REPO.read(model_name) model = FasterRCNNModel.from_dict(raw_model) model.set_device(COMPUTING_DEVICE) return model
[docs] def collate_fn(batch: torch.Tensor) -> tuple: """Collate function to make dataloaders.""" return tuple(zip(*batch))
[docs] def make_dataloaders( train_frames: list[Frame], val_frames: list[Frame], label_map: dict[int, str], train_batch_size: int, val_batch_size: int, ) -> tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: """Build data loaders from train- and val frames.""" train_ds = TorchBoxDataset(train_frames, label_map) val_ds = TorchBoxDataset(val_frames, label_map) train_dataloader = torch.utils.data.DataLoader( train_ds, batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, ) val_dataloader = torch.utils.data.DataLoader( val_ds, batch_size=val_batch_size, shuffle=False, collate_fn=collate_fn, ) return train_dataloader, val_dataloader