Skip to content

Commit 9ff3cbf

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Attribution API refactor: Base LLMAttributionResult class + refactor (#1657)
Summary: Refactor LLMAttributionResult into an abstract base object that is generic. Create LLMAttributionResult as a concrete child with aliases for captum.attr API supporting legacy use. Changes support the refactor and enable more generalized use beyond logprob-based attribution. Differential Revision: D84721127
1 parent ee5b695 commit 9ff3cbf

File tree

1 file changed

+179
-84
lines changed

1 file changed

+179
-84
lines changed

captum/attr/_core/llm_attr.py

Lines changed: 179 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@
55
from abc import ABC
66
from copy import copy
77
from dataclasses import dataclass
8-
from textwrap import dedent, shorten
8+
from textwrap import shorten
99

1010
from typing import (
1111
Any,
1212
Callable,
1313
cast,
1414
Dict,
15+
Generic,
1516
List,
1617
Optional,
1718
Tuple,
1819
Type,
1920
TYPE_CHECKING,
21+
TypeVar,
2022
Union,
2123
)
2224

@@ -56,130 +58,138 @@
5658
"temperature": None,
5759
"top_p": None,
5860
}
61+
TInputValue = TypeVar("TInputValue")
62+
TTargetValue = TypeVar("TTargetValue")
5963

6064

61-
@dataclass
62-
class LLMAttributionResult:
65+
@dataclass(kw_only=True)
66+
class BaseLLMAttributionResult(ABC, Generic[TInputValue, TTargetValue]):
6367
"""
6468
Data class for the return result of LLMAttribution,
6569
which includes the necessary properties of the attribution.
6670
It also provides utilities to help present and plot the result in different forms.
6771
"""
6872

69-
input_tokens: List[str]
70-
output_tokens: List[str]
71-
# pyre-ignore[13]: initialized via a property setter
72-
_seq_attr: Tensor
73-
_token_attr: Optional[Tensor] = None
74-
_output_probs: Optional[Tensor] = None
73+
input_values: List[TInputValue] # ablated values
74+
target_names: List[str] # names of each target, e.g. judge name or tokens
75+
_target_values: Optional[
76+
List[TTargetValue]
77+
] # value for each target name e.g. token prob
78+
_aggregate_attr: Tensor # 1D [# input_values]
79+
_element_attr: Optional[Tensor] = None # 2D [# target_names, # input_values]
80+
aggregate_descriptor: str = "Aggregate"
81+
element_descriptor: str = "Element"
7582

7683
def __init__(
7784
self,
7885
*,
79-
input_tokens: List[str],
80-
output_tokens: List[str],
81-
seq_attr: npt.ArrayLike,
82-
token_attr: Optional[npt.ArrayLike] = None,
83-
output_probs: Optional[npt.ArrayLike] = None,
86+
input_values: List[TInputValue],
87+
target_names: List[str],
88+
target_values: Optional[npt.ArrayLike] = None,
89+
aggregate_attr: npt.ArrayLike,
90+
element_attr: Optional[npt.ArrayLike] = None,
91+
aggregate_descriptor: str = "Aggregate",
92+
element_descriptor: str = "Element",
8493
) -> None:
85-
self.input_tokens = input_tokens
86-
self.output_tokens = output_tokens
87-
self.seq_attr = seq_attr
88-
self.token_attr = token_attr
89-
self.output_probs = output_probs
94+
self.input_values = input_values
95+
self.target_names = target_names
96+
self.target_values = target_values
97+
self.aggregate_attr = aggregate_attr
98+
self.element_attr = element_attr
99+
self.aggregate_descriptor = aggregate_descriptor
100+
self.element_descriptor = element_descriptor
90101

91102
@property
92-
def seq_attr(self) -> Tensor:
93-
return self._seq_attr
103+
def aggregate_attr(self) -> Tensor:
104+
return self._aggregate_attr
94105

