Skip to content

Commit 9e39d12

Browse files
committed
added linear to non residual agnostic
1 parent ad56ac9 commit 9e39d12

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

botnet/modules/blocks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ def forward(
346346

347347
class AgnosticNonlinearInteractionBlock(InteractionBlock):
348348
def _setup(self) -> None:
349+
self.linear_up = o3.Linear(self.node_feats_irreps,
350+
self.node_feats_irreps,
351+
internal_weights=True,
352+
shared_weights=True)
349353
# TensorProduct
350354
irreps_mid, instructions = tp_out_irreps_with_instructions(self.node_feats_irreps, self.edge_attrs_irreps,
351355
self.target_irreps)
@@ -381,6 +385,7 @@ def forward(
381385
sender, receiver = edge_index
382386
num_nodes = node_feats.shape[0]
383387
tp_weights = self.conv_tp_weights(edge_feats)
388+
node_feats = self.linear_up(node_feats)
384389
mji = self.conv_tp(node_feats[sender], edge_attrs, tp_weights) # [n_edges, irreps]
385390
message = scatter_sum(src=mji, index=receiver, dim=0, dim_size=num_nodes) # [n_nodes, irreps]
386391
message = self.linear(message) / self.avg_num_neighbors

scripts/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
'nequip': {
2424
'color': colors[1],
2525
'label': 'NequIP',
26-
'linestyle': dashed,
2726
},
2827
'nequip-linear': {
2928
'color': colors[3],
@@ -65,6 +64,10 @@
6564
'color': colors[0],
6665
'label': 'BOTNet',
6766
},
67+
'botnet-e0': {
68+
'color': colors[0],
69+
'label': 'BOTNet E0',
70+
},
6871
'botnet-ssh': {
6972
'color': colors[8],
7073
'label': 'BOTNet-SSH',

0 commit comments

Comments
 (0)