Skip to content

Commit d3fe6f2

Browse files
authored
Merge branch 'main' into forest_flow
2 parents 3c7165f + 81fcb8d commit d3fe6f2

20 files changed

+392
-30
lines changed

.github/workflows/test.yaml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Tests
1+
name: TorchCFM Tests
22

33
on:
44
push:
@@ -14,7 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
17-
python-version: ["3.8", "3.9", "3.10"]
17+
python-version: ["3.8", "3.9", "3.10", "3.11"]
1818

1919
steps:
2020
- name: Checkout
@@ -28,7 +28,6 @@ jobs:
2828
- name: Install dependencies
2929
run: |
3030
python -m pip install --upgrade pip
31-
pip install -r runner-requirements.txt
3231
pip install pytest
3332
pip install sh
3433
pip install -e .
@@ -39,10 +38,10 @@ jobs:
3938
4039
- name: Run pytest
4140
run: |
42-
pytest -v runner
41+
pytest -v --ignore=examples --ignore=runner
4342
4443
# upload code coverage report
45-
code-coverage:
44+
code-coverage-torchcfm:
4645
runs-on: ubuntu-latest
4746

4847
steps:
@@ -57,14 +56,17 @@ jobs:
5756
- name: Install dependencies
5857
run: |
5958
python -m pip install --upgrade pip
60-
pip install -r runner-requirements.txt
6159
pip install pytest
6260
pip install pytest-cov[toml]
6361
pip install sh
6462
pip install -e .
6563
6664
- name: Run tests and collect coverage
67-
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
65+
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/
6866

6967
- name: Upload coverage to Codecov
7068
uses: codecov/codecov-action@v3
69+
with:
70+
name: codecov-torchcfm
71+
verbose: true
72+
fail_ci_if_error: true

.github/workflows/test_runner.yaml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
name: Runner Tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main, "release/*"]
8+
9+
jobs:
10+
run_tests_ubuntu:
11+
runs-on: ${{ matrix.os }}
12+
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
17+
python-version: ["3.8", "3.9", "3.10"]
18+
19+
steps:
20+
- name: Checkout
21+
uses: actions/checkout@v3
22+
23+
- name: Set up Python ${{ matrix.python-version }}
24+
uses: actions/setup-python@v4
25+
with:
26+
python-version: ${{ matrix.python-version }}
27+
28+
- name: Install dependencies
29+
run: |
30+
python -m pip install --upgrade pip
31+
pip install -r runner-requirements.txt
32+
pip install pytest
33+
pip install sh
34+
pip install -e .
35+
36+
- name: List dependencies
37+
run: |
38+
python -m pip list
39+
40+
- name: Run pytest
41+
run: |
42+
pytest -v runner
43+
44+
# upload code coverage report
45+
code-coverage-runner:
46+
runs-on: ubuntu-latest
47+
48+
steps:
49+
- name: Checkout
50+
uses: actions/checkout@v3
51+
52+
- name: Set up Python 3.10
53+
uses: actions/setup-python@v4
54+
with:
55+
python-version: "3.10"
56+
57+
- name: Install dependencies
58+
run: |
59+
python -m pip install --upgrade pip
60+
pip install -r runner-requirements.txt
61+
pip install pytest
62+
pip install pytest-cov[toml]
63+
pip install sh
64+
pip install -e .
65+
66+
- name: Run tests and collect coverage
67+
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
68+
69+
- name: Upload coverage to Codecov
70+
uses: codecov/codecov-action@v3
71+
with:
72+
name: codecov-runner
73+
verbose: true
74+
fail_ci_if_error: true

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ count = true
3535

3636
[tool.bandit]
3737
skips = ["B101", "B311"]
38+
39+
[tool.isort]
40+
known_first_party = ["tests", "src"]

runner/src/datamodules/distribution_datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from pytorch_lightning import LightningDataModule
88
from pytorch_lightning.trainer.supporters import CombinedLoader
99
from sklearn.preprocessing import StandardScaler
10-
from src import utils
1110
from torch.utils.data import DataLoader, Sampler, TensorDataset, random_split
1211
from torchdyn.datasets import ToyDataset
1312

13+
from src import utils
14+
1415
from .components.base import BaseLightningDataModule
1516
from .components.time_dataset import load_dataset
1617
from .components.tnet_dataset import SCData

runner/src/eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from omegaconf import DictConfig
3939
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
4040
from pytorch_lightning.loggers import LightningLoggerBase
41+
4142
from src import utils
4243

4344
log = utils.get_pylogger(__name__)

runner/src/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from omegaconf import DictConfig
4040
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
4141
from pytorch_lightning.loggers import LightningLoggerBase
42+
4243
from src import utils
4344

4445
log = utils.get_pylogger(__name__)

runner/src/utils/rich_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from omegaconf import DictConfig, OmegaConf, open_dict
99
from pytorch_lightning.utilities import rank_zero_only
1010
from rich.prompt import Prompt
11+
1112
from src.utils import pylogger
1213

1314
log = pylogger.get_pylogger(__name__)

runner/src/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytorch_lightning import Callback
1010
from pytorch_lightning.loggers import LightningLoggerBase
1111
from pytorch_lightning.utilities import rank_zero_only
12+
1213
from src.utils import pylogger, rich_utils
1314

1415
log = pylogger.get_pylogger(__name__)

runner/tests/helpers/run_if.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from packaging.version import Version
1212
from pkg_resources import get_distribution
13+
1314
from tests.helpers.package_available import (
1415
_COMET_AVAILABLE,
1516
_DEEPSPEED_AVAILABLE,

runner/tests/helpers/run_sh_command.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List
22

33
import pytest
4+
45
from tests.helpers.package_available import _SH_AVAILABLE
56

67
if _SH_AVAILABLE:

0 commit comments

Comments
 (0)