Skip to content

Commit 2021643

Browse files
committed
ci: Selectively run tests
Signed-off-by: oliver könig <[email protected]>
1 parent c62a613 commit 2021643

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

.github/scripts/nemo_dependencies.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,16 @@ def find_top_level_packages(nemo_root: str) -> List[str]:
8787
"""Find all top-level packages under nemo directory."""
8888
packages: List[str] = []
8989
nemo_dir = os.path.join(nemo_root, 'nemo')
90+
tests_dir = os.path.join(nemo_root, 'tests')
9091

9192
if not os.path.exists(nemo_dir):
9293
print(f"Warning: nemo directory not found at {nemo_dir}")
9394
return packages
95+
if not os.path.exists(tests_dir):
96+
print(f"Warning: nemo directory not found at {nemo_dir}")
97+
return packages
9498

95-
for item in os.listdir(nemo_dir):
99+
for item in os.listdir(nemo_dir) + os.listdir(tests_dir):
96100
item_path = os.path.join(nemo_dir, item)
97101
if os.path.isdir(item_path) and not item.startswith('__'):
98102
packages.append(item)
@@ -125,17 +129,18 @@ def build_dependency_graph(nemo_root: str) -> Dict[str, List[str]]:
125129

126130
dependencies: Dict[str, List[str]] = {}
127131

128-
# Second pass: analyze imports and build reverse dependencies
129132
for file_path in find_python_files(nemo_root):
130133
relative_path = os.path.relpath(file_path, nemo_root)
131134
parts = relative_path.split(os.sep)
132135

133-
if len(parts) == 1 or parts[-1] == "__init__.py" or parts[0] != "nemo":
136+
if len(parts) == 1 or parts[-1] == "__init__.py" or (parts[0] != "nemo" and parts[0] != "tests"):
134137
continue
135138

136139
module_path = relative_path.replace(".py", "").replace("/", ".")
137140
if parts[1] in top_level_packages and parts[1] != 'collections':
138141
dependencies[module_path] = list(set(analyze_imports(nemo_root, file_path)))
142+
elif parts[0] == 'tests':
143+
dependencies[module_path] = [relative_path]
139144
elif parts[1] == 'collections':
140145
dependencies[module_path] = list(set(analyze_imports(nemo_root, file_path)))
141146

@@ -181,7 +186,7 @@ def build_dependency_graph(nemo_root: str) -> Dict[str, List[str]]:
181186
simplified_dependencies: Dict[str, List[str]] = {}
182187
for package, deps in dependencies.items():
183188
package_parts = package.split('.')
184-
print(f"{os.path.join(*package_parts[:-1])}.py")
189+
185190
if os.path.isfile((file_path := f"{os.path.join(*package_parts[:-1])}.py")):
186191
simplified_package_path = file_path
187192
elif os.path.isdir((file_path := f"{os.path.join(*package_parts[:-1])}")):
@@ -221,13 +226,17 @@ def build_dependency_graph(nemo_root: str) -> Dict[str, List[str]]:
221226
if "asr" in dep or "tts" in dep or "speechlm" in dep or "audio" in dep:
222227
new_deps.append("speech")
223228

224-
if "export" in dep or "deploy" in dep:
229+
elif "export" in dep or "deploy" in dep:
225230
new_deps.append("export-deploy")
226231

227-
if "llm" in dep or "vlm" in dep or "automodel" in dep:
232+
elif "llm" in dep or "vlm" in dep or "automodel" in dep:
228233
new_deps.append("automodel")
229234

230-
if "collections" in dep and not ("asr" in dep or "tts" in dep or "speechlm" in dep or "audio" in dep):
235+
elif "tests/collections" in dep:
236+
new_deps.append("unit-tests")
237+
continue
238+
239+
else:
231240
new_deps.append("nemo2")
232241

233242
bucket_deps[package] = sorted(list(set(new_deps)))

0 commit comments

Comments
 (0)