discopat.nn_training package

Subpackages

Submodules

discopat.nn_training.nn_trainer module

class discopat.nn_training.nn_trainer.NNTrainer(net, dataset, val_dataset, parameters, device, callbacks=None)[source]

Bases: ABC

Parameters:
static parse_parameters(parameters)[source]
Parameters:

parameters (dict[str, Any])

Return type:

tuple(dict, dict, dict)

abstractmethod set_default_lr_scheduler()[source]
Return type:

Any

abstractmethod set_default_optimiser()[source]
Return type:

Any

abstractmethod train()[source]
Return type:

None

discopat.nn_training.torch module

class discopat.nn_training.torch.TorchNNTrainer(net, dataset, val_dataset, parameters, device, callbacks=None)[source]

Bases: NNTrainer

Parameters:
set_default_lr_scheduler()[source]
Return type:

LRScheduler

set_default_optimiser()[source]
Return type:

Optimizer

train(num_epochs)[source]
Parameters:

num_epochs (int)

discopat.nn_training.utils module

class discopat.nn_training.utils.MetricLogger(delimiter='\t')[source]

Bases: object

add_meter(name, meter)[source]
log_every(iterable, print_freq, header=None)[source]
synchronize_between_processes()[source]
update(**kwargs)[source]
class discopat.nn_training.utils.SmoothedValue(window_size=20, fmt=None)[source]

Bases: object

Track a series of values and provide access to smoothed values over a window or the global series average.

property avg
property global_avg
property max
property median
synchronize_between_processes()[source]

Warning: does not synchronize the deque.

update(value, n=1)[source]
property value
discopat.nn_training.utils.get_world_size()[source]
discopat.nn_training.utils.is_dist_avail_and_initialized()[source]

Module contents