Skip to content

Molecule generation model (GeoDiff) #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f405837
create filestructure for science application
natolambert Jun 30, 2022
e4a2ddf
rebase molecule gen
natolambert Oct 3, 2022
71753ef
make style
natolambert Jul 1, 2022
9064eda
add property to self in init for colab
natolambert Jul 5, 2022
7c15d6b
small fix to types in forward()
natolambert Jul 11, 2022
120af84
rebase main, small updates
natolambert Oct 3, 2022
9dd023a
add helper function to trim colab
natolambert Jul 12, 2022
2d1f748
remove unused code
natolambert Jul 12, 2022
a4513e2
clean API for colab
natolambert Jul 13, 2022
ce71e2f
remove unused code
natolambert Jul 21, 2022
3865892
weird rebase
natolambert Oct 3, 2022
127f72a
tests pass
natolambert Jul 22, 2022
2f0ac21
make style and fix-copies
natolambert Jul 22, 2022
25ec89d
rename model and file
natolambert Jul 25, 2022
79f25d6
update API, update tests, rename class
natolambert Jul 25, 2022
7a85d04
clean model & tests
natolambert Jul 25, 2022
a90d1be
add checking for imports
natolambert Jul 26, 2022
4d23976
minor formatting nit
natolambert Jul 26, 2022
506eb3c
add attribution of original codebase
natolambert Jul 27, 2022
4d158a3
style and readibility improvements
natolambert Aug 1, 2022
7e73190
fixes post large rebase
natolambert Oct 3, 2022
682eb47
fix tests
natolambert Oct 3, 2022
77569dc
Merge remote-tracking branch 'origin/main' into molecule_gen
natolambert Oct 3, 2022
2ef3727
make quality and style
natolambert Oct 3, 2022
47af5ce
only import moleculegnn when ready
natolambert Oct 3, 2022
f5f2576
fix torch_geometric check
natolambert Oct 3, 2022
104ec26
remove dummy tranformers objects
natolambert Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add checking for imports
  • Loading branch information
natolambert committed Oct 3, 2022
commit a90d1be482f766781ae87e0948020d0ade8f50c9
8 changes: 8 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
is_onnx_available,
is_scipy_available,
is_torch_available,
is_inflect_available,
is_torch_geometric_available,
is_transformers_available,
is_unidecode_available,
)
Expand Down Expand Up @@ -83,3 +85,9 @@
from .pipelines import FlaxStableDiffusionPipeline
else:
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
from .utils.dummy_transformers_objects import *

if is_torch_geometric_available():
from .models import MoleculeGNN
else:
from .utils.dummy_torch_geometric_objects import *
145 changes: 145 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,148 @@
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
<<<<<<< HEAD
=======


_transformers_available = importlib.util.find_spec("transformers") is not None
try:
_transformers_version = importlib_metadata.version("transformers")
logger.debug(f"Successfully imported transformers version {_transformers_version}")
except importlib_metadata.PackageNotFoundError:
_transformers_available = False


_inflect_available = importlib.util.find_spec("inflect") is not None
try:
_inflect_version = importlib_metadata.version("inflect")
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
_inflect_available = False


_unidecode_available = importlib.util.find_spec("unidecode") is not None
try:
_unidecode_version = importlib_metadata.version("unidecode")
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False


_modelcards_available = importlib.util.find_spec("modelcards") is not None
try:
_modelcards_version = importlib_metadata.version("modelcards")
logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
except importlib_metadata.PackageNotFoundError:
_modelcards_available = False

_torch_scatter_available = importlib.util.find_spec("torch_scatter") is not None
try:
_torch_scatter_version = importlib_metadata.version("torch_scatter")
logger.debug(f"Successfully imported torch_scatter version {_torch_scatter_version}")
except importlib_metadata.PackageNotFoundError:
_torch_scatter_available = False

_torch_scatter_available = importlib.util.find_spec("torch_geometric") is not None
try:
_torch_geometric_version = importlib_metadata.version("torch_geometric")
logger.debug(f"Successfully imported torch_geometric version {_torch_geometric_version}")
except importlib_metadata.PackageNotFoundError:
_torch_geometric_available = False


def is_transformers_available():
return _transformers_available


def is_inflect_available():
return _inflect_available


def is_unidecode_available():
return _unidecode_available


def is_modelcards_available():
return _modelcards_available


def is_torch_scatter_available():
return _torch_scatter_available


def is_torch_geometric_available():
# the model source of the Molecule Generation GNN requires a specific torch geometric version
# for more info, see the original repo https://github.com/MinkaiXu/GeoDiff or our colab in readme
return _torch_geometric_version == "1.7.2"


class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
not have access to.
"""


class EntryNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""


class RevisionNotFoundError(HTTPError):
"""Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""


TRANSFORMERS_IMPORT_ERROR = """
{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
"""


UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
Unidecode`
"""


INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
inflect`
"""

TORCH_GEOMETRIC_IMPORT_ERROR = """
{0} requires version 1.7.2 of torch_geometric but it was not found in your environment. You can install it with conda:
`conda install -c rusty1s pytorch-geometric=1.7.2`, given pytorch 1.8
"""


BACKENDS_MAPPING = OrderedDict(
[
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
]
)


def requires_backends(obj, backends):
if not isinstance(backends, (list, tuple)):
backends = [backends]

name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
checks = (BACKENDS_MAPPING[backend] for backend in backends)
failed = [msg.format(name) for available, msg in checks if not available()]
if failed:
raise ImportError("".join(failed))


class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""

def __getattr__(cls, key):
if key.startswith("_"):
return super().__getattr__(cls, key)
requires_backends(cls, cls._backends)
>>>>>>> bf87817 (add checking for imports)
10 changes: 10 additions & 0 deletions src/diffusers/utils/dummy_torch_geometric_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends


class MoleculeGNN(metaclass=DummyObject):
_backends = ["torch_geometric"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch_geometric"])
9 changes: 8 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,21 @@
DDPMScheduler,
LDMPipeline,
LDMTextToImagePipeline,
MoleculeGNN,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
ScoreSdeVeScheduler,
UNet2DModel,
VQModel,
)
from diffusers.utils import is_torch_geometric_available


if is_torch_geometric_available():
from diffusers import MoleculeGNN
else:
from diffusers.utils.dummy_torch_geometric_objects import *

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
Expand Down