Source code for discopat.nn_training.nn_trainer

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from discopat.core import ComputingDevice, Dataset, NeuralNet


[docs] class NNTrainer(ABC): def __init__( self, net: NeuralNet, dataset: Dataset, val_dataset: Dataset, parameters: dict[str, Any], device: ComputingDevice, callbacks: list or None = None, ): self.net = net self.dataset = dataset self.val_dataset = val_dataset self.device = device if callbacks is None: callbacks = [] self.callbacks = callbacks ( self.optimiser_params, self.lr_scheduler_params, self.training_loop_params, ) = self.parse_parameters(parameters) self.optimiser = self.set_default_optimiser() self.lr_scheduler = self.set_default_lr_scheduler()
[docs] @abstractmethod def train(self) -> None: pass
[docs] @staticmethod def parse_parameters(parameters: dict[str, Any]) -> tuple(dict, dict, dict): optimiser_params = parameters["optimiser"] lr_scheduler_params = parameters["lr_scheduler"] training_loop_params = parameters["training_loop"] return optimiser_params, lr_scheduler_params, training_loop_params
[docs] @abstractmethod def set_default_optimiser(self) -> Any: pass
[docs] @abstractmethod def set_default_lr_scheduler(self) -> Any: pass