Skip to content
Prev Previous commit
Next Next commit
removed alpha beta rule v2
  • Loading branch information
rGure committed Jun 29, 2021
commit ca2387cd3561ecb60b3e5e38528062c8314643c7
26 changes: 26 additions & 0 deletions captum/attr/_core/lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,39 @@ def _register_forward_hooks(self) -> None:
)
self.backward_handles.append(backward_handle)
else:
# adding "no op" node in between layers, should be removed after
# transitioning to module backward hooks
no_op_handle = layer.register_forward_pre_hook(self._no_op_pre_hook)
self.forward_handles.append(no_op_handle)

forward_handle = layer.register_forward_hook(
layer.rule.forward_hook # type: ignore
)
self.forward_handles.append(forward_handle)
if self.verbose:
print(f"Applied {layer.rule} on layer {layer}")

@staticmethod
def _no_op_pre_hook(
module: Module,
inputs: TensorOrTupleOfTensorsGeneric,
) -> TensorOrTupleOfTensorsGeneric:
"""Pre Hook for adding no op nodes in between modules, in order to fix
the hook ordering issues.

Args:
module (nn.Module): module in question, not used but required to be
part of the signature by pytorch
inputs (TensorOrTupleOfTensorsGeneric): inputs of the module

Returns:
TensorOrTupleOfTensorsGeneric: cloned inputs
"""
if isinstance(inputs, tuple):
return tuple(x.clone() for x in inputs)
else:
return inputs.clone()

def _register_weight_hooks(self) -> None:
for layer in self.layers:
if layer.rule is not None:
Expand Down
150 changes: 9 additions & 141 deletions captum/attr/_utils/lrp_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _create_backward_hook_output(
self, outputs: Tensor
) -> Callable[[Tensor], Optional[Tensor]]:
def _backward_hook_output(grad: Tensor) -> None:
self.relevance_output = grad.data
self.relevance_output[grad.device] = grad.data

return _backward_hook_output

