Skip to content
Next Next commit
add extra optional dependency options
  • Loading branch information
SarahG-579462 committed Jun 6, 2024
commit ff50220aaa3fdc11cad2611f0830f923d34e7f30
File renamed without changes.
14 changes: 14 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
sdk_path = subprocess.check_output(['xcrun', '--show-sdk-path'])
os.environ['CFLAGS'] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8"))

with open('requirements_opt.txt') as f:
optional_requirements = f.read().splitlines()

setup(
name='POT',
version=__version__,
Expand All @@ -71,6 +74,17 @@
data_files=[],
setup_requires=["oldest-supported-numpy", "cython>=0.23"],
install_requires=["numpy>=1.16", "scipy>=1.6"],
extras_require={
'backend-numpy': [], # in requirements.
'backend-jax': ['jax<=0.4.24', 'jaxlib<=0.4.24'],
'backend-cupy': [], # should be installed with conda, not pip, or figure out what CUDA version above.
'backend-tf': ['tensorflow'],
'backend-torch': ['torch_geometric'],
'cvxopt': ['cvxopt'], # on it's own to prevent accidental GPL violations
'dr': ['scikit-learn', 'pymanopt', 'autograd'],
'gnn': ['torch', 'torch_geometric'],
'all': optional_requirements
},
python_requires=">=3.6",
classifiers=[
'Development Status :: 5 - Production/Stable',
Expand Down