| |
| |
| |
| |
| |
| |
| """ |
| Start multiple process locally for DDP. |
| """ |
|
|
| import logging |
| import subprocess as sp |
| import sys |
|
|
| from hydra import utils |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ChildrenManager: |
| def __init__(self): |
| self.children = [] |
| self.failed = False |
|
|
| def add(self, child): |
| child.rank = len(self.children) |
| self.children.append(child) |
|
|
| def __enter__(self): |
| return self |
|
|
| def __exit__(self, exc_type, exc_value, traceback): |
| if exc_value is not None: |
| logger.error("An exception happened while starting workers %r", exc_value) |
| self.failed = True |
| try: |
| while self.children and not self.failed: |
| for child in list(self.children): |
| try: |
| exitcode = child.wait(0.1) |
| except sp.TimeoutExpired: |
| continue |
| else: |
| self.children.remove(child) |
| if exitcode: |
| logger.error(f"Worker {child.rank} died, killing all workers") |
| self.failed = True |
| except KeyboardInterrupt: |
| logger.error("Received keyboard interrupt, trying to kill all workers.") |
| self.failed = True |
| for child in self.children: |
| child.terminate() |
| if not self.failed: |
| logger.info("All workers completed successfully") |
|
|
|
|
| def start_ddp_workers(): |
| import torch as th |
|
|
| world_size = th.cuda.device_count() |
| if not world_size: |
| logger.error( |
| "DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") |
| sys.exit(1) |
| logger.info(f"Starting {world_size} worker processes for DDP.") |
| with ChildrenManager() as manager: |
| for rank in range(world_size): |
| kwargs = {} |
| argv = list(sys.argv) |
| argv += [f"world_size={world_size}", f"rank={rank}"] |
| if rank > 0: |
| kwargs['stdin'] = sp.DEVNULL |
| kwargs['stdout'] = sp.DEVNULL |
| kwargs['stderr'] = sp.DEVNULL |
| log = utils.HydraConfig().hydra.job_logging.handlers.file.filename |
| log += f".{rank}" |
| argv.append("hydra.job_logging.handlers.file.filename=" + log) |
| manager.add(sp.Popen([sys.executable] + argv, cwd=utils.get_original_cwd(), **kwargs)) |
| sys.exit(int(manager.failed)) |
|
|