1
- import json
2
1
import os
3
2
import subprocess
4
3
5
- from tools .stats .s3_stat_parser import (
6
- get_previous_reports_for_branch ,
7
- Report ,
8
- Version2Report ,
9
- HAVE_BOTO3 ,
10
- )
11
4
from tools .stats .import_test_stats import get_disabled_tests , get_slow_tests
12
5
13
- from typing import Any , Dict , List , Optional , Tuple , cast
14
- from typing_extensions import TypedDict
15
-
16
-
17
- class JobTimeJSON (TypedDict ):
18
- commit : str
19
- JOB_BASE_NAME : str
20
- job_times : Dict [str , float ]
21
-
22
-
23
- def _get_stripped_CI_job () -> str :
24
- return os .environ .get ("BUILD_ENVIRONMENT" , "" )
25
-
26
-
27
- def _get_job_times_json (job_times : Dict [str , float ]) -> JobTimeJSON :
28
- return {
29
- "commit" : subprocess .check_output (
30
- ["git" , "rev-parse" , "HEAD" ], encoding = "ascii"
31
- ).strip (),
32
- "JOB_BASE_NAME" : _get_stripped_CI_job (),
33
- "job_times" : job_times ,
34
- }
35
-
36
-
37
- def _calculate_job_times (reports : List ["Report" ]) -> Dict [str , float ]:
38
- """Compute test runtime by filename: ("test_file_name" -> (current_avg, # values))"""
39
- jobs_to_times : Dict [str , Tuple [float , int ]] = dict ()
40
- for report in reports :
41
- v_report = cast (Version2Report , report )
42
- assert (
43
- "format_version" in v_report .keys () and v_report .get ("format_version" ) == 2
44
- ), "S3 format currently handled is version 2 only"
45
- files : Dict [str , Any ] = v_report ["files" ]
46
- for name , test_file in files .items ():
47
- if name not in jobs_to_times :
48
- jobs_to_times [name ] = (test_file ["total_seconds" ], 1 )
49
- else :
50
- curr_avg , curr_count = jobs_to_times [name ]
51
- new_count = curr_count + 1
52
- new_avg = (
53
- curr_avg * curr_count + test_file ["total_seconds" ]
54
- ) / new_count
55
- jobs_to_times [name ] = (new_avg , new_count )
56
-
57
- return {job : time for job , (time , _ ) in jobs_to_times .items ()}
6
+ from typing import Dict , List , Tuple
58
7
59
8
60
9
def calculate_shards (
@@ -91,63 +40,6 @@ def calculate_shards(
91
40
return sharded_jobs
92
41
93
42
94
- def _pull_job_times_from_S3 () -> Dict [str , float ]:
95
- if HAVE_BOTO3 :
96
- ci_job_prefix = _get_stripped_CI_job ()
97
- s3_reports : List ["Report" ] = get_previous_reports_for_branch (
98
- "origin/viable/strict" , ci_job_prefix
99
- )
100
- else :
101
- print (
102
- "Uh oh, boto3 is not found. Either it is not installed or we failed to import s3_stat_parser."
103
- )
104
- print (
105
- "If not installed, please install boto3 for automatic sharding and test categorization."
106
- )
107
- s3_reports = []
108
-
109
- if len (s3_reports ) == 0 :
110
- print ("::warning:: Gathered no reports from S3. Please proceed without them." )
111
- return dict ()
112
-
113
- return _calculate_job_times (s3_reports )
114
-
115
-
116
- def _query_past_job_times (test_times_file : Optional [str ] = None ) -> Dict [str , float ]:
117
- """Read historic test job times from a file.
118
-
119
- If the file doesn't exist or isn't matching current commit. It will download data from S3 and exported it.
120
- """
121
- if test_times_file and os .path .exists (test_times_file ):
122
- with open (test_times_file ) as file :
123
- test_times_json : JobTimeJSON = json .load (file )
124
-
125
- curr_commit = subprocess .check_output (
126
- ["git" , "rev-parse" , "HEAD" ], encoding = "ascii"
127
- ).strip ()
128
- file_commit = test_times_json .get ("commit" , "" )
129
- curr_ci_job = _get_stripped_CI_job ()
130
- file_ci_job = test_times_json .get ("JOB_BASE_NAME" , "N/A" )
131
- if curr_commit != file_commit :
132
- print (f"Current test times file is from different commit { file_commit } ." )
133
- elif curr_ci_job != file_ci_job :
134
- print (f"Current test times file is for different CI job { file_ci_job } ." )
135
- else :
136
- print (
137
- f"Found stats for current commit: { curr_commit } and job: { curr_ci_job } . Proceeding with those values."
138
- )
139
- return test_times_json .get ("job_times" , {})
140
-
141
- # Found file, but commit or CI job in JSON doesn't match
142
- print (
143
- f"Overwriting current file with stats based on current commit: { curr_commit } and CI job: { curr_ci_job } "
144
- )
145
-
146
- job_times = export_S3_test_times (test_times_file )
147
-
148
- return job_times
149
-
150
-
151
43
def _query_changed_test_files () -> List [str ]:
152
44
default_branch = f"origin/{ os .environ .get ('GIT_DEFAULT_BRANCH' , 'master' )} "
153
45
cmd = ["git" , "diff" , "--name-only" , default_branch , "HEAD" ]
@@ -161,47 +53,6 @@ def _query_changed_test_files() -> List[str]:
161
53
return lines
162
54
163
55
164
- # Get sharded test allocation based on historic S3 data.
165
- def get_shard_based_on_S3 (
166
- which_shard : int , num_shards : int , tests : List [str ], test_times_file : str
167
- ) -> List [str ]:
168
- # Short circuit and don't do any work if there's only 1 shard
169
- if num_shards == 1 :
170
- return tests
171
-
172
- jobs_to_times = _query_past_job_times (test_times_file )
173
-
174
- # Got no stats from S3, returning early to save runtime
175
- if len (jobs_to_times ) == 0 :
176
- print (
177
- "::warning:: Gathered no stats from S3. Proceeding with default sharding plan."
178
- )
179
- return tests [which_shard - 1 :: num_shards ]
180
-
181
- shards = calculate_shards (num_shards , tests , jobs_to_times )
182
- _ , tests_from_shard = shards [which_shard - 1 ]
183
- return tests_from_shard
184
-
185
-
186
- def get_slow_tests_based_on_S3 (
187
- test_list : List [str ], td_list : List [str ], slow_test_threshold : int
188
- ) -> List [str ]:
189
- """Get list of slow tests based on historic S3 data."""
190
- jobs_to_times : Dict [str , float ] = _query_past_job_times ()
191
-
192
- # Got no stats from S3, returning early to save runtime
193
- if len (jobs_to_times ) == 0 :
194
- print ("::warning:: Gathered no stats from S3. No new slow tests calculated." )
195
- return []
196
-
197
- slow_tests : List [str ] = []
198
- for test in test_list :
199
- if test in jobs_to_times and test not in td_list :
200
- if jobs_to_times [test ] > slow_test_threshold :
201
- slow_tests .append (test )
202
- return slow_tests
203
-
204
-
205
56
def get_reordered_tests (tests : List [str ]) -> List [str ]:
206
57
"""Get the reordered test filename list based on github PR history or git changed file."""
207
58
prioritized_tests : List [str ] = []
@@ -242,20 +93,6 @@ def get_reordered_tests(tests: List[str]) -> List[str]:
242
93
return tests
243
94
244
95
245
- # TODO Refactor this and unify with tools.stats.export_slow_tests
246
- def export_S3_test_times (test_times_filename : Optional [str ] = None ) -> Dict [str , float ]:
247
- test_times : Dict [str , float ] = _pull_job_times_from_S3 ()
248
- if test_times_filename is not None :
249
- print (f"Exporting S3 test stats to { test_times_filename } ." )
250
- if os .path .exists (test_times_filename ):
251
- print (f"Overwriting existent file: { test_times_filename } " )
252
- with open (test_times_filename , "w+" ) as file :
253
- job_times_json = _get_job_times_json (test_times )
254
- json .dump (job_times_json , file , indent = " " , separators = ("," , ": " ))
255
- file .write ("\n " )
256
- return test_times
257
-
258
-
259
96
def get_test_case_configs (dirpath : str ) -> None :
260
97
get_slow_tests (dirpath = dirpath )
261
98
get_disabled_tests (dirpath = dirpath )
0 commit comments