Source code for discopat.projects.tokam

import json
from pathlib import Path

from discopat.core import Model, Movie
from discopat.nn_models import FasterRCNNModel
from discopat.repositories.hdf5 import HDF5Repository
from discopat.repositories.local import DISCOPATH, LocalNNModelRepository
from discopat.utils import get_device

MOVIE_TABLE = {
    "blob_dwi_512": "250610_103200",
    "blob_i_512": "250605_164500",
    "turb_dwi_256": "250603_111000",
    "turb_dwi_512": "250610_110800",
    "turb_i_256": "250603_105600",
    "turb_i_512": "250715_150500",
}

SET_TABLE = {
    "train": {"movie_name": "blob_i_512", "annotation_task": "250606_110200"},
    "val": {"movie_name": "blob_dwi_512", "annotation_task": "250610_211600"},
    "test": {"movie_name": "turb_dwi_512", "annotation_task": "250610_220000"},
}

MOVIE_REPO = HDF5Repository("tokam2d")
MODEL_REPO = LocalNNModelRepository("models")

COMPUTING_DEVICE = get_device(allow_mps=False)
print("Computing device:", COMPUTING_DEVICE)


[docs] def load_movie(movie_name: str) -> Movie: if movie_name not in MOVIE_TABLE: msg = ( f"Unkown movie name: {movie_name}. " f"Allowed names: {sorted(MOVIE_TABLE)}" ) raise ValueError(msg) movie = MOVIE_REPO.read(MOVIE_TABLE[movie_name]) print(f"Loaded {len(movie)} frames.") return movie
[docs] def load_set(set_name: str) -> Movie: if set_name not in SET_TABLE: msg = f"Unknown set name: {set_name}. Allowed names: {list(SET_TABLE)}" raise ValueError(msg) movie_name = SET_TABLE[set_name]["movie_name"] annotation_task = SET_TABLE[set_name]["annotation_task"] # Load movie movie = load_movie(movie_name) # Load annotations annotation_path = ( DISCOPATH / "annotations" / annotation_task / "annotated_movie.json" ) with Path.open(annotation_path) as f: annotation_dict = { frame.name: frame.annotations for frame in Movie.from_dict(json.load(f)).frames } # Filter out non-annotated frames movie.frames = [ frame for frame in movie.frames if frame.name in annotation_dict ] # Add annotations to frames for frame in movie.frames: frame.annotations = annotation_dict[frame.name] print(f"Loaded {len(movie)} annotated frames.") return movie
[docs] def load_model(model_name: str) -> Model: raw_model = MODEL_REPO.read(model_name) model = FasterRCNNModel.from_dict(raw_model) model.set_device(COMPUTING_DEVICE) return model