Skip to content

Molecule generation model (GeoDiff) #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f405837
create filestructure for science application
natolambert Jun 30, 2022
e4a2ddf
rebase molecule gen
natolambert Oct 3, 2022
71753ef
make style
natolambert Jul 1, 2022
9064eda
add property to self in init for colab
natolambert Jul 5, 2022
7c15d6b
small fix to types in forward()
natolambert Jul 11, 2022
120af84
rebase main, small updates
natolambert Oct 3, 2022
9dd023a
add helper function to trim colab
natolambert Jul 12, 2022
2d1f748
remove unused code
natolambert Jul 12, 2022
a4513e2
clean API for colab
natolambert Jul 13, 2022
ce71e2f
remove unused code
natolambert Jul 21, 2022
3865892
weird rebase
natolambert Oct 3, 2022
127f72a
tests pass
natolambert Jul 22, 2022
2f0ac21
make style and fix-copies
natolambert Jul 22, 2022
25ec89d
rename model and file
natolambert Jul 25, 2022
79f25d6
update API, update tests, rename class
natolambert Jul 25, 2022
7a85d04
clean model & tests
natolambert Jul 25, 2022
a90d1be
add checking for imports
natolambert Jul 26, 2022
4d23976
minor formatting nit
natolambert Jul 26, 2022
506eb3c
add attribution of original codebase
natolambert Jul 27, 2022
4d158a3
style and readibility improvements
natolambert Aug 1, 2022
7e73190
fixes post large rebase
natolambert Oct 3, 2022
682eb47
fix tests
natolambert Oct 3, 2022
77569dc
Merge remote-tracking branch 'origin/main' into molecule_gen
natolambert Oct 3, 2022
2ef3727
make quality and style
natolambert Oct 3, 2022
47af5ce
only import moleculegnn when ready
natolambert Oct 3, 2022
f5f2576
fix torch_geometric check
natolambert Oct 3, 2022
104ec26
remove dummy tranformers objects
natolambert Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
weird rebase
  • Loading branch information
natolambert committed Oct 3, 2022
commit 38658929732f66f9b83fc2641f294743e358f971
32 changes: 16 additions & 16 deletions src/diffusers/models/dualencoder_gfn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff
from typing import Callable, Union
from typing import Callable, Union, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -520,7 +520,7 @@ def __init__(
self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])
self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])

def forward(
def _forward(
self,
atom_type,
pos,
Expand Down Expand Up @@ -613,28 +613,28 @@ def forward(
else:
return edge_inv_global, edge_inv_local

def get_residual_params(
def forward(
self,
t,
batch,
sample,
timestep: Union[torch.Tensor, float, int],
extend_order=False,
extend_radius=True,
clip_local=None,
):
atom_type = batch.atom_type
bond_index = batch.edge_index
bond_type = batch.edge_type
num_graphs = batch.num_graphs
pos = batch.pos
)-> Dict[str, torch.FloatTensor]:
atom_type = sample.atom_type
bond_index = sample.edge_index
bond_type = sample.edge_type
num_graphs = sample.num_graphs
pos = sample.pos

timesteps = torch.full(size=(num_graphs,), fill_value=t, dtype=torch.long, device=pos.device)
timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)

edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self.forward(
edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(
atom_type=atom_type,
pos=batch.pos,
pos=sample.pos,
bond_index=bond_index,
bond_type=bond_type,
batch=batch.batch,
batch=sample.batch,
time_step=timesteps,
return_edges=True,
extend_order=extend_order,
Expand Down Expand Up @@ -677,7 +677,7 @@ def get_residual(

# Sum
eps_pos = node_eq_local + node_eq_global * w_global
return -eps_pos
return {"sample": -eps_pos}


def clip_norm(vec, limit, p=2):
Expand Down
Loading