Model Inference

This example demonstrates how to detect structures with a CNN using discopat.

Imports

from tqdm import tqdm

from discopat.core import Movie
from discopat.display import plot_frame
from discopat.repositories import repository_factory
from discopat.nn_models import FasterRCNNModel

Definitions

movie_name = "blob_i/density"
model_name = "faster_rcnn_241113_131447"

computing_device = "cpu"
data_source = "osf"
framework = "torch"

Load the images

movie_repository = repository_factory(data_source, "input_movies")
movie = movie_repository.read(movie_name)

for frame in movie.frames:
    plot_frame(frame)
  0%|          | 0.00/361 [00:00<?, ?bytes/s]
100%|██████████| 361/361 [00:00<00:00, 3.55Mbytes/s]

  0%|          | 0.00/6.55M [00:00<?, ?bytes/s]
  0%|          | 32.8k/6.55M [00:00<00:24, 262kbytes/s]
  1%|          | 65.5k/6.55M [00:00<00:22, 285kbytes/s]
  2%|▏         | 98.3k/6.55M [00:00<00:22, 284kbytes/s]
  3%|▎         | 180k/6.55M [00:00<00:13, 468kbytes/s]
  5%|▌         | 328k/6.55M [00:00<00:07, 805kbytes/s]
  9%|▉         | 606k/6.55M [00:00<00:04, 1.43Mbytes/s]
 18%|█▊        | 1.15M/6.55M [00:00<00:02, 2.66Mbytes/s]
 34%|███▍      | 2.21M/6.55M [00:00<00:00, 5.12Mbytes/s]
 60%|██████    | 3.96M/6.55M [00:00<00:00, 8.91Mbytes/s]
 91%|█████████ | 5.95M/6.55M [00:01<00:00, 12.2Mbytes/s]
100%|██████████| 6.55M/6.55M [00:01<00:00, 6.05Mbytes/s]

  0%|          | 0.00/6.55M [00:00<?, ?bytes/s]
  0%|          | 32.8k/6.55M [00:00<00:25, 253kbytes/s]
  1%|          | 65.5k/6.55M [00:00<00:22, 289kbytes/s]
  2%|▏         | 98.3k/6.55M [00:00<00:22, 286kbytes/s]
  3%|▎         | 180k/6.55M [00:00<00:13, 469kbytes/s]
  5%|▌         | 344k/6.55M [00:00<00:07, 844kbytes/s]
 10%|▉         | 623k/6.55M [00:00<00:04, 1.46Mbytes/s]
 18%|█▊        | 1.16M/6.55M [00:00<00:02, 2.66Mbytes/s]
 36%|███▌      | 2.33M/6.55M [00:00<00:00, 5.40Mbytes/s]
 66%|██████▌   | 4.31M/6.55M [00:00<00:00, 9.77Mbytes/s]
 93%|█████████▎| 6.09M/6.55M [00:01<00:00, 12.2Mbytes/s]
100%|██████████| 6.55M/6.55M [00:01<00:00, 6.05Mbytes/s]

  0%|          | 0.00/6.55M [00:00<?, ?bytes/s]
  0%|          | 32.8k/6.55M [00:00<00:24, 262kbytes/s]
  1%|          | 65.5k/6.55M [00:00<00:22, 285kbytes/s]
  2%|▏         | 98.3k/6.55M [00:00<00:22, 283kbytes/s]
  3%|▎         | 180k/6.55M [00:00<00:13, 465kbytes/s]
  5%|▌         | 344k/6.55M [00:00<00:07, 838kbytes/s]
 10%|▉         | 623k/6.55M [00:00<00:04, 1.45Mbytes/s]
 18%|█▊        | 1.16M/6.55M [00:00<00:02, 2.68Mbytes/s]
 35%|███▍      | 2.28M/6.55M [00:00<00:00, 5.28Mbytes/s]
 65%|██████▌   | 4.26M/6.55M [00:00<00:00, 9.70Mbytes/s]
 95%|█████████▌| 6.23M/6.55M [00:01<00:00, 12.7Mbytes/s]
100%|██████████| 6.55M/6.55M [00:01<00:00, 5.95Mbytes/s]
Figure(640x480)
Figure(640x480)
Figure(640x480)

Load the model

model_repository = repository_factory(data_source, "models")

