import datetime
import errno
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
[docs]
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a window or the global series average."""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
[docs]
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
[docs]
def synchronize_between_processes(self):
"""Warning: does not synchronize the deque!."""
if not is_dist_avail_and_initialized():
return
t = torch.tensor(
[self.count, self.total], dtype=torch.float64, device="cuda"
)
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
[docs]
def all_gather(data):
"""Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
data_list = [None] * world_size
dist.all_gather_object(data_list, data)
return data_list
[docs]
def reduce_dict(input_dict, average=True):
"""Reduce the values in the dictionary from all processes so that all processes.
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum have the averaged results.
Returns a dict with the same fields as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.inference_mode():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
return {k: v for k, v in zip(names, values)}
[docs]
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
[docs]
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'"
)
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
[docs]
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
[docs]
def add_meter(self, name, meter):
self.meters[name] = meter
[docs]
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(
f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)"
)
[docs]
def collate_fn(batch):
return tuple(zip(*batch))
[docs]
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
[docs]
def setup_for_distributed(is_master):
"""Disable printing when not in master process."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
[docs]
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
return dist.is_initialized()
[docs]
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
[docs]
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
[docs]
def is_main_process():
return get_rank() == 0