diff --git a/torchft/local_sgd_integ_test.py b/torchft/local_sgd_integ_test.py index 3cc72634..f0edcec7 100644 --- a/torchft/local_sgd_integ_test.py +++ b/torchft/local_sgd_integ_test.py @@ -28,7 +28,11 @@ MyModel, Runner, ) -from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo +from torchft.process_group import ( + FakeProcessGroupWrapper, + ProcessGroupBabyNCCL, + ProcessGroupGloo, +) logger: logging.Logger = logging.getLogger(__name__) @@ -188,9 +192,11 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting") if device.type == "cuda": - pg = ProcessGroupBabyNCCL() + pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL()) else: - pg = ProcessGroupGloo(timeout=timedelta(seconds=10)) + pg = FakeProcessGroupWrapper( + ProcessGroupGloo(timeout=timedelta(seconds=10)) + ) manager = Manager( pg=pg, min_replica_size=2, @@ -210,6 +216,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53] # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) + runner.event_injector.set_pg(pg) stack.callback(manager.shutdown) # initialize default group for device mesh to work if not torch.distributed.is_initialized(): @@ -669,3 +676,78 @@ def test_streaming_diloco_upscale( for event_injector in event_injectors: self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @skipIf(sys.platform == "darwin", "not reliable on mac") + @parameterized.expand(CONFIG) + def test_streaming_diloco_commit_failure( + self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int + ) -> None: + # Skip the test if use_cuda is True and there are not enough GPUs + if use_cuda and torch.cuda.device_count() < 2: + self.skipTest("Not enough GPUs for CUDA test") + + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + executors = [] + + event_injectors = [ + EventInjector().fail_allreduce_at(0, 1), + EventInjector().fail_allreduce_at(0, 1), + ] + + torch.manual_seed(42) + # Initialize the model so we can pass in the state_dict + m: nn.Module = MultiMyModel(2, 3, n_fragments) + + for replica_id, event_injector in zip(range(num_replicas), event_injectors): + executor = ThreadPoolExecutor(max_workers=1) + executors.append(executor) + runner = Runner( + replica_id=replica_id, + num_replicas=num_replicas, + lighthouse_address=lighthouse.address(), + event_injector=event_injector, + train_loop=diloco_train_loop, + train_loop_args={ + "model_state_dict": m.state_dict(), + "n_fragments": n_fragments, + "diloco_args": { + "fragment_sync_delay": fragment_sync_delay, + "sync_every": 4, + }, + }, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + continue + + for fut in futures: + try: + state_dicts.append(fut.result()[0]) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + + rep0, rep1 = state_dicts + + assert_equal_global_state(rep0, rep1) + + for step in rep0.keys(): + self.assertEqual( + rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"] + ) + + for event_injector in event_injectors: + self.assertEqual( + event_injector.count[EventInjectorEvent.AllreduceFailure], 1 + ) diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 322a4aa1..6bdab583 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -34,7 +34,11 @@ from torchft.local_sgd import DiLoCo, LocalSGD from torchft.manager import Manager from torchft.optim import OptimizerWrapper -from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo +from torchft.process_group import ( + FakeProcessGroupWrapper, + ProcessGroupBabyNCCL, + ProcessGroupGloo, +) logger: logging.Logger = logging.getLogger(__name__) @@ -76,10 +80,13 @@ class InjectedFailure(Exception): class EventInjectorEvent(Enum): + # Crashes a rank Failure = auto() # Used to wait for a rank to reach a certain step before continuing. # Users need to make sure the size of the barrier is appropriately set. Barrier = auto() + # Fails the allreduce call made by a rank + AllreduceFailure = auto() class EventInjectorInfo: @@ -90,10 +97,15 @@ def __init__(self, event: EventInjectorEvent, data: object) -> None: class EventInjector: def __init__(self) -> None: + self._pg: Optional[FakeProcessGroupWrapper] = None self._lock = threading.Lock() self._events: Dict[Tuple[int, int], EventInjectorInfo] = {} self.count: dict[EventInjectorEvent, int] = defaultdict(int) + def set_pg(self, pg: FakeProcessGroupWrapper) -> None: + with self._lock: + self._pg = pg + def fail_at(self, rank: int, step: int) -> "EventInjector": with self._lock: assert (rank, step) not in self._events @@ -102,6 +114,14 @@ def fail_at(self, rank: int, step: int) -> "EventInjector": ) return self + def fail_allreduce_at(self, rank: int, step: int) -> "EventInjector": + with self._lock: + assert (rank, step) not in self._events + self._events[(rank, step)] = EventInjectorInfo( + EventInjectorEvent.AllreduceFailure, None + ) + return self + def barrier_at( self, rank: int, step: int, barrier: threading.Barrier ) -> "EventInjector": @@ -124,6 +144,14 @@ def check(self, rank: int, step: int) -> None: print(f"injecting failure {rank=} {step=}") raise InjectedFailure(f"injected failure {rank=} {step=}") + if event_info.event == EventInjectorEvent.AllreduceFailure: + print(f"injecting allreduce failure {rank=} {step=}") + assert self._pg is not None + self._pg.report_future_error( + RuntimeError("injected allreduce error") + ) + return + if event_info.event == EventInjectorEvent.Barrier: print(f"waiting for barrier {rank=} {step=}") cast(threading.Barrier, event_info.data).wait() diff --git a/torchft/process_group.py b/torchft/process_group.py index f1a3b79f..8f9c27bc 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -378,6 +378,13 @@ def __init__( self._pg: Optional[BaseProcessGroup] = pg self._timeout = timeout + def getBackendName(self) -> str: + pg = self._pg + if isinstance(pg, ProcessGroup): + return pg.getBackendName() + + raise NotImplementedError("not implemented") + def configure(self, store_addr: str, rank: int, world_size: int) -> None: pg = self._pg if isinstance(pg, ProcessGroup): @@ -1016,6 +1023,55 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return _DummyWork(tensors) +class FakeProcessGroupWrapper(ProcessGroupWrapper): + """ + This is a wrapper around any ProcessGroup that can be used to inject + errors into the process group at various points. + + This is intended to be used for tests so that they can test cases + in which process group operations error out. + """ + + def __init__(self, pg: ProcessGroup) -> None: + super().__init__(pg=pg) + + self._future_error: Optional[Exception] = None + + def configure(self, store_addr: str, rank: int, world_size: int) -> None: + self._future_error = None + + super().configure(store_addr, rank, world_size) + + def report_future_error(self, e: Exception) -> None: + """ + Report an error to this process group. This will cause the + future attached to the next operation to error out. + + Args: + e: exception to report + """ + self._future_error = e + + def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: + work = super().allreduce(tensors, opts) + + if self._future_error is None: + return work + + fut = work.get_future() + + def callback( + fut: torch.futures.Future[List[torch.Tensor]], + ) -> List[torch.Tensor]: + future_error, self._future_error = self._future_error, None + assert future_error is not None + raise future_error + + fut = fut.then(callback) + + return work + + class _ManagedWork(Work): def __init__(self, manager: "Manager", work: Work, default_result: object) -> None: super().__init__()