Skip to content
26 changes: 26 additions & 0 deletions captum/attr/_core/lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,39 @@ def _register_forward_hooks(self) -> None:
)
self.backward_handles.append(backward_handle)
else:
# adding "no op" node in between layers, should be removed after
# transitioning to module backward hooks
no_op_handle = layer.register_forward_pre_hook(self._no_op_pre_hook)
self.forward_handles.append(no_op_handle)

forward_handle = layer.register_forward_hook(
layer.rule.forward_hook # type: ignore
)
self.forward_handles.append(forward_handle)
if self.verbose:
print(f"Applied {layer.rule} on layer {layer}")

@staticmethod
def _no_op_pre_hook(
module: Module,
inputs: TensorOrTupleOfTensorsGeneric,
) -> TensorOrTupleOfTensorsGeneric:
"""Pre Hook for adding no op nodes in between modules, in order to fix
the hook ordering issues.

Args:
module (nn.Module): module in question, not used but required to be
part of the signature by pytorch
inputs (TensorOrTupleOfTensorsGeneric): inputs of the module

Returns:
TensorOrTupleOfTensorsGeneric: cloned inputs
"""
if isinstance(inputs, Tensor):
return inputs.clone()
else:
return tuple(x.clone() for x in inputs)

def _register_weight_hooks(self) -> None:
for layer in self.layers:
if layer.rule is not None:
Expand Down
Loading