Skip to content

Commit 3c0cfdd

Browse files
authored
Merge pull request atong01#80 from atong01/add_tests
Add tests on lambda_t, low sigma values and guided functions
2 parents e96c52c + 3f08702 commit 3c0cfdd

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

.github/workflows/test.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,10 @@ jobs:
6262
pip install -e .
6363
6464
- name: Run tests and collect coverage
65-
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/
65+
run: pytest . --cov torchcfm --ignore=runner --ignore=examples --ignore=torchcfm/models/ --cov-fail-under=30
6666

6767
- name: Upload coverage to Codecov
6868
uses: codecov/codecov-action@v3
6969
with:
7070
name: codecov-torchcfm
7171
verbose: true
72-
fail_ci_if_error: true

.github/workflows/test_runner.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,10 @@ jobs:
6464
pip install -e .
6565
6666
- name: Run tests and collect coverage
67-
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
67+
run: pytest runner --cov runner --cov-fail-under=30 # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
6868

6969
- name: Upload coverage to Codecov
7070
uses: codecov/codecov-action@v3
7171
with:
7272
name: codecov-runner
7373
verbose: true
74-
fail_ci_if_error: true

tests/test_conditional_flow_matcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def sample_plan(method, x0, x1, sigma):
9292

9393
@pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"])
9494
# Test both integer and floating sigma
95-
@pytest.mark.parametrize("sigma", [0.0, 0.5, 1.5, 0, 1])
95+
@pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1])
9696
@pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]])
9797
def test_fm(method, sigma, shape):
9898
batch_size = TEST_BATCH_SIZE
@@ -107,6 +107,7 @@ def test_fm(method, sigma, shape):
107107
torch.manual_seed(TEST_SEED)
108108
np.random.seed(TEST_SEED)
109109
t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)
110+
_ = FM.compute_lambda(t)
110111

111112
if method in ["sb_cfm", "exact_ot_cfm"]:
112113
torch.manual_seed(TEST_SEED)

tests/test_time_t.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def test_random_Tensor_t(FM):
4949
SchrodingerBridgeConditionalFlowMatcher(sigma=0.1),
5050
],
5151
)
52-
def test_guided_random_Tensor_t(FM):
52+
@pytest.mark.parametrize("return_noise", [True, False])
53+
def test_guided_random_Tensor_t(FM, return_noise):
5354
# Test guided_sample_location_and_conditional_flow functions
5455
x0 = torch.randn(batch_size, 2)
5556
y0 = torch.randint(high=10, size=(batch_size, 1))
@@ -58,13 +59,13 @@ def test_guided_random_Tensor_t(FM):
5859

5960
torch.manual_seed(seed)
6061
t_given = torch.rand(batch_size)
61-
t_given, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
62-
x0, x1, y0=y0, y1=y1, t=t_given
63-
)
62+
t_given = FM.guided_sample_location_and_conditional_flow(
63+
x0, x1, y0=y0, y1=y1, t=t_given, return_noise=return_noise
64+
)[0]
6465

6566
torch.manual_seed(seed)
66-
t_random, xt, ut, y0, y1 = FM.guided_sample_location_and_conditional_flow(
67-
x0, x1, y0=y0, y1=y1, t=None
68-
)
67+
t_random = FM.guided_sample_location_and_conditional_flow(
68+
x0, x1, y0=y0, y1=y1, t=None, return_noise=return_noise
69+
)[0]
6970

7071
assert any(t_given == t_random)

0 commit comments

Comments
 (0)