Skip to content

Conversation

@samwaltonnorwood
Copy link

Adds the "lifted_skip" function which acts as a direct residual connection, zero-padding as necessary to match new irreps in the output of EquivariantProductBasisBlock.
Also adds arguments to the InteractionBlock class to support edge gating for OCP.

@samwaltonnorwood
Copy link
Author

samwaltonnorwood commented Sep 19, 2023

@abhshkdz The lifted_skip operation needs to be revised to fix the bug you're seeing. The setup I developed this on was slightly different and ensured node_feats_irreps.dim <= hidden_irreps.dim, so it errs in the opposite case. This should fix it:

``` class lifted_skip(torch.nn.Module):
    def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps) -> None:
        super().__init__()
        self.in_dim = o3.Irreps(irreps_in).dim
        self.out_dim = o3.Irreps(irreps_out).dim
        self.pad = True
        if self.in_dim >= self.out_dim:
            self.pad = False
        
    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        batch, _ = tensor.shape
        if self.pad:
            template = torch.zeros(batch, self.out_dim, device=tensor.device)
            template[:, :self.in_dim] = tensor
            return template
        else:
            return tensor[:, :self.out_dim]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant