11import math
22import warnings
33from functools import partial
4- from typing import Optional
4+ from typing import Optional , Union
55
66import numpy as np
77import ot as pot
@@ -18,6 +18,7 @@ def __init__(
1818 reg : float = 0.05 ,
1919 reg_m : float = 1.0 ,
2020 normalize_cost : bool = False ,
21+ num_threads : Union [int , str ] = 1 ,
2122 warn : bool = True ,
2223 ) -> None :
2324 """Initialize the OTPlanSampler class.
@@ -36,13 +37,16 @@ def __init__(
3637 normalizes the cost matrix so that the maximum cost is 1. Helps
3738 stabilize Sinkhorn-based solvers. Should not be used in the vast
3839 majority of cases.
40+ num_threads: int or str, optional
41+ number of threads to use for the "exact" OT solver. If "max", uses
42+ the maximum number of threads.
3943 warn: bool, optional
4044 if True, raises a warning if the algorithm does not converge
4145 """
4246 # ot_fn should take (a, b, M) as arguments where a, b are marginals and
4347 # M is a cost matrix
4448 if method == "exact" :
45- self .ot_fn = pot .emd
49+ self .ot_fn = partial ( pot .emd , numThreads = num_threads )
4650 elif method == "sinkhorn" :
4751 self .ot_fn = partial (pot .sinkhorn , reg = reg )
4852 elif method == "unbalanced" :
0 commit comments