Expand All @@ -225,7 +225,7 @@ class WSquaredRule(PropagationRule):
"""

def __init__(self) -> None:
self._input_shapes = tuple()
self._input_shapes: Tuple[torch.Size, ...] = tuple()
self._denominator = torch.Tensor()

def forward_hook(
Expand Down Expand Up @@ -343,7 +343,6 @@ def forward_hook(
inputs = _format_tensor_into_tuples(inputs)
self._has_single_input = len(inputs) == 1
self._handle_input_hooks = list()
self.relevance_input = list()
for input_index, input_ in enumerate(inputs):
if not hasattr(input_, "hook_registered"):
input_hook = self._create_backward_hook_input(input_.data, input_index)
Expand Down Expand Up @@ -421,17 +420,20 @@ def _compute_signed_contributions(
rescaled_relevance = grad / denominator

# getting contractions with transposed Jacobian
res_wp = torch.autograd.grad(
positive_weight_contraction = torch.autograd.grad(
outputs=mod_pos_out, inputs=mod_pos_in, grad_outputs=rescaled_relevance
)
res_wm = torch.autograd.grad(
negative_weight_contraction = torch.autograd.grad(
outputs=mod_neg_out, inputs=mod_neg_in, grad_outputs=rescaled_relevance
)

out = tuple(
(pos_in * jac_pos) + (neg_in * jac_neg)
for jac_pos, pos_in, jac_neg, neg_in in zip(
res_wp, mod_pos_in, res_wm, mod_neg_in
positive_weight_contraction,
mod_pos_in,
negative_weight_contraction,
mod_neg_in,
)
)

Expand Down Expand Up @@ -473,140 +475,6 @@ def _manipulate_weights(
self._separate_weights_by_sign(module)


class AlphaBetaRuleV2(AlphaBetaRule):
r"""
AlphaBetaRule with decreased execution times but increased memory consumption.
"""

def _create_backward_hook_input(
self, input_: Tensor, input_index: int
) -> Callable[[Tensor], Optional[Tensor]]:
def _backward_hook_input(grad: Tensor) -> Tensor:
out = input_ * grad
out += self.out[input_index]
return out

return _backward_hook_input

def _create_backward_hook_output(
self, output: Tensor
) -> Callable[[Tensor], Optional[Tensor]]:
def _backward_hook_output(grad: Tensor) -> Tensor:
pos_denominator = self._compute_pos_denominator(output)
if self.beta:
neg_denominator = self.og_outputs[0] - pos_denominator
if self.set_bias_to_zero and self._bias_contrib is not None:
neg_denominator -= torch.cat(
tuple(
self._bias_contrib for _ in range(neg_denominator.shape[0])
)
)

neg_denominator -= self.STABILITY_FACTOR
beta_rel = (self.beta * grad) / neg_denominator

res_wp_beta = torch.autograd.grad(
inputs=self.inputs_pos,
outputs=self.wp_xp,
grad_outputs=beta_rel,
retain_graph=False,
)
res_wm_beta = torch.autograd.grad(
inputs=self.inputs_neg,
outputs=self.wm_xm,
grad_outputs=beta_rel,
retain_graph=True,
)

out_beta = list(
xm * wp_beta + xp * wm_beta
for xm, wp_beta, xp, wm_beta in zip(
self.inputs_neg, res_wp_beta, self.inputs_pos, res_wm_beta
)
)

pos_denominator += self.STABILITY_FACTOR
alpha_rel = (self.alpha * grad) / pos_denominator

res_wm_alpha = torch.autograd.grad(
inputs=self.inputs_neg,
outputs=self.wm_xm,
grad_outputs=alpha_rel,
retain_graph=False,
)

out_alpha = list(
xm * wm_alpha for xm, wm_alpha in zip(self.inputs_neg, res_wm_alpha)
)

if self.beta:
out_alpha = list(x + y for x, y in zip(out_alpha, out_beta))

self.out = out_alpha

return alpha_rel

return _backward_hook_output

def _compute_pos_denominator(self, output: Tensor) -> Tensor:
if not self.beta:
with torch.no_grad():
pos_denominator = output
with torch.autograd.set_grad_enabled(True):
self.wm_xm = self._module_neg.forward(*self.inputs_neg)

pos_denominator += self.wm_xm
else:
with torch.autograd.set_grad_enabled(True):
self.wp_xp = self._module_pos.forward(*self.inputs_pos)
self.wm_xm = self._module_neg.forward(*self.inputs_neg)

pos_denominator = self.wp_xp + self.wm_xm

if not self.set_bias_to_zero and self._bias_contrib is not None:
pos_denominator += torch.cat(
tuple(self._bias_contrib for _ in range(pos_denominator.shape[0]))
).clamp(min=0)

return pos_denominator.detach()

def forward_hook_weights(
self,
module: Module,
inputs: Tuple[Tensor, ...],
outputs: Tensor,
) -> None:
super().forward_hook_weights(module, inputs, outputs)
self.og_outputs = outputs

def _manipulate_weights(
self,
module: Module,
inputs: Tuple[Tensor, ...],
outputs: Tensor,
) -> None:
if hasattr(module, "bias"):
if module.bias is not None:
if self._bias_contrib is None:
with torch.no_grad():
self._bias_contrib = module.forward(
*(
torch.zeros(input_.shape[1:]).unsqueeze(dim=0)
for input_ in inputs
)
)
module.bias.data = torch.zeros_like(module.bias.data)

self._separate_weights_by_sign(module)

def forward_pre_hook_activations(
self, module: Module, inputs: Tuple[Tensor, ...]
) -> Tuple[Tensor, ...]:
for input_, activation in zip(inputs, module.activations):
input_.data = activation.clamp(min=0)
return inputs


class ZBoundRule(PropagationRule):
def __init__(
self,
Expand All @@ -631,7 +499,7 @@ def __init__(
self._bias_contrib = None

self._denominator_bound_contribution = torch.Tensor()
self._input_shapes = tuple()
self._input_shapes: Tuple[torch.Size, ...] = tuple()

def forward_hook(
self, module: Module, inputs: Tuple[Tensor, ...], outputs: Tensor
Expand Down