|
32 | 32 | try:
|
33 | 33 | # using tools/ to optimize test run.
|
34 | 34 | sys.path.append(str(REPO_ROOT))
|
| 35 | + from tools.stats.import_test_stats import get_test_times |
35 | 36 | from tools.testing.test_selections import (
|
36 |
| - export_S3_test_times, |
37 |
| - get_shard_based_on_S3, |
38 | 37 | get_reordered_tests,
|
39 | 38 | get_test_case_configs,
|
| 39 | + calculate_shards, |
40 | 40 | )
|
41 | 41 | HAVE_TEST_SELECTION_TOOLS = True
|
42 | 42 | except ImportError:
|
@@ -677,13 +677,6 @@ def parse_args():
|
677 | 677 | help="additional arguments passed through to unittest, e.g., "
|
678 | 678 | "python run_test.py -i sparse -- TestSparse.test_factory_size_check",
|
679 | 679 | )
|
680 |
| - parser.add_argument( |
681 |
| - "--export-past-test-times", |
682 |
| - nargs="?", |
683 |
| - type=str, |
684 |
| - const=TEST_TIMES_FILE, |
685 |
| - help="dumps test times from previous S3 stats into a file, format JSON", |
686 |
| - ) |
687 | 680 | parser.add_argument(
|
688 | 681 | "--shard",
|
689 | 682 | nargs=2,
|
@@ -838,11 +831,21 @@ def get_selected_tests(options):
|
838 | 831 | assert num_shards <= len(
|
839 | 832 | selected_tests
|
840 | 833 | ), f"Number of shards must be less than {len(selected_tests)}"
|
841 |
| - # TODO: fix this to use test_times_filename, but currently this is not working |
842 |
| - # because setting the export arg immeidately halts the test execution. |
843 |
| - selected_tests = get_shard_based_on_S3( |
844 |
| - which_shard, num_shards, selected_tests, TEST_TIMES_FILE |
845 |
| - ) |
| 834 | + |
| 835 | + if num_shards == 1: |
| 836 | + return selected_tests |
| 837 | + |
| 838 | + # Download previous test times to make sharding decisions |
| 839 | + test_file_times = get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE) |
| 840 | + if len(test_file_times) == 0: |
| 841 | + print( |
| 842 | + "::warning:: Gathered no stats from S3. Proceeding with default sharding plan." |
| 843 | + ) |
| 844 | + selected_tests = selected_tests[which_shard - 1 :: num_shards] |
| 845 | + else: |
| 846 | + shards = calculate_shards(num_shards, selected_tests, test_file_times) |
| 847 | + _, tests_from_shard = shards[which_shard - 1] |
| 848 | + selected_tests = tests_from_shard |
846 | 849 |
|
847 | 850 | # skip all distributed tests if distributed package is not available.
|
848 | 851 | if not dist.is_available():
|
@@ -882,15 +885,6 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
|
882 | 885 | def main():
|
883 | 886 | options = parse_args()
|
884 | 887 |
|
885 |
| - # TODO: move this export & download function in tools/ folder |
886 |
| - test_times_filename = options.export_past_test_times |
887 |
| - if test_times_filename: |
888 |
| - print( |
889 |
| - f"Exporting past test times from S3 to {test_times_filename}, no tests will be run." |
890 |
| - ) |
891 |
| - export_S3_test_times(test_times_filename) |
892 |
| - return |
893 |
| - |
894 | 888 | test_directory = str(REPO_ROOT / "test")
|
895 | 889 | selected_tests = get_selected_tests(options)
|
896 | 890 |
|
|
0 commit comments