raw_model = model_repository.read(model_name)
model = FasterRCNNModel.from_dict(raw_model)
model.set_device(computing_device)
  0%|          | 0.00/14.0 [00:00<?, ?bytes/s]
100%|██████████| 14.0/14.0 [00:00<00:00, 124kbytes/s]

  0%|          | 0.00/112 [00:00<?, ?bytes/s]
100%|██████████| 112/112 [00:00<00:00, 1.23Mbytes/s]

  0%|          | 0.00/166M [00:00<?, ?bytes/s]
  0%|          | 32.8k/166M [00:00<10:31, 262kbytes/s]
  0%|          | 65.5k/166M [00:00<09:41, 285kbytes/s]
  0%|          | 98.3k/166M [00:00<09:39, 286kbytes/s]
  0%|          | 180k/166M [00:00<05:53, 469kbytes/s]
  0%|          | 344k/166M [00:00<03:16, 842kbytes/s]
  0%|          | 623k/166M [00:00<01:53, 1.46Mbytes/s]
  1%|          | 1.16M/166M [00:00<01:01, 2.67Mbytes/s]
  1%|▏         | 2.29M/166M [00:00<00:30, 5.28Mbytes/s]
  2%|▏         | 2.83M/166M [00:00<00:31, 5.25Mbytes/s]
  3%|▎         | 5.62M/166M [00:01<00:13, 11.8Mbytes/s]
  5%|▌         | 8.57M/166M [00:01<00:09, 16.9Mbytes/s]
  7%|▋         | 11.4M/166M [00:01<00:07, 20.2Mbytes/s]
  9%|▊         | 14.3M/166M [00:01<00:06, 22.6Mbytes/s]
 10%|█         | 17.2M/166M [00:01<00:06, 24.4Mbytes/s]
 12%|█▏        | 19.7M/166M [00:01<00:06, 23.7Mbytes/s]
 13%|█▎        | 22.1M/166M [00:01<00:06, 22.8Mbytes/s]
 15%|█▍        | 24.4M/166M [00:01<00:06, 22.1Mbytes/s]
 16%|█▋        | 27.2M/166M [00:01<00:05, 23.9Mbytes/s]
 18%|█▊        | 30.0M/166M [00:02<00:05, 24.9Mbytes/s]
 20%|█▉        | 32.7M/166M [00:02<00:05, 25.7Mbytes/s]
 21%|██▏       | 35.6M/166M [00:02<00:04, 26.3Mbytes/s]
 23%|██▎       | 38.3M/166M [00:02<00:05, 24.3Mbytes/s]
 25%|██▍       | 41.1M/166M [00:02<00:04, 25.0Mbytes/s]
 27%|██▋       | 44.0M/166M [00:02<00:04, 25.6Mbytes/s]
 28%|██▊       | 46.6M/166M [00:02<00:04, 24.7Mbytes/s]
 30%|██▉       | 49.1M/166M [00:02<00:06, 19.2Mbytes/s]
 31%|███       | 51.8M/166M [00:02<00:05, 20.9Mbytes/s]
 33%|███▎      | 54.7M/166M [00:03<00:04, 22.7Mbytes/s]
 34%|███▍      | 57.1M/166M [00:03<00:05, 20.3Mbytes/s]
 36%|███▌      | 59.3M/166M [00:03<00:05, 20.4Mbytes/s]
 37%|███▋      | 61.8M/166M [00:03<00:04, 21.7Mbytes/s]
 39%|███▉      | 64.4M/166M [00:03<00:04, 22.8Mbytes/s]
 41%|████      | 67.1M/166M [00:03<00:04, 24.0Mbytes/s]
 42%|████▏     | 69.9M/166M [00:03<00:03, 25.0Mbytes/s]
 44%|████▍     | 72.7M/166M [00:03<00:03, 25.7Mbytes/s]
 46%|████▌     | 75.5M/166M [00:03<00:03, 26.2Mbytes/s]
 47%|████▋     | 78.5M/166M [00:04<00:03, 26.4Mbytes/s]
 49%|████▉     | 81.4M/166M [00:04<00:03, 27.0Mbytes/s]
 51%|█████     | 84.3M/166M [00:04<00:02, 27.3Mbytes/s]
 53%|█████▎    | 87.0M/166M [00:04<00:03, 25.3Mbytes/s]
 54%|█████▍    | 89.6M/166M [00:04<00:03, 24.5Mbytes/s]
 56%|█████▌    | 92.4M/166M [00:04<00:02, 25.3Mbytes/s]
 57%|█████▋    | 94.9M/166M [00:04<00:02, 25.2Mbytes/s]
 59%|█████▉    | 97.6M/166M [00:04<00:02, 25.5Mbytes/s]
 61%|██████    | 100M/166M [00:04<00:02, 26.2Mbytes/s]
 62%|██████▏   | 103M/166M [00:05<00:02, 26.1Mbytes/s]
 64%|██████▍   | 106M/166M [00:05<00:02, 26.3Mbytes/s]
 65%|██████▌   | 109M/166M [00:05<00:02, 26.2Mbytes/s]
 67%|██████▋   | 111M/166M [00:05<00:02, 24.5Mbytes/s]
 69%|██████▉   | 114M/166M [00:05<00:02, 25.7Mbytes/s]
 71%|███████   | 117M/166M [00:05<00:01, 26.3Mbytes/s]
 72%|███████▏  | 120M/166M [00:05<00:01, 26.9Mbytes/s]
 74%|███████▍  | 123M/166M [00:05<00:01, 27.2Mbytes/s]
 76%|███████▌  | 126M/166M [00:05<00:01, 27.5Mbytes/s]
 78%|███████▊  | 129M/166M [00:05<00:01, 27.7Mbytes/s]
 79%|███████▉  | 131M/166M [00:06<00:01, 27.8Mbytes/s]
 81%|████████  | 134M/166M [00:06<00:01, 27.9Mbytes/s]
 83%|████████▎ | 137M/166M [00:06<00:01, 22.9Mbytes/s]
 84%|████████▍ | 140M/166M [00:06<00:01, 23.1Mbytes/s]
 86%|████████▌ | 142M/166M [00:06<00:00, 23.8Mbytes/s]
 87%|████████▋ | 145M/166M [00:06<00:00, 24.9Mbytes/s]
 89%|████████▉ | 148M/166M [00:06<00:00, 25.9Mbytes/s]
 91%|█████████ | 150M/166M [00:06<00:00, 25.0Mbytes/s]
 92%|█████████▏| 153M/166M [00:06<00:00, 24.2Mbytes/s]
 94%|█████████▍| 155M/166M [00:07<00:00, 23.4Mbytes/s]
 95%|█████████▌| 158M/166M [00:07<00:00, 22.2Mbytes/s]
 97%|█████████▋| 160M/166M [00:07<00:00, 22.2Mbytes/s]
 98%|█████████▊| 162M/166M [00:07<00:00, 20.1Mbytes/s]
