Skip to content

Molecule generation model (GeoDiff) #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f405837
create filestructure for science application
natolambert Jun 30, 2022
e4a2ddf
rebase molecule gen
natolambert Oct 3, 2022
71753ef
make style
natolambert Jul 1, 2022
9064eda
add property to self in init for colab
natolambert Jul 5, 2022
7c15d6b
small fix to types in forward()
natolambert Jul 11, 2022
120af84
rebase main, small updates
natolambert Oct 3, 2022
9dd023a
add helper function to trim colab
natolambert Jul 12, 2022
2d1f748
remove unused code
natolambert Jul 12, 2022
a4513e2
clean API for colab
natolambert Jul 13, 2022
ce71e2f
remove unused code
natolambert Jul 21, 2022
3865892
weird rebase
natolambert Oct 3, 2022
127f72a
tests pass
natolambert Jul 22, 2022
2f0ac21
make style and fix-copies
natolambert Jul 22, 2022
25ec89d
rename model and file
natolambert Jul 25, 2022
79f25d6
update API, update tests, rename class
natolambert Jul 25, 2022
7a85d04
clean model & tests
natolambert Jul 25, 2022
a90d1be
add checking for imports
natolambert Jul 26, 2022
4d23976
minor formatting nit
natolambert Jul 26, 2022
506eb3c
add attribution of original codebase
natolambert Jul 27, 2022
4d158a3
style and readibility improvements
natolambert Aug 1, 2022
7e73190
fixes post large rebase
natolambert Oct 3, 2022
682eb47
fix tests
natolambert Oct 3, 2022
77569dc
Merge remote-tracking branch 'origin/main' into molecule_gen
natolambert Oct 3, 2022
2ef3727
make quality and style
natolambert Oct 3, 2022
47af5ce
only import moleculegnn when ready
natolambert Oct 3, 2022
f5f2576
fix torch_geometric check
natolambert Oct 3, 2022
104ec26
remove dummy tranformers objects
natolambert Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
rebase molecule gen
  • Loading branch information
natolambert committed Oct 3, 2022
commit e4a2ddf7a44ea649d001de132e58c679e923b331
45 changes: 23 additions & 22 deletions src/diffusers/models/dualencoder_gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,43 +506,44 @@ def sigmoid(x):


class DualEncoderEpsNetwork(ModelMixin, ConfigMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the first graph network NN in this library, we should be extra careful with the naming.
Is this a universally understandable name? Do you think other graph networks would also use this architecuter? Should we make the name more generic in this case? Can we link to a paper here that defined that model architecture?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link to original code is above now too, I followed up with those authors asking if their's was original. https://github.com/DeepGraphLearning/ConfGF

def __init__(self, config):
def __init__(self, hidden_dim, num_convs, num_convs_local, cutoff, mlp_act, beta_schedule, beta_start, beta_end, num_diffusion_timesteps, edge_order, edge_encoder, smooth_conv):
super().__init__()
self.config = config
self.cutoff = cutoff
self.edge_encoder = edge_encoder

"""
edge_encoder: Takes both edge type and edge length as input and outputs a vector
[Note]: node embedding is done in SchNetEncoder
"""
self.edge_encoder_global = MLPEdgeEncoder(config.hidden_dim, config.mlp_act) # get_edge_encoder(config)
self.edge_encoder_local = MLPEdgeEncoder(config.hidden_dim, config.mlp_act) # get_edge_encoder(config)
self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)
self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)

"""
The graph neural network that extracts node-wise features.
"""
self.encoder_global = SchNetEncoder(
hidden_channels=config.hidden_dim,
num_filters=config.hidden_dim,
num_interactions=config.num_convs,
hidden_channels=hidden_dim,
num_filters=hidden_dim,
num_interactions=num_convs,
edge_channels=self.edge_encoder_global.out_channels,
cutoff=config.cutoff,
smooth=config.smooth_conv,
cutoff=cutoff,
smooth=smooth_conv,
)
self.encoder_local = GINEncoder(
hidden_dim=config.hidden_dim,
num_convs=config.num_convs_local,
hidden_dim=hidden_dim,
num_convs=num_convs_local,
)

"""
`output_mlp` takes a mixture of two nodewise features and edge features as input and outputs
gradients w.r.t. edge_length (out_dim = 1).
"""
self.grad_global_dist_mlp = MultiLayerPerceptron(
2 * config.hidden_dim, [config.hidden_dim, config.hidden_dim // 2, 1], activation=config.mlp_act
2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act
)

self.grad_local_dist_mlp = MultiLayerPerceptron(
2 * config.hidden_dim, [config.hidden_dim, config.hidden_dim // 2, 1], activation=config.mlp_act
2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act
)

"""
Expand All @@ -551,15 +552,15 @@ def __init__(self, config):
self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])
self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])

self.model_type = config.type # config.type # 'diffusion'; 'dsm'
self.model_type = type # type # 'diffusion'; 'dsm'

# denoising diffusion
## betas
betas = get_beta_schedule(
beta_schedule=config.beta_schedule,
beta_start=config.beta_start,
beta_end=config.beta_end,
num_diffusion_timesteps=config.num_diffusion_timesteps,
beta_schedule=beta_schedule,
beta_start=beta_start,
beta_end=beta_end,
num_diffusion_timesteps=num_diffusion_timesteps,
)
betas = torch.from_numpy(betas).float()
self.betas = nn.Parameter(betas, requires_grad=False)
Expand Down Expand Up @@ -599,8 +600,8 @@ def forward(
edge_index=bond_index,
edge_type=bond_type,
batch=batch,
order=self.config.edge_order,
cutoff=self.config.cutoff,
order=self.edge_order,
cutoff=self.cutoff,
extend_order=extend_order,
extend_radius=extend_radius,
is_sidechain=is_sidechain,
Expand Down Expand Up @@ -743,14 +744,14 @@ def get_loss_diffusion(
train_edge_mask = is_train_edge(edge_index, is_sidechain)
d_perturbed = torch.where(train_edge_mask.unsqueeze(-1), d_perturbed, d_gt)

if self.config.edge_encoder == "gaussian":
if self.edge_encoder == "gaussian":
# Distances must be greater than 0
d_sgn = torch.sign(d_perturbed)
d_perturbed = torch.clamp(d_perturbed * d_sgn, min=0.01, max=float("inf"))
d_target = (d_gt - d_perturbed) / (1.0 - a_edge).sqrt() * a_edge.sqrt() # (E_global, 1), denoising direction

global_mask = torch.logical_and(
torch.logical_or(d_perturbed <= self.config.cutoff, local_edge_mask.unsqueeze(-1)),
torch.logical_or(d_perturbed <= self.cutoff, local_edge_mask.unsqueeze(-1)),
~local_edge_mask.unsqueeze(-1),
)
target_d_global = torch.where(global_mask, d_target, torch.zeros_like(d_target))
Expand Down