Skip to content

Commit 9f58d5d

Browse files
suopytorchmergebot
authored andcommitted
[test stats] use published test stats for sharding (#81116)
Use the nightly-published test stats to perform sharding, instead of calculating it in every build job. Pull Request resolved: #81116 Approved by: https://github.com/janeyx99
1 parent fb93c39 commit 9f58d5d

File tree

7 files changed

+34
-37
lines changed

7 files changed

+34
-37
lines changed

.github/workflows/_linux-build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ jobs:
135135
- name: Archive artifacts into zip
136136
if: inputs.build-generates-artifacts
137137
run: |
138-
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
138+
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin
139139
140140
- name: Store PyTorch Build Artifacts on S3
141141
uses: seemethere/upload-artifact-s3@v5

.jenkins/pytorch/build.sh

-6
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,4 @@ else
296296
fi
297297
fi
298298

299-
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
300-
# export test times so that potential sharded tests that'll branch off this build will use consistent data
301-
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
302-
python test/run_test.py --export-past-test-times
303-
fi
304-
305299
print_sccache_stats

.jenkins/pytorch/win-test-helpers/build_pytorch.bat

-3
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,6 @@ python setup.py install --cmake && sccache --show-stats && (
146146
if errorlevel 1 exit /b
147147
if not errorlevel 0 exit /b
148148

149-
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
150-
python test/run_test.py --export-past-test-times %PYTORCH_FINAL_PACKAGE_DIR%/.pytorch-test-times.json
151-
152149
:: Also save build/.ninja_log as an artifact
153150
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
154151
)

.jenkins/pytorch/win-test-helpers/test_python_jit_legacy.bat

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
22

33
echo Copying over test times file
4-
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
54

65
pushd test
76

.jenkins/pytorch/win-test-helpers/test_python_shard.bat

-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ if "%SHARD_NUMBER%" == "1" (
2121
)
2222
)
2323

24-
echo Copying over test times file
25-
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
26-
2724
echo Run nn tests
2825
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
2926
if ERRORLEVEL 1 goto fail

test/run_test.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
try:
3333
# using tools/ to optimize test run.
3434
sys.path.append(str(REPO_ROOT))
35+
from tools.stats.import_test_stats import get_test_times
3536
from tools.testing.test_selections import (
36-
export_S3_test_times,
37-
get_shard_based_on_S3,
3837
get_reordered_tests,
3938
get_test_case_configs,
39+
calculate_shards,
4040
)
4141
HAVE_TEST_SELECTION_TOOLS = True
4242
except ImportError:
@@ -677,13 +677,6 @@ def parse_args():
677677
help="additional arguments passed through to unittest, e.g., "
678678
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
679679
)
680-
parser.add_argument(
681-
"--export-past-test-times",
682-
nargs="?",
683-
type=str,
684-
const=TEST_TIMES_FILE,
685-
help="dumps test times from previous S3 stats into a file, format JSON",
686-
)
687680
parser.add_argument(
688681
"--shard",
689682
nargs=2,
@@ -838,11 +831,21 @@ def get_selected_tests(options):
838831
assert num_shards <= len(
839832
selected_tests
840833
), f"Number of shards must be less than {len(selected_tests)}"
841-
# TODO: fix this to use test_times_filename, but currently this is not working
842-
# because setting the export arg immeidately halts the test execution.
843-
selected_tests = get_shard_based_on_S3(
844-
which_shard, num_shards, selected_tests, TEST_TIMES_FILE
845-
)
834+
835+
if num_shards == 1:
836+
return selected_tests
837+
838+
# Download previous test times to make sharding decisions
839+
test_file_times = get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE)
840+
if len(test_file_times) == 0:
841+
print(
842+
"::warning:: Gathered no stats from S3. Proceeding with default sharding plan."
843+
)
844+
selected_tests = selected_tests[which_shard - 1 :: num_shards]
845+
else:
846+
shards = calculate_shards(num_shards, selected_tests, test_file_times)
847+
_, tests_from_shard = shards[which_shard - 1]
848+
selected_tests = tests_from_shard
846849

847850
# skip all distributed tests if distributed package is not available.
848851
if not dist.is_available():
@@ -882,15 +885,6 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
882885
def main():
883886
options = parse_args()
884887

885-
# TODO: move this export & download function in tools/ folder
886-
test_times_filename = options.export_past_test_times
887-
if test_times_filename:
888-
print(
889-
f"Exporting past test times from S3 to {test_times_filename}, no tests will be run."
890-
)
891-
export_S3_test_times(test_times_filename)
892-
return
893-
894888
test_directory = str(REPO_ROOT / "test")
895889
selected_tests = get_selected_tests(options)
896890

tools/stats/import_test_stats.py

+16
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def fetch_and_cache(
4141
This fetch and cache utils allows sharing between different process.
4242
"""
4343
path = os.path.join(dirpath, name)
44+
print(f"Downloading {url} to {path}")
4445

4546
def is_cached_file_valid() -> bool:
4647
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
@@ -80,6 +81,21 @@ def get_slow_tests(
8081
return {}
8182

8283

84+
def get_test_times(dirpath: str, filename: str) -> Dict[str, float]:
85+
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json"
86+
87+
def process_response(the_response: Dict[str, Any]) -> Any:
88+
build_environment = os.environ["BUILD_ENVIRONMENT"]
89+
test_config = os.environ["TEST_CONFIG"]
90+
return the_response[build_environment][test_config]
91+
92+
try:
93+
return fetch_and_cache(dirpath, filename, url, process_response)
94+
except Exception:
95+
print("Couldn't download test times...")
96+
return {}
97+
98+
8399
def get_disabled_tests(
84100
dirpath: str, filename: str = DISABLED_TESTS_FILE
85101
) -> Optional[Dict[str, Any]]:

0 commit comments

Comments
 (0)