@@ -226,6 +226,53 @@ def forward(
226226 return self .skip_tp (message , node_attrs ) # [n_nodes, irreps]
227227
228228
229+ class ResidualElementDependentInteractionBlock (InteractionBlock ):
230+ def _setup (self ) -> None :
231+ self .linear_up = o3 .Linear (self .node_feats_irreps ,
232+ self .node_feats_irreps ,
233+ internal_weights = True ,
234+ shared_weights = True )
235+ # TensorProduct
236+ irreps_mid , instructions = tp_out_irreps_with_instructions (self .node_feats_irreps , self .edge_attrs_irreps ,
237+ self .target_irreps )
238+ self .conv_tp = o3 .TensorProduct (self .node_feats_irreps ,
239+ self .edge_attrs_irreps ,
240+ irreps_mid ,
241+ instructions = instructions ,
242+ shared_weights = False ,
243+ internal_weights = False )
244+ self .conv_tp_weights = TensorProductWeightsBlock (num_elements = self .node_attrs_irreps .num_irreps ,
245+ num_edge_feats = self .edge_feats_irreps .num_irreps ,
246+ num_feats_out = self .conv_tp .weight_numel )
247+
248+ # Linear
249+ irreps_mid = irreps_mid .simplify ()
250+ self .irreps_out = linear_out_irreps (irreps_mid , self .target_irreps )
251+ self .irreps_out = self .irreps_out .simplify ()
252+ self .linear = o3 .Linear (irreps_mid , self .irreps_out , internal_weights = True , shared_weights = True )
253+
254+ # Selector TensorProduct
255+ self .skip_tp = o3 .FullyConnectedTensorProduct (self .node_feats_irreps , self .node_attrs_irreps , self .irreps_out )
256+
257+ def forward (
258+ self ,
259+ node_attrs : torch .Tensor ,
260+ node_feats : torch .Tensor ,
261+ edge_attrs : torch .Tensor ,
262+ edge_feats : torch .Tensor ,
263+ edge_index : torch .Tensor ,
264+ ) -> torch .Tensor :
265+ sender , receiver = edge_index
266+ num_nodes = node_feats .shape [0 ]
267+ sc = self .skip_tp (node_feats , node_attrs )
268+ node_feats = self .linear_up (node_feats )
269+ tp_weights = self .conv_tp_weights (node_attrs [sender ], edge_feats )
270+ mji = self .conv_tp (node_feats [sender ], edge_attrs , tp_weights ) # [n_edges, irreps]
271+ message = scatter_sum (src = mji , index = receiver , dim = 0 , dim_size = num_nodes ) # [n_nodes, irreps]
272+ message = self .linear (message ) / self .avg_num_neighbors
273+ return message + sc # [n_nodes, irreps]
274+
275+
229276def init_layer (layer : torch .nn .Linear , w_scale = 1.0 ) -> torch .nn .Linear :
230277 torch .nn .init .orthogonal_ (layer .weight .data )
231278 layer .weight .data .mul_ (w_scale ) # type: ignore
@@ -289,7 +336,7 @@ def forward(
289336 ) -> torch .Tensor :
290337 sender , receiver = edge_index
291338 num_nodes = node_feats .shape [0 ]
292-
339+
293340 tp_weights = self .conv_tp_weights (torch .cat ([node_attrs [sender ], edge_feats ], dim = - 1 ))
294341 mji = self .conv_tp (node_feats [sender ], edge_attrs , tp_weights ) # [n_edges, irreps]
295342 message = scatter_sum (src = mji , index = receiver , dim = 0 , dim_size = num_nodes ) # [n_nodes, irreps]
@@ -333,7 +380,6 @@ def forward(
333380 ) -> torch .Tensor :
334381 sender , receiver = edge_index
335382 num_nodes = node_feats .shape [0 ]
336-
337383 tp_weights = self .conv_tp_weights (edge_feats )
338384 mji = self .conv_tp (node_feats [sender ], edge_attrs , tp_weights ) # [n_edges, irreps]
339385 message = scatter_sum (src = mji , index = receiver , dim = 0 , dim_size = num_nodes ) # [n_nodes, irreps]
0 commit comments