Skip to content

test allreduce failures for diloco #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions torchft/local_sgd_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
MyModel,
Runner,
)
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
from torchft.process_group import (
FakeProcessGroupWrapper,
ProcessGroupBabyNCCL,
ProcessGroupGloo,
)

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -188,9 +192,11 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
pg = FakeProcessGroupWrapper(ProcessGroupBabyNCCL())
else:
pg = ProcessGroupGloo(timeout=timedelta(seconds=10))
pg = FakeProcessGroupWrapper(
ProcessGroupGloo(timeout=timedelta(seconds=10))
)
manager = Manager(
pg=pg,
min_replica_size=2,
Expand All @@ -210,6 +216,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-ignore[53]
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
)
runner.event_injector.set_pg(pg)
stack.callback(manager.shutdown)
# initialize default group for device mesh to work
if not torch.distributed.is_initialized():
Expand Down Expand Up @@ -669,3 +676,78 @@ def test_streaming_diloco_upscale(

for event_injector in event_injectors:
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
@skipIf(sys.platform == "darwin", "not reliable on mac")
@parameterized.expand(CONFIG)
def test_streaming_diloco_commit_failure(
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int
) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")

lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []
executors = []

event_injectors = [
EventInjector().fail_allreduce_at(0, 1),
EventInjector().fail_allreduce_at(0, 1),
]

torch.manual_seed(42)
# Initialize the model so we can pass in the state_dict
m: nn.Module = MultiMyModel(2, 3, n_fragments)

for replica_id, event_injector in zip(range(num_replicas), event_injectors):
executor = ThreadPoolExecutor(max_workers=1)
executors.append(executor)
runner = Runner(
replica_id=replica_id,
num_replicas=num_replicas,
lighthouse_address=lighthouse.address(),
event_injector=event_injector,
train_loop=diloco_train_loop,
train_loop_args={
"model_state_dict": m.state_dict(),
"n_fragments": n_fragments,
"diloco_args": {
"fragment_sync_delay": fragment_sync_delay,
"sync_every": 4,
},
},
)
futures.append(executor.submit(runner.run_replica))

state_dicts = []

for fut in as_completed(futures):
continue

for fut in futures:
try:
state_dicts.append(fut.result()[0])
except Exception as e:
print(e)
raise

lighthouse.shutdown()

rep0, rep1 = state_dicts

assert_equal_global_state(rep0, rep1)

for step in rep0.keys():
self.assertEqual(
rep0[step]["user"]["local_step"], rep1[step]["user"]["local_step"]
)

for event_injector in event_injectors:
self.assertEqual(
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
)
30 changes: 29 additions & 1 deletion torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.optim import OptimizerWrapper
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
from torchft.process_group import (
FakeProcessGroupWrapper,
ProcessGroupBabyNCCL,
ProcessGroupGloo,
)

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,10 +80,13 @@ class InjectedFailure(Exception):


class EventInjectorEvent(Enum):
# Crashes a rank
Failure = auto()
# Used to wait for a rank to reach a certain step before continuing.
# Users need to make sure the size of the barrier is appropriately set.
Barrier = auto()
# Fails the allreduce call made by a rank
AllreduceFailure = auto()


class EventInjectorInfo:
Expand All @@ -90,10 +97,15 @@ def __init__(self, event: EventInjectorEvent, data: object) -> None:

class EventInjector:
def __init__(self) -> None:
self._pg: Optional[FakeProcessGroupWrapper] = None
self._lock = threading.Lock()
self._events: Dict[Tuple[int, int], EventInjectorInfo] = {}
self.count: dict[EventInjectorEvent, int] = defaultdict(int)

def set_pg(self, pg: FakeProcessGroupWrapper) -> None:
with self._lock:
self._pg = pg

def fail_at(self, rank: int, step: int) -> "EventInjector":
with self._lock:
assert (rank, step) not in self._events
Expand All @@ -102,6 +114,14 @@ def fail_at(self, rank: int, step: int) -> "EventInjector":
)
return self

def fail_allreduce_at(self, rank: int, step: int) -> "EventInjector":
with self._lock:
assert (rank, step) not in self._events
self._events[(rank, step)] = EventInjectorInfo(
EventInjectorEvent.AllreduceFailure, None
)
return self

def barrier_at(
self, rank: int, step: int, barrier: threading.Barrier
) -> "EventInjector":
Expand All @@ -124,6 +144,14 @@ def check(self, rank: int, step: int) -> None:
print(f"injecting failure {rank=} {step=}")
raise InjectedFailure(f"injected failure {rank=} {step=}")

if event_info.event == EventInjectorEvent.AllreduceFailure:
print(f"injecting allreduce failure {rank=} {step=}")
assert self._pg is not None
self._pg.report_future_error(
RuntimeError("injected allreduce error")
)
return

if event_info.event == EventInjectorEvent.Barrier:
print(f"waiting for barrier {rank=} {step=}")
cast(threading.Barrier, event_info.data).wait()
Expand Down
56 changes: 56 additions & 0 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ def __init__(
self._pg: Optional[BaseProcessGroup] = pg
self._timeout = timeout

def getBackendName(self) -> str:
pg = self._pg
if isinstance(pg, ProcessGroup):
return pg.getBackendName()

raise NotImplementedError("not implemented")

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
pg = self._pg
if isinstance(pg, ProcessGroup):
Expand Down Expand Up @@ -1016,6 +1023,55 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return _DummyWork(tensors)


class FakeProcessGroupWrapper(ProcessGroupWrapper):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could maybe move this class into tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah there's a bunch of other fakes in this file. maybe take this up separately? we also want to upstream process_group to pytorch

"""
This is a wrapper around any ProcessGroup that can be used to inject
errors into the process group at various points.

This is intended to be used for tests so that they can test cases
in which process group operations error out.
"""

def __init__(self, pg: ProcessGroup) -> None:
super().__init__(pg=pg)

self._future_error: Optional[Exception] = None

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
self._future_error = None

super().configure(store_addr, rank, world_size)

def report_future_error(self, e: Exception) -> None:
"""
Report an error to this process group. This will cause the
future attached to the next operation to error out.

Args:
e: exception to report
"""
self._future_error = e

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
work = super().allreduce(tensors, opts)

if self._future_error is None:
return work

fut = work.get_future()

def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> List[torch.Tensor]:
future_error, self._future_error = self._future_error, None
assert future_error is not None
raise future_error

fut = fut.then(callback)

return work


class _ManagedWork(Work):
def __init__(self, manager: "Manager", work: Work, default_result: object) -> None:
super().__init__()
Expand Down