Skip to content

Test rapids-dask-dependency UCX protocol selection #1520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,14 @@ rapids-logger "pytest dask-cuda"
rapids-logger "Run local benchmark"
./ci/run_benchmarks.sh

# Run rapids-dask-dependency tests without `distributed-ucxx`, ensuring the protocol
# selection mechanism works also on "legacy" environments where only `ucx-py` is
# installed.
# TODO: remove as part of https://github.com/rapidsai/dask-cuda/issues/1517
mamba remove -y distributed-ucxx
./ci/run_pytest.sh \
--junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cuda-rdd-protocol-selection.xml" \
-k "test_rdd_protocol"

rapids-logger "Test script exiting with latest error code: $EXITCODE"
exit ${EXITCODE}
22 changes: 22 additions & 0 deletions ci/test_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,31 @@ rapids-logger "Installing test dependencies"
# echo to expand wildcard
rapids-pip-retry install -v --prefer-binary -r /tmp/requirements-test.txt "$(echo "${DASK_CUDA_WHEELHOUSE}"/dask_cuda*.whl)"

EXITCODE=0
# shellcheck disable=SC2317
set_exit_code() {
EXITCODE=$?
rapids-logger "Test failed with error ${EXITCODE}"
}
trap set_exit_code ERR
set +e
Comment on lines +21 to +28
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Along with lines 47-48, this reproduces the same error handling mechanism as that of ci/test_python.sh.


rapids-logger "pytest dask-cuda"
./ci/run_pytest.sh \
--junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cuda.xml"

rapids-logger "Run local benchmark"
./ci/run_benchmarks.sh

# Run rapids-dask-dependency tests without `distributed-ucxx`, ensuring the protocol
# selection mechanism works also on "legacy" environments where only `ucx-py` is
# installed.
# TODO: remove as part of https://github.com/rapidsai/dask-cuda/issues/1517
distributed_ucxx_package_name="$(pip list | grep distributed-ucxx | awk '{print $1}')"
pip uninstall -y "${distributed_ucxx_package_name}"
./ci/run_pytest.sh \
--junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cuda-rdd-protocol-selection.xml" \
-k "test_rdd_protocol"

rapids-logger "Test script exiting with latest error code: $EXITCODE"
exit ${EXITCODE}
160 changes: 160 additions & 0 deletions dask_cuda/tests/test_rdd_ucx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
Copy link
Member Author

@pentschev pentschev Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was intended for rapids-dask-dependency but was added here instead because there are no GPU build/test jobs in the rapids-dask-dependency repository. Adding it there would require too much effort for just a temporary test, just testing it here is then preferable.

# SPDX-License-Identifier: Apache-2.0


import importlib
import io
import multiprocessing as mp
import sys

import pytest

from dask_cuda import LocalCUDACluster

mp = mp.get_context("spawn") # type: ignore


def _has_distributed_ucxx() -> bool:
return bool(importlib.util.find_spec("distributed_ucxx"))


def _test_protocol_ucx():
with LocalCUDACluster(protocol="ucx") as cluster:
assert cluster.scheduler_comm.address.startswith("ucx://")

if _has_distributed_ucxx():
import distributed_ucxx

assert all(
isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX)
for batched_send in cluster.scheduler.stream_comms.values()
)
else:
import rapids_dask_dependency

assert all(
isinstance(
batched_send.comm,
rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX,
)
for batched_send in cluster.scheduler.stream_comms.values()
)


def _test_protocol_ucxx():
if _has_distributed_ucxx():
with LocalCUDACluster(protocol="ucxx") as cluster:
assert cluster.scheduler_comm.address.startswith("ucxx://")
import distributed_ucxx

assert all(
isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX)
for batched_send in cluster.scheduler.stream_comms.values()
)
else:
with pytest.raises(RuntimeError, match="Cluster failed to start"):
LocalCUDACluster(protocol="ucxx")


def _test_protocol_ucx_old():
with LocalCUDACluster(protocol="ucx-old") as cluster:
assert cluster.scheduler_comm.address.startswith("ucx-old://")

import rapids_dask_dependency

assert all(
isinstance(
batched_send.comm,
rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX,
)
for batched_send in cluster.scheduler.stream_comms.values()
)


def _run_test_with_output_capture(test_func_name, conn):
"""Run a test function in a subprocess and capture stdout/stderr."""
# Redirect stdout and stderr to capture output
old_stdout = sys.stdout
old_stderr = sys.stderr
captured_output = io.StringIO()
sys.stdout = sys.stderr = captured_output

try:
# Import and run the test function
if test_func_name == "_test_protocol_ucx":
_test_protocol_ucx()
elif test_func_name == "_test_protocol_ucxx":
_test_protocol_ucxx()
elif test_func_name == "_test_protocol_ucx_old":
_test_protocol_ucx_old()
else:
raise ValueError(f"Unknown test function: {test_func_name}")

output = captured_output.getvalue()
conn.send((True, output)) # True = success
except Exception as e:
output = captured_output.getvalue()
output += f"\nException: {e}"
import traceback

output += f"\nTraceback:\n{traceback.format_exc()}"
conn.send((False, output)) # False = failure
finally:
# Restore original stdout/stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
conn.close()


@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "ucx-old"])
def test_rdd_protocol(protocol):
"""Test rapids-dask-dependency protocol selection"""
if protocol == "ucx":
test_func_name = "_test_protocol_ucx"
elif protocol == "ucxx":
test_func_name = "_test_protocol_ucxx"
else:
test_func_name = "_test_protocol_ucx_old"

# Create a pipe for communication between parent and child processes
parent_conn, child_conn = mp.Pipe()
p = mp.Process(
target=_run_test_with_output_capture, args=(test_func_name, child_conn)
)

p.start()
p.join(timeout=60)

if p.is_alive():
p.kill()
p.close()
raise TimeoutError("Test process timed out")

# Get the result from the child process
success, output = parent_conn.recv()

# Check that the test passed
assert success, f"Test failed in subprocess. Output:\n{output}"

# For the ucx protocol, check if warnings are printed when distributed_ucxx is not
# available
if protocol == "ucx" and not _has_distributed_ucxx():
# Check if the warning about protocol='ucx' is printed
print(f"Output for {protocol} protocol:\n{output}")
assert (
"you have requested protocol='ucx'" in output
), f"Expected warning not found in output: {output}"
assert (
"'distributed-ucxx' is not installed" in output
), f"Expected warning about distributed-ucxx not found in output: {output}"
elif protocol == "ucx" and _has_distributed_ucxx():
# When distributed_ucxx is available, the warning should NOT be printed
assert "you have requested protocol='ucx'" not in output, (
"Warning should not be printed when distributed_ucxx is available: "
f"{output}"
)
elif protocol == "ucx-old":
# The ucx-old protocol should not generate warnings
assert (
"you have requested protocol='ucx'" not in output
), f"Warning should not be printed for ucx-old protocol: {output}"