diff --git a/ci/test_python.sh b/ci/test_python.sh index 0ae369a6..11cc546b 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -52,5 +52,14 @@ rapids-logger "pytest dask-cuda" rapids-logger "Run local benchmark" ./ci/run_benchmarks.sh +# Run rapids-dask-dependency tests without `distributed-ucxx`, ensuring the protocol +# selection mechanism works also on "legacy" environments where only `ucx-py` is +# installed. +# TODO: remove as part of https://github.com/rapidsai/dask-cuda/issues/1517 +mamba remove -y distributed-ucxx +./ci/run_pytest.sh \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cuda-rdd-protocol-selection.xml" \ + -k "test_rdd_protocol" + rapids-logger "Test script exiting with latest error code: $EXITCODE" exit ${EXITCODE} diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index 21c7dc8d..d4bc26f7 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -18,9 +18,31 @@ rapids-logger "Installing test dependencies" # echo to expand wildcard rapids-pip-retry install -v --prefer-binary -r /tmp/requirements-test.txt "$(echo "${DASK_CUDA_WHEELHOUSE}"/dask_cuda*.whl)" +EXITCODE=0 +# shellcheck disable=SC2317 +set_exit_code() { + EXITCODE=$? + rapids-logger "Test failed with error ${EXITCODE}" +} +trap set_exit_code ERR +set +e + rapids-logger "pytest dask-cuda" ./ci/run_pytest.sh \ --junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cuda.xml" rapids-logger "Run local benchmark" ./ci/run_benchmarks.sh + +# Run rapids-dask-dependency tests without `distributed-ucxx`, ensuring the protocol +# selection mechanism works also on "legacy" environments where only `ucx-py` is +# installed. +# TODO: remove as part of https://github.com/rapidsai/dask-cuda/issues/1517 +distributed_ucxx_package_name="$(pip list | grep distributed-ucxx | awk '{print $1}')" +pip uninstall -y "${distributed_ucxx_package_name}" +./ci/run_pytest.sh \ + --junitxml="${RAPIDS_TESTS_DIR}/junit-dask-cuda-rdd-protocol-selection.xml" \ + -k "test_rdd_protocol" + +rapids-logger "Test script exiting with latest error code: $EXITCODE" +exit ${EXITCODE} diff --git a/dask_cuda/tests/test_rdd_ucx.py b/dask_cuda/tests/test_rdd_ucx.py new file mode 100644 index 00000000..172a53b4 --- /dev/null +++ b/dask_cuda/tests/test_rdd_ucx.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + + +import importlib +import io +import multiprocessing as mp +import sys + +import pytest + +from dask_cuda import LocalCUDACluster + +mp = mp.get_context("spawn") # type: ignore + + +def _has_distributed_ucxx() -> bool: + return bool(importlib.util.find_spec("distributed_ucxx")) + + +def _test_protocol_ucx(): + with LocalCUDACluster(protocol="ucx") as cluster: + assert cluster.scheduler_comm.address.startswith("ucx://") + + if _has_distributed_ucxx(): + import distributed_ucxx + + assert all( + isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX) + for batched_send in cluster.scheduler.stream_comms.values() + ) + else: + import rapids_dask_dependency + + assert all( + isinstance( + batched_send.comm, + rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX, + ) + for batched_send in cluster.scheduler.stream_comms.values() + ) + + +def _test_protocol_ucxx(): + if _has_distributed_ucxx(): + with LocalCUDACluster(protocol="ucxx") as cluster: + assert cluster.scheduler_comm.address.startswith("ucxx://") + import distributed_ucxx + + assert all( + isinstance(batched_send.comm, distributed_ucxx.ucxx.UCXX) + for batched_send in cluster.scheduler.stream_comms.values() + ) + else: + with pytest.raises(RuntimeError, match="Cluster failed to start"): + LocalCUDACluster(protocol="ucxx") + + +def _test_protocol_ucx_old(): + with LocalCUDACluster(protocol="ucx-old") as cluster: + assert cluster.scheduler_comm.address.startswith("ucx-old://") + + import rapids_dask_dependency + + assert all( + isinstance( + batched_send.comm, + rapids_dask_dependency.patches.distributed.comm.__rdd_patch_ucx.UCX, + ) + for batched_send in cluster.scheduler.stream_comms.values() + ) + + +def _run_test_with_output_capture(test_func_name, conn): + """Run a test function in a subprocess and capture stdout/stderr.""" + # Redirect stdout and stderr to capture output + old_stdout = sys.stdout + old_stderr = sys.stderr + captured_output = io.StringIO() + sys.stdout = sys.stderr = captured_output + + try: + # Import and run the test function + if test_func_name == "_test_protocol_ucx": + _test_protocol_ucx() + elif test_func_name == "_test_protocol_ucxx": + _test_protocol_ucxx() + elif test_func_name == "_test_protocol_ucx_old": + _test_protocol_ucx_old() + else: + raise ValueError(f"Unknown test function: {test_func_name}") + + output = captured_output.getvalue() + conn.send((True, output)) # True = success + except Exception as e: + output = captured_output.getvalue() + output += f"\nException: {e}" + import traceback + + output += f"\nTraceback:\n{traceback.format_exc()}" + conn.send((False, output)) # False = failure + finally: + # Restore original stdout/stderr + sys.stdout = old_stdout + sys.stderr = old_stderr + conn.close() + + +@pytest.mark.parametrize("protocol", ["ucx", "ucxx", "ucx-old"]) +def test_rdd_protocol(protocol): + """Test rapids-dask-dependency protocol selection""" + if protocol == "ucx": + test_func_name = "_test_protocol_ucx" + elif protocol == "ucxx": + test_func_name = "_test_protocol_ucxx" + else: + test_func_name = "_test_protocol_ucx_old" + + # Create a pipe for communication between parent and child processes + parent_conn, child_conn = mp.Pipe() + p = mp.Process( + target=_run_test_with_output_capture, args=(test_func_name, child_conn) + ) + + p.start() + p.join(timeout=60) + + if p.is_alive(): + p.kill() + p.close() + raise TimeoutError("Test process timed out") + + # Get the result from the child process + success, output = parent_conn.recv() + + # Check that the test passed + assert success, f"Test failed in subprocess. Output:\n{output}" + + # For the ucx protocol, check if warnings are printed when distributed_ucxx is not + # available + if protocol == "ucx" and not _has_distributed_ucxx(): + # Check if the warning about protocol='ucx' is printed + print(f"Output for {protocol} protocol:\n{output}") + assert ( + "you have requested protocol='ucx'" in output + ), f"Expected warning not found in output: {output}" + assert ( + "'distributed-ucxx' is not installed" in output + ), f"Expected warning about distributed-ucxx not found in output: {output}" + elif protocol == "ucx" and _has_distributed_ucxx(): + # When distributed_ucxx is available, the warning should NOT be printed + assert "you have requested protocol='ucx'" not in output, ( + "Warning should not be printed when distributed_ucxx is available: " + f"{output}" + ) + elif protocol == "ucx-old": + # The ucx-old protocol should not generate warnings + assert ( + "you have requested protocol='ucx'" not in output + ), f"Warning should not be printed for ucx-old protocol: {output}"