95-
@seq_attr.setter
96-
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
106+
@aggregate_attr.setter
107+
def aggregate_attr(self, seq_attr: npt.ArrayLike) -> None:
97108
if isinstance(seq_attr, Tensor):
98-
self._seq_attr = seq_attr
109+
self._aggregate_attr = seq_attr
99110
else:
100-
self._seq_attr = torch.tensor(seq_attr)
111+
self._aggregate_attr = torch.tensor(seq_attr)
101112
# IDEA: in the future we might want to support higher dim seq_attr
102113
# (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
103-
assert len(self._seq_attr.shape) == 1, "seq_attr must be a 1D tensor"
114+
assert len(self._aggregate_attr.shape) == 1, "seq_attr must be a 1D tensor"
104115
assert (
105-
len(self.input_tokens) == self._seq_attr.shape[0]
116+
len(self.input_values) == self._aggregate_attr.shape[0]
106117
), "seq_attr and input_tokens must have the same length"
107118

108119
@property
109-
def token_attr(self) -> Optional[Tensor]:
110-
return self._token_attr
120+
def element_attr(self) -> Optional[Tensor]:
121+
return self._element_attr
111122

112-
@token_attr.setter
113-
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
123+
@element_attr.setter
124+
def element_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
114125
if token_attr is None:
115-
self._token_attr = None
126+
self._element_attr = None
116127
elif isinstance(token_attr, Tensor):
117-
self._token_attr = token_attr
128+
self._element_attr = token_attr
118129
else:
119-
self._token_attr = torch.tensor(token_attr)
130+
self._element_attr = torch.tensor(token_attr)
120131

121-
if self._token_attr is not None:
132+
if self._element_attr is not None:
122133
# IDEA: in the future we might want to support higher dim seq_attr
123-
assert len(self._token_attr.shape) == 2, "token_attr must be a 2D tensor"
124-
assert self._token_attr.shape == (
125-
len(self.output_tokens),
126-
len(self.input_tokens),
127-
), dedent(
128-
f"""\
129-
Expect token_attr to have shape
130-
{len(self.output_tokens), len(self.input_tokens)},
131-
got {self._token_attr.shape}
132-
"""
134+
assert len(self._element_attr.shape) == 2, "token_attr must be a 2D tensor"
135+
assert self._element_attr.shape == (
136+
len(self.target_names),
137+
len(self.input_values),
138+
), (
139+
"Expect token_attr to have shape "
140+
f"({len(self.target_names), len(self.input_values)}), "
141+
f"got {self._element_attr.shape}"
133142
)
134143

135144
@property
136-
def output_probs(self) -> Optional[Tensor]:
137-
return self._output_probs
138-
139-
@output_probs.setter
140-
def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None:
141-
if output_probs is None:
142-
self._output_probs = None
143-
elif isinstance(output_probs, Tensor):
144-
self._output_probs = output_probs
145+
def target_values(self) -> Optional[List[TTargetValue]]:
146+
return self._target_values
147+
148+
@target_values.setter
149+
def target_values(self, target_values: Optional[npt.ArrayLike]) -> None:
150+
if target_values is None:
151+
self._target_values = None
152+
elif isinstance(target_values, (Tensor, np.ndarray)):
153+
self._target_values = target_values.tolist()
145154
else:
146-
self._output_probs = torch.tensor(output_probs)
155+
# pyre-ignore[6]: should be iterable
156+
self._target_values = list(target_values)
147157

148-
if self._output_probs is not None:
149-
assert (
150-
len(self._output_probs.shape) == 1
151-
), "output_probs must be a 1D tensor"
152-
assert (
153-
len(self.output_tokens) == self._output_probs.shape[0]
154-
), "seq_attr and input_tokens must have the same length"
158+
if self._target_values is not None:
159+
assert len(self._target_values) == len(
160+
self.target_names
161+
), f"{len(self._target_values)=} and {len(self.target_names)=} must have the same length"
155162

156163
@property
157-
def seq_attr_dict(self) -> Dict[str, float]:
158-
return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}
164+
def aggregate_attr_dict(self) -> Dict[TInputValue, float]:
165+
return {
166+
k: v for v, k in zip(self.aggregate_attr.cpu().tolist(), self.input_values)
167+
}
159168

160-
def plot_token_attr(
169+
def plot_element_attr(
161170
self, show: bool = False
162171
) -> Union[None, Tuple["Figure", "Axes"]]:
163172
"""
164173
Generate a matplotlib plot for visualising the attribution
165-
of the output tokens.
174+
of the output elements.
166175
167176
Args:
168177
show (bool): whether to show the plot directly or return the figure and axis
169178
Default: False
170179
"""
171180

172-
if self.token_attr is None:
181+
if self.element_attr is None:
173182
raise ValueError(
174-
"token_attr is None (no token-level attribution was performed), please "
175-
"use plot_seq_attr instead for the sequence-level attribution plot"
183+
f"element_attr is None (no {self.element_descriptor.lower()}-level attribution was "
184+
"performed), please use plot_aggregate_attr instead for the "
185+
f"{self.aggregate_descriptor}-level attribution plot"
176186
)
177-
token_attr = self.token_attr.cpu()
187+
element_attr = self.element_attr.cpu()
178188

179189
# maximum absolute attribution value
180190
# used as the boundary of normalization
181191
# always keep 0 as the mid point to differentiate pos/neg attr
182-
max_abs_attr_val = token_attr.abs().max().item()
192+
max_abs_attr_val = element_attr.abs().max().item()
183193

184194
import matplotlib.pyplot as plt
185195

@@ -189,7 +199,7 @@ def plot_token_attr(
189199
ax.grid(False)
190200

191201
# Plot the heatmap
192-
data = token_attr.numpy()
202+
data = element_attr.numpy()
193203

194204
fig.set_size_inches(
195205
max(data.shape[1] * 1.3, 6.4), max(data.shape[0] / 2.5, 4.8)
@@ -219,17 +229,19 @@ def plot_token_attr(
219229

220230
# Create colorbar
221231
cbar = fig.colorbar(im, ax=ax) # type: ignore
222-
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
232+
cbar.ax.set_ylabel(
233+
f"{self.element_descriptor} Attribution", rotation=-90, va="bottom"
234+
)
223235

224236
# Show all ticks and label them with the respective list entries.
225-
shortened_tokens = [
237+
shortened_values = [
226238
shorten(repr(t)[1:-1], width=50, placeholder="...")
227-
for t in self.input_tokens
239+
for t in self.input_values
228240
]
229-
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_tokens)
241+
ax.set_xticks(np.arange(data.shape[1]), labels=shortened_values)
230242
ax.set_yticks(
231243
np.arange(data.shape[0]),
232-
labels=[repr(token)[1:-1] for token in self.output_tokens],
244+
labels=[repr(name)[1:-1] for name in self.target_names],
233245
)
234246

235247
# Let the horizontal axes labeling appear on top.
@@ -259,10 +271,12 @@ def plot_token_attr(
259271
else:
260272
return fig, ax
261273

262-
def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]:
274+
def plot_aggregated_attr(
275+
self, show: bool = False
276+
) -> Union[None, Tuple["Figure", "Axes"]]:
263277
"""
264278
Generate a matplotlib plot for visualising the attribution
265-
of the output sequence.
279+
of the aggregated output.
266280
267281
Args:
268282
show (bool): whether to show the plot directly or return the figure and axis
@@ -273,15 +287,15 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes
273287

274288
fig, ax = plt.subplots()
275289

276-
data = self.seq_attr.cpu().numpy()
290+
data = self.aggregate_attr.cpu().numpy()
277291

278292
fig.set_size_inches(max(data.shape[0] / 2, 6.4), max(data.shape[0] / 4, 4.8))
279293

280-
shortened_tokens = [
294+
shortened_values = [
281295
shorten(repr(t)[1:-1], width=50, placeholder="...")
282-
for t in self.input_tokens
296+
for t in self.input_values
283297
]
284-
ax.set_xticks(range(data.shape[0]), labels=shortened_tokens)
298+
ax.set_xticks(range(data.shape[0]), labels=shortened_values)
285299

286300
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
287301

@@ -309,14 +323,95 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes
309323
color="#d0365b",
310324
)
311325

312-
ax.set_ylabel("Sequence Attribution", rotation=90, va="bottom")
326+
ax.set_ylabel(
327+
f"{self.aggregate_descriptor} Attribution", rotation=90, va="bottom"
328+
)
313329

314330
if show:
315331
plt.show()
316332
return None # mypy wants this
317333
else:
318334
return fig, ax
319335

336+
# Aliases
337+
338+
@property
339+
def input_tokens(self) -> List[TInputValue]:
340+
return self.input_values
341+
342+
@input_tokens.setter
343+
def input_tokens(self, input_tokens: List[TInputValue]) -> None:
344+
self.input_values = input_tokens
345+
346+
@property
347+
def output_tokens(self) -> List[str]:
348+
return self.target_names
349+
350+
@output_tokens.setter
351+
def output_tokens(self, output_tokens: List[str]) -> None:
352+
self.target_names = output_tokens
353+
354+
@property
355+
def output_probs(self) -> Optional[List[TTargetValue]]:
356+
return self.target_values
357+
358+
@output_probs.setter
359+
def output_probs(self, output_probs: Optional[npt.ArrayLike]) -> None:
360+
self.target_values = output_probs
361+
362+
@property
363+
def seq_attr(self) -> Tensor:
364+
return self.aggregate_attr
365+
366+
@seq_attr.setter
367+
def seq_attr(self, seq_attr: npt.ArrayLike) -> None:
368+
self.aggregate_attr = seq_attr
369+
370+
@property
371+
def token_attr(self) -> Optional[Tensor]:
372+
return self.element_attr
373+
374+
@token_attr.setter
375+
def token_attr(self, token_attr: Optional[npt.ArrayLike]) -> None:
376+
self.element_attr = token_attr
377+
378+
@property
379+
def seq_attr_dict(self) -> Dict[TInputValue, float]:
380+
return self.aggregate_attr_dict
381+
382+
def plot_token_attr(
383+
self, show: bool = False
384+
) -> Union[None, Tuple["Figure", "Axes"]]:
385+
return self.plot_element_attr(show=show)
386+
387+
def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes"]]:
388+
return self.plot_aggregated_attr(show=show)
389+
390+
391+
@dataclass(kw_only=True)
392+
# pyre-ignore[13]: _aggregate_attr and _target_values initialized via setters
393+
class LLMAttributionResult(BaseLLMAttributionResult[str, float]):
394+
"""LLM Attribution Result for the captum.attr API"""
395+
396+
def __init__(
397+
self,
398+
*,
399+
input_tokens: List[str],
400+
output_tokens: List[str],
401+
seq_attr: npt.ArrayLike,
402+
token_attr: Optional[npt.ArrayLike] = None,
403+
output_probs: Optional[npt.ArrayLike] = None,
404+
) -> None:
405+
super().__init__(
406+
input_values=input_tokens,
407+
target_names=output_tokens,
408+
target_values=output_probs,
409+
aggregate_attr=seq_attr,
410+
element_attr=token_attr,
411+
aggregate_descriptor="Sequence",
412+
element_descriptor="Token",
413+
)
414+
320415

321416
def _clean_up_pretty_token(token: str) -> str:
322417
"""Remove newlines and leading/trailing whitespace from token."""

0 commit comments

Comments
 (0)