Skip to content

Commit 35ecf6c

Browse files
committed
added interaction residual element dependent
1 parent bf8acbb commit 35ecf6c

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

botnet/modules/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .blocks import (AtomicEnergiesBlock, RadialEmbeddingBlock, LinearReadoutBlock, SimpleInteractionBlock,
44
ElementDependentInteractionBlock, InteractionBlock, NonlinearInteractionBlock,
5-
NonLinearReadoutBlock, AgnosticNonlinearInteractionBlock,
5+
NonLinearReadoutBlock, AgnosticNonlinearInteractionBlock, ResidualElementDependentInteractionBlock,
66
AgnosticResidualNonlinearInteractionBlock, NequIPInteractionBlock, AgnosticNoScNonlinearInteractionBlock)
77
from .loss import EnergyForcesLoss, ACELoss, WeightedEnergyForcesLoss
88
from .models import (BodyOrderedModel, ScaleShiftBodyOrderedModel, SingleReadoutModel,
@@ -14,6 +14,7 @@
1414
interaction_classes: Dict[str, Type[InteractionBlock]] = {
1515
'SimpleInteractionBlock': SimpleInteractionBlock,
1616
'ElementDependentInteractionBlock': ElementDependentInteractionBlock,
17+
'ResidualElementDependentInteractionBlock': ResidualElementDependentInteractionBlock,
1718
'NonlinearInteractionBlock': NonlinearInteractionBlock,
1819
'AgnosticNonlinearInteractionBlock': AgnosticNonlinearInteractionBlock,
1920
'AgnosticNoScNonlinearInteractionBlock': AgnosticNoScNonlinearInteractionBlock,

botnet/modules/blocks.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,53 @@ def forward(
226226
return self.skip_tp(message, node_attrs) # [n_nodes, irreps]
227227

228228

229+
class ResidualElementDependentInteractionBlock(InteractionBlock):
230+
def _setup(self) -> None:
231+
self.linear_up = o3.Linear(self.node_feats_irreps,
232+
self.node_feats_irreps,
233+
internal_weights=True,
234+
shared_weights=True)
235+
# TensorProduct
236+
irreps_mid, instructions = tp_out_irreps_with_instructions(self.node_feats_irreps, self.edge_attrs_irreps,
237+
self.target_irreps)
238+
self.conv_tp = o3.TensorProduct(self.node_feats_irreps,
239+
self.edge_attrs_irreps,
240+
irreps_mid,
241+
instructions=instructions,
242+
shared_weights=False,
243+
internal_weights=False)
244+
self.conv_tp_weights = TensorProductWeightsBlock(num_elements=self.node_attrs_irreps.num_irreps,
245+
num_edge_feats=self.edge_feats_irreps.num_irreps,
246+
num_feats_out=self.conv_tp.weight_numel)
247+
248+
# Linear
249+
irreps_mid = irreps_mid.simplify()
250+
self.irreps_out = linear_out_irreps(irreps_mid, self.target_irreps)
251+
self.irreps_out = self.irreps_out.simplify()
252+
self.linear = o3.Linear(irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True)
253+
254+
# Selector TensorProduct
255+
self.skip_tp = o3.FullyConnectedTensorProduct(self.node_feats_irreps, self.node_attrs_irreps, self.irreps_out)
256+
257+
def forward(
258+
self,
259+
node_attrs: torch.Tensor,
260+
node_feats: torch.Tensor,
261+
edge_attrs: torch.Tensor,
262+
edge_feats: torch.Tensor,
263+
edge_index: torch.Tensor,
264+
) -> torch.Tensor:
265+
sender, receiver = edge_index
266+
num_nodes = node_feats.shape[0]
267+
sc = self.skip_tp(node_feats, node_attrs)
268+
node_feats = self.linear_up(node_feats)
269+
tp_weights = self.conv_tp_weights(node_attrs[sender], edge_feats)
270+
mji = self.conv_tp(node_feats[sender], edge_attrs, tp_weights) # [n_edges, irreps]
271+
message = scatter_sum(src=mji, index=receiver, dim=0, dim_size=num_nodes) # [n_nodes, irreps]
272+
message = self.linear(message) / self.avg_num_neighbors
273+
return message + sc # [n_nodes, irreps]
274+
275+
229276
def init_layer(layer: torch.nn.Linear, w_scale=1.0) -> torch.nn.Linear:
230277
torch.nn.init.orthogonal_(layer.weight.data)
231278
layer.weight.data.mul_(w_scale) # type: ignore
@@ -289,7 +336,7 @@ def forward(
289336
) -> torch.Tensor:
290337
sender, receiver = edge_index
291338
num_nodes = node_feats.shape[0]
292-
339+
293340
tp_weights = self.conv_tp_weights(torch.cat([node_attrs[sender], edge_feats], dim=-1))
294341
mji = self.conv_tp(node_feats[sender], edge_attrs, tp_weights) # [n_edges, irreps]
295342
message = scatter_sum(src=mji, index=receiver, dim=0, dim_size=num_nodes) # [n_nodes, irreps]
@@ -333,7 +380,6 @@ def forward(
333380
) -> torch.Tensor:
334381
sender, receiver = edge_index
335382
num_nodes = node_feats.shape[0]
336-
337383
tp_weights = self.conv_tp_weights(edge_feats)
338384
mji = self.conv_tp(node_feats[sender], edge_attrs, tp_weights) # [n_edges, irreps]
339385
message = scatter_sum(src=mji, index=receiver, dim=0, dim_size=num_nodes) # [n_nodes, irreps]

0 commit comments

Comments
 (0)