| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
|
|
| import torch |
| from torch.utils.data.distributed import DistributedSampler |
| from torch.utils.data import DataLoader, Subset |
| from torch.nn.parallel.distributed import DistributedDataParallel |
|
|
| logger = logging.getLogger(__name__) |
| rank = 0 |
| world_size = 1 |
|
|
|
|
| def init(args): |
| """init. |
| |
| Initialize DDP using the given rendezvous file. |
| """ |
| global rank, world_size |
| if args.ddp: |
| assert args.rank is not None and args.world_size is not None |
| rank = args.rank |
| world_size = args.world_size |
| if world_size == 1: |
| return |
| torch.cuda.set_device(rank) |
| torch.distributed.init_process_group( |
| backend=args.ddp_backend, |
| init_method='file://' + os.path.abspath(args.rendezvous_file), |
| world_size=world_size, |
| rank=rank) |
| logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size) |
|
|
|
|
| def average(metrics, count=1.): |
| """average. |
| |
| Average all the relevant metrices across processes |
| `metrics`should be a 1D float32 fector. Returns the average of `metrics` |
| over all hosts. You can use `count` to control the weight of each worker. |
| """ |
| if world_size == 1: |
| return metrics |
| tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) |
| tensor *= count |
| torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) |
| return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() |
|
|
|
|
| def wrap(model): |
| """wrap. |
| |
| Wrap a model with DDP if distributed training is enabled. |
| """ |
| if world_size == 1: |
| return model |
| else: |
| return DistributedDataParallel( |
| model, |
| device_ids=[torch.cuda.current_device()], |
| output_device=torch.cuda.current_device()) |
|
|
|
|
| def barrier(): |
| if world_size > 1: |
| torch.distributed.barrier() |
|
|
|
|
| def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): |
| """loader. |
| |
| Create a dataloader properly in case of distributed training. |
| If a gradient is going to be computed you must set `shuffle=True`. |
| |
| :param dataset: the dataset to be parallelized |
| :param args: relevant args for the loader |
| :param shuffle: shuffle examples |
| :param klass: loader class |
| :param kwargs: relevant args |
| """ |
|
|
| if world_size == 1: |
| return klass(dataset, *args, shuffle=shuffle, **kwargs) |
|
|
| if shuffle: |
| |
| sampler = DistributedSampler(dataset) |
| |
| return klass(dataset, *args, **kwargs, sampler=sampler) |
| else: |
| |
| dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) |
| return klass(dataset, *args, shuffle=shuffle) |
|
|