100%|█████████▉| 165M/166M [00:07<00:00, 21.7Mbytes/s]
100%|██████████| 166M/166M [00:07<00:00, 21.9Mbytes/s]
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/docs/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth

  0%|          | 0.00/97.8M [00:00<?, ?B/s]
 19%|█▉        | 19.0M/97.8M [00:00<00:00, 198MB/s]
 39%|███▉      | 38.0M/97.8M [00:00<00:00, 192MB/s]
 58%|█████▊    | 56.4M/97.8M [00:00<00:00, 184MB/s]
 83%|████████▎ | 81.4M/97.8M [00:00<00:00, 213MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 209MB/s]

Compute predictions

analysed_frames = [model.predict(frame) for frame in tqdm(movie.frames)]
analysed_movie = Movie(name=movie.name, frames=analysed_frames, tracks=[])
  0%|          | 0/3 [00:00<?, ?it/s]/home/docs/checkouts/readthedocs.org/user_builds/discopat/checkouts/stable/discopat/core/entities/annotation.py:25: UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  self.score = float(score)

 33%|███▎      | 1/3 [00:02<00:04,  2.32s/it]
 67%|██████▋   | 2/3 [00:04<00:02,  2.09s/it]
100%|██████████| 3/3 [00:06<00:00,  1.99s/it]
100%|██████████| 3/3 [00:06<00:00,  2.04s/it]

Display predictions

for frame in analysed_movie.frames:
    plot_frame(frame)
Figure(640x480)
Figure(640x480)
Figure(640x480)

Total running time of the script: (0 minutes 54.744 seconds)

Gallery generated by Sphinx-Gallery