Skip to content

Molecule generation model (GeoDiff) #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 27 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f405837
create filestructure for science application
natolambert Jun 30, 2022
e4a2ddf
rebase molecule gen
natolambert Oct 3, 2022
71753ef
make style
natolambert Jul 1, 2022
9064eda
add property to self in init for colab
natolambert Jul 5, 2022
7c15d6b
small fix to types in forward()
natolambert Jul 11, 2022
120af84
rebase main, small updates
natolambert Oct 3, 2022
9dd023a
add helper function to trim colab
natolambert Jul 12, 2022
2d1f748
remove unused code
natolambert Jul 12, 2022
a4513e2
clean API for colab
natolambert Jul 13, 2022
ce71e2f
remove unused code
natolambert Jul 21, 2022
3865892
weird rebase
natolambert Oct 3, 2022
127f72a
tests pass
natolambert Jul 22, 2022
2f0ac21
make style and fix-copies
natolambert Jul 22, 2022
25ec89d
rename model and file
natolambert Jul 25, 2022
79f25d6
update API, update tests, rename class
natolambert Jul 25, 2022
7a85d04
clean model & tests
natolambert Jul 25, 2022
a90d1be
add checking for imports
natolambert Jul 26, 2022
4d23976
minor formatting nit
natolambert Jul 26, 2022
506eb3c
add attribution of original codebase
natolambert Jul 27, 2022
4d158a3
style and readibility improvements
natolambert Aug 1, 2022
7e73190
fixes post large rebase
natolambert Oct 3, 2022
682eb47
fix tests
natolambert Oct 3, 2022
77569dc
Merge remote-tracking branch 'origin/main' into molecule_gen
natolambert Oct 3, 2022
2ef3727
make quality and style
natolambert Oct 3, 2022
47af5ce
only import moleculegnn when ready
natolambert Oct 3, 2022
f5f2576
fix torch_geometric check
natolambert Oct 3, 2022
104ec26
remove dummy tranformers objects
natolambert Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove unused code
  • Loading branch information
natolambert committed Oct 3, 2022
commit 2d1f748303bf5c07408d396189c759c3e2d2a2e2
269 changes: 10 additions & 259 deletions src/diffusers/models/dualencoder_gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,38 +679,6 @@ def get_loss(
extend_order=True,
extend_radius=True,
is_sidechain=None,
):
return self.get_loss_diffusion(
atom_type,
pos,
bond_index,
bond_type,
batch,
num_nodes_per_graph,
num_graphs,
anneal_power,
return_unreduced_loss,
return_unreduced_edge_loss,
extend_order,
extend_radius,
is_sidechain,
)

def get_loss_diffusion(
self,
atom_type,
pos,
bond_index,
bond_type,
batch,
num_nodes_per_graph,
num_graphs,
anneal_power=2.0,
return_unreduced_loss=False,
return_unreduced_edge_loss=False,
extend_order=True,
extend_radius=True,
is_sidechain=None,
):
N = atom_type.size(0)
node2graph = batch
Expand Down Expand Up @@ -787,66 +755,13 @@ def get_loss_diffusion(
else:
return loss

def langevin_dynamics_sample(
self,
atom_type,
pos_init,
bond_index,
bond_type,
batch,
num_graphs,
extend_order,
extend_radius=True,
n_steps=100,
step_lr=0.0000010,
clip=1000,
clip_local=None,
clip_pos=None,
min_sigma=0,
is_sidechain=None,
global_start_sigma=float("inf"),
w_global=0.2,
w_reg=1.0,
**kwargs,
):
return self.langevin_dynamics_sample_diffusion(
atom_type,
pos_init,
bond_index,
bond_type,
batch,
num_graphs,
extend_order,
extend_radius,
n_steps,
step_lr,
clip,
clip_local,
clip_pos,
min_sigma,
is_sidechain,
global_start_sigma,
w_global,
w_reg,
sampling_type=kwargs.get("sampling_type", "ddpm_noisy"),
eta=kwargs.get("eta", 1.0),
)

def get_residual_params(
self,
t,
batch,
extend_order=False,
extend_radius=True,
step_lr=0.0000010,
clip=1000,
clip_local=None,
clip_pos=None,
min_sigma=0,
is_sidechain=None,
global_start_sigma=0.5,
w_global=1.0,
**kwargs,
):
atom_type = batch.atom_type
bond_index = batch.edge_index
Expand Down Expand Up @@ -879,15 +794,20 @@ def get_residual(
self,
pos,
sigma,
edge_inv_global,
local_edge_mask,
edge_index,
edge_length,
node_eq_local,
model_outputs,
global_start_sigma=0.5,
w_global=1.0,
clip=1000.0,
):
(
edge_inv_global,
edge_inv_local,
edge_index,
edge_type,
edge_length,
local_edge_mask,
node_eq_local,
) = model_outputs

# Global
if sigma < global_start_sigma:
Expand All @@ -901,172 +821,13 @@ def get_residual(
eps_pos = node_eq_local + node_eq_global * w_global
return -eps_pos

def langevin_dynamics_sample_diffusion(
self,
atom_type,
pos_init,
bond_index,
bond_type,
batch,
num_graphs,
extend_order,
extend_radius=True,
n_steps=100,
clip=1000,
clip_local=None,
clip_pos=None,
global_start_sigma=float("inf"),
w_global=0.2,
w_reg=1.0,
**kwargs,
):
def compute_alpha(beta, t):
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1) # .view(-1, 1, 1, 1)
return a

sigmas = (1.0 - self.alphas).sqrt() / self.alphas.sqrt()
pos_traj = []
if is_sidechain is not None:
assert pos_gt is not None, "need crd of backbone for sidechain prediction"
with torch.no_grad():
# skip = self.num_timesteps // n_steps
# seq = range(0, self.num_timesteps, skip)

## to test sampling with less intermediate diffusion steps
# n_steps: the num of steps
seq = range(self.num_timesteps - n_steps, self.num_timesteps)
seq_next = [-1] + list(seq[:-1])

pos = pos_init * sigmas[-1]
if is_sidechain is not None:
pos[~is_sidechain] = pos_gt[~is_sidechain]
for i, j in tqdm(zip(reversed(seq), reversed(seq_next)), desc="sample"):
t = torch.full(size=(num_graphs,), fill_value=i, dtype=torch.long, device=pos.device)

edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self(
atom_type=atom_type,
pos=pos,
bond_index=bond_index,
bond_type=bond_type,
batch=batch,
time_step=t,
return_edges=True,
extend_order=extend_order,
extend_radius=extend_radius,
is_sidechain=is_sidechain,
) # (E_global, 1), (E_local, 1)

# Local
node_eq_local = eq_transform(
edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]
)
if clip_local is not None:
node_eq_local = clip_norm(node_eq_local, limit=clip_local)
# Global
if sigmas[i] < global_start_sigma:
edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())
node_eq_global = eq_transform(edge_inv_global, pos, edge_index, edge_length)
node_eq_global = clip_norm(node_eq_global, limit=clip)
else:
node_eq_global = 0
# Sum
eps_pos = node_eq_local + node_eq_global * w_global # + eps_pos_reg * w_reg

