Skip to content

Commit b4525b5

Browse files
authored
added multithreading to OTPlanSampler for "exact" solver (atong01#131)
* added multithreading to OTPlanSampler for "exact" solver * changed type hinting
1 parent f07c5cd commit b4525b5

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchcfm/optimal_transport.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import warnings
33
from functools import partial
4-
from typing import Optional
4+
from typing import Optional, Union
55

66
import numpy as np
77
import ot as pot
@@ -18,6 +18,7 @@ def __init__(
1818
reg: float = 0.05,
1919
reg_m: float = 1.0,
2020
normalize_cost: bool = False,
21+
num_threads: Union[int, str] = 1,
2122
warn: bool = True,
2223
) -> None:
2324
"""Initialize the OTPlanSampler class.
@@ -36,13 +37,16 @@ def __init__(
3637
normalizes the cost matrix so that the maximum cost is 1. Helps
3738
stabilize Sinkhorn-based solvers. Should not be used in the vast
3839
majority of cases.
40+
num_threads: int or str, optional
41+
number of threads to use for the "exact" OT solver. If "max", uses
42+
the maximum number of threads.
3943
warn: bool, optional
4044
if True, raises a warning if the algorithm does not converge
4145
"""
4246
# ot_fn should take (a, b, M) as arguments where a, b are marginals and
4347
# M is a cost matrix
4448
if method == "exact":
45-
self.ot_fn = pot.emd
49+
self.ot_fn = partial(pot.emd, numThreads=num_threads)
4650
elif method == "sinkhorn":
4751
self.ot_fn = partial(pot.sinkhorn, reg=reg)
4852
elif method == "unbalanced":

0 commit comments

Comments
 (0)