Source code for discopat.nn_training.torch_detection_utils.engine

import math
import sys

import torch

from .utils import MetricLogger, SmoothedValue, reduce_dict


[docs] def train_one_epoch( model, optimizer, data_loader, device, epoch, print_freq, scaler=None ): model.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter( "lr", SmoothedValue(window_size=1, fmt="{value:.6f}") ) header = f"Epoch: [{epoch}]" lr_scheduler = None if epoch == 0: warmup_factor = 1.0 / 1000 warmup_iters = min(1000, len(data_loader) - 1) lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=warmup_factor, total_iters=warmup_iters ) for images, targets in metric_logger.log_every( data_loader, print_freq, header ): input_images = [image.float().to(device) for image in images] train_targets = [ { k: v.to(torch.int64).to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items() } for t in targets ] with torch.amp.autocast(str(device), enabled=scaler is not None): loss_dict = model(input_images, train_targets) losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = reduce_dict(loss_dict) losses_reduced = sum(loss for loss in loss_dict_reduced.values()) loss_value = losses_reduced.item() if not math.isfinite(loss_value): print(f"Loss is {loss_value}, stopping training") print(loss_dict_reduced) sys.exit(1) optimizer.zero_grad() if scaler is not None: scaler.scale(losses).backward() scaler.step(optimizer) scaler.update() else: losses.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() metric_logger.update(loss=losses_reduced, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) return metric_logger