# Update

sampling_type = kwargs.get("sampling_type", "ddpm_noisy") # types: generalized, ddpm_noisy, ld

noise = torch.randn_like(pos) # center_pos(torch.randn_like(pos), batch)
if sampling_type == "generalized" or sampling_type == "ddpm_noisy":
b = self.betas
t = t[0]
next_t = (torch.ones(1) * j).to(pos.device)
at = compute_alpha(b, t.long())
at_next = compute_alpha(b, next_t.long())
if sampling_type == "generalized":
eta = kwargs.get("eta", 1.0)
et = -eps_pos
## original
# pos0_t = (pos - et * (1 - at).sqrt()) / at.sqrt()
## reweighted
# pos0_t = pos - et * (1 - at).sqrt() / at.sqrt()
c1 = eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
c2 = ((1 - at_next) - c1**2).sqrt()
# pos_next = at_next.sqrt() * pos0_t + c1 * noise + c2 * et
# pos_next = pos0_t + c1 * noise / at_next.sqrt() + c2 * et / at_next.sqrt()

# pos_next = pos + et * (c2 / at_next.sqrt() - (1 - at).sqrt() / at.sqrt()) + noise * c1 / at_next.sqrt()
step_size_pos_ld = step_lr * (sigmas[i] / 0.01) ** 2 / sigmas[i]
step_size_pos_generalized = 5 * ((1 - at).sqrt() / at.sqrt() - c2 / at_next.sqrt())
step_size_pos = (
step_size_pos_ld
if step_size_pos_ld < step_size_pos_generalized
else step_size_pos_generalized
)

step_size_noise_ld = torch.sqrt((step_lr * (sigmas[i] / 0.01) ** 2) * 2)
step_size_noise_generalized = 3 * (c1 / at_next.sqrt())
step_size_noise = (
step_size_noise_ld
if step_size_noise_ld < step_size_noise_generalized
else step_size_noise_generalized
)

pos_next = pos - et * step_size_pos + noise * step_size_noise

elif sampling_type == "ddpm_noisy":
atm1 = at_next
beta_t = 1 - at / atm1
e = -eps_pos
pos0_from_e = (1.0 / at).sqrt() * pos - (1.0 / at - 1).sqrt() * e
mean_eps = (
(atm1.sqrt() * beta_t) * pos0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * pos
) / (1.0 - at)
mean = mean_eps
mask = 1 - (t == 0).float()
logvar = beta_t.log()
pos_next = mean + mask * torch.exp(0.5 * logvar) * noise
elif sampling_type == "ld":
step_size = step_lr * (sigmas[i] / 0.01) ** 2
pos_next = pos + step_size * eps_pos / sigmas[i] + noise * torch.sqrt(step_size * 2)

pos = pos_next

if is_sidechain is not None:
pos[~is_sidechain] = pos_gt[~is_sidechain]

if torch.isnan(pos).any():
print("NaN detected. Please restart.")
raise FloatingPointError()
pos = pos - scatter_mean(pos, batch, dim=0)[batch] # center_pos(pos, batch)
if clip_pos is not None:
pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)
pos_traj.append(pos.clone().cpu())

return pos, pos_traj


def clip_norm(vec, limit, p=2):
norm = torch.norm(vec, dim=-1, p=2, keepdim=True)
denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))
return vec * denom


def is_bond(edge_type):
return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0)


def is_angle_edge(edge_type):
return edge_type == len(BOND_TYPES) + 1 - 1


def is_dihedral_edge(edge_type):
return edge_type == len(BOND_TYPES) + 2 - 1


def is_radius_edge(edge_type):
return edge_type == 0


def is_local_edge(edge_type):
Expand All @@ -1080,13 +841,3 @@ def is_train_edge(edge_index, is_sidechain):
is_sidechain = is_sidechain.bool()
return torch.logical_or(is_sidechain[edge_index[0]], is_sidechain[edge_index[1]])


def regularize_bond_length(edge_type, edge_length, rng=5.0):
mask = is_bond(edge_type).float().reshape(-1, 1)
d = -torch.clamp(edge_length - rng, min=0.0, max=float("inf")) * mask
return d


# def center_pos(pos, batch):
# pos_center = pos - scatter_mean(pos, batch, dim=0)[batch]
# return pos_center