55from abc import ABC
66from copy import copy
77from dataclasses import dataclass
8- from textwrap import dedent , shorten
8+ from textwrap import shorten
99
1010from typing import (
1111 Any ,
1212 Callable ,
1313 cast ,
1414 Dict ,
15+ Generic ,
1516 List ,
1617 Optional ,
1718 Tuple ,
1819 Type ,
1920 TYPE_CHECKING ,
21+ TypeVar ,
2022 Union ,
2123)
2224
5658 "temperature" : None ,
5759 "top_p" : None ,
5860}
61+ TInputValue = TypeVar ("TInputValue" )
62+ TTargetValue = TypeVar ("TTargetValue" )
5963
6064
61- @dataclass
62- class LLMAttributionResult :
65+ @dataclass ( kw_only = True )
66+ class BaseLLMAttributionResult ( ABC , Generic [ TInputValue , TTargetValue ]) :
6367 """
6468 Data class for the return result of LLMAttribution,
6569 which includes the necessary properties of the attribution.
6670 It also provides utilities to help present and plot the result in different forms.
6771 """
6872
69- input_tokens : List [str ]
70- output_tokens : List [str ]
71- # pyre-ignore[13]: initialized via a property setter
72- _seq_attr : Tensor
73- _token_attr : Optional [Tensor ] = None
74- _output_probs : Optional [Tensor ] = None
73+ input_values : List [TInputValue ] # ablated values
74+ target_names : List [str ] # names of each target, e.g. judge name or tokens
75+ _target_values : Optional [
76+ List [TTargetValue ]
77+ ] # value for each target name e.g. token prob
78+ _aggregate_attr : Tensor # 1D [# input_values]
79+ _element_attr : Optional [Tensor ] = None # 2D [# target_names, # input_values]
80+ aggregate_descriptor : str = "Aggregate"
81+ element_descriptor : str = "Element"
7582
7683 def __init__ (
7784 self ,
7885 * ,
79- input_tokens : List [str ],
80- output_tokens : List [str ],
81- seq_attr : npt .ArrayLike ,
82- token_attr : Optional [npt .ArrayLike ] = None ,
83- output_probs : Optional [npt .ArrayLike ] = None ,
86+ input_values : List [TInputValue ],
87+ target_names : List [str ],
88+ target_values : Optional [npt .ArrayLike ] = None ,
89+ aggregate_attr : npt .ArrayLike ,
90+ element_attr : Optional [npt .ArrayLike ] = None ,
91+ aggregate_descriptor : str = "Aggregate" ,
92+ element_descriptor : str = "Element" ,
8493 ) -> None :
85- self .input_tokens = input_tokens
86- self .output_tokens = output_tokens
87- self .seq_attr = seq_attr
88- self .token_attr = token_attr
89- self .output_probs = output_probs
94+ self .input_values = input_values
95+ self .target_names = target_names
96+ self .target_values = target_values
97+ self .aggregate_attr = aggregate_attr
98+ self .element_attr = element_attr
99+ self .aggregate_descriptor = aggregate_descriptor
100+ self .element_descriptor = element_descriptor
90101
91102 @property
92- def seq_attr (self ) -> Tensor :
93- return self ._seq_attr
103+ def aggregate_attr (self ) -> Tensor :
104+ return self ._aggregate_attr
94105
95- @seq_attr .setter
96- def seq_attr (self , seq_attr : npt .ArrayLike ) -> None :
106+ @aggregate_attr .setter
107+ def aggregate_attr (self , seq_attr : npt .ArrayLike ) -> None :
97108 if isinstance (seq_attr , Tensor ):
98- self ._seq_attr = seq_attr
109+ self ._aggregate_attr = seq_attr
99110 else :
100- self ._seq_attr = torch .tensor (seq_attr )
111+ self ._aggregate_attr = torch .tensor (seq_attr )
101112 # IDEA: in the future we might want to support higher dim seq_attr
102113 # (e.g. attention w.r.t. multiple layers, gradients w.r.t. different classes)
103- assert len (self ._seq_attr .shape ) == 1 , "seq_attr must be a 1D tensor"
114+ assert len (self ._aggregate_attr .shape ) == 1 , "seq_attr must be a 1D tensor"
104115 assert (
105- len (self .input_tokens ) == self ._seq_attr .shape [0 ]
116+ len (self .input_values ) == self ._aggregate_attr .shape [0 ]
106117 ), "seq_attr and input_tokens must have the same length"
107118
108119 @property
109- def token_attr (self ) -> Optional [Tensor ]:
110- return self ._token_attr
120+ def element_attr (self ) -> Optional [Tensor ]:
121+ return self ._element_attr
111122
112- @token_attr .setter
113- def token_attr (self , token_attr : Optional [npt .ArrayLike ]) -> None :
123+ @element_attr .setter
124+ def element_attr (self , token_attr : Optional [npt .ArrayLike ]) -> None :
114125 if token_attr is None :
115- self ._token_attr = None
126+ self ._element_attr = None
116127 elif isinstance (token_attr , Tensor ):
117- self ._token_attr = token_attr
128+ self ._element_attr = token_attr
118129 else :
119- self ._token_attr = torch .tensor (token_attr )
130+ self ._element_attr = torch .tensor (token_attr )
120131
121- if self ._token_attr is not None :
132+ if self ._element_attr is not None :
122133 # IDEA: in the future we might want to support higher dim seq_attr
123- assert len (self ._token_attr .shape ) == 2 , "token_attr must be a 2D tensor"
124- assert self ._token_attr .shape == (
125- len (self .output_tokens ),
126- len (self .input_tokens ),
127- ), dedent (
128- f"""\
129- Expect token_attr to have shape
130- { len (self .output_tokens ), len (self .input_tokens )} ,
131- got { self ._token_attr .shape }
132- """
134+ assert len (self ._element_attr .shape ) == 2 , "token_attr must be a 2D tensor"
135+ assert self ._element_attr .shape == (
136+ len (self .target_names ),
137+ len (self .input_values ),
138+ ), (
139+ "Expect token_attr to have shape "
140+ f"({ len (self .target_names ), len (self .input_values )} ), "
141+ f"got { self ._element_attr .shape } "
133142 )
134143
135144 @property
136- def output_probs (self ) -> Optional [Tensor ]:
137- return self ._output_probs
138-
139- @output_probs .setter
140- def output_probs (self , output_probs : Optional [npt .ArrayLike ]) -> None :
141- if output_probs is None :
142- self ._output_probs = None
143- elif isinstance (output_probs , Tensor ):
144- self ._output_probs = output_probs
145+ def target_values (self ) -> Optional [List [ TTargetValue ] ]:
146+ return self ._target_values
147+
148+ @target_values .setter
149+ def target_values (self , target_values : Optional [npt .ArrayLike ]) -> None :
150+ if target_values is None :
151+ self ._target_values = None
152+ elif isinstance (target_values , ( Tensor , np . ndarray ) ):
153+ self ._target_values = target_values . tolist ()
145154 else :
146- self ._output_probs = torch .tensor (output_probs )
155+ # pyre-ignore[6]: should be iterable
156+ self ._target_values = list (target_values )
147157
148- if self ._output_probs is not None :
149- assert (
150- len (self ._output_probs .shape ) == 1
151- ), "output_probs must be a 1D tensor"
152- assert (
153- len (self .output_tokens ) == self ._output_probs .shape [0 ]
154- ), "seq_attr and input_tokens must have the same length"
158+ if self ._target_values is not None :
159+ assert len (self ._target_values ) == len (
160+ self .target_names
161+ ), f"{ len (self ._target_values )= } and { len (self .target_names )= } must have the same length"
155162
156163 @property
157- def seq_attr_dict (self ) -> Dict [str , float ]:
158- return {k : v for v , k in zip (self .seq_attr .cpu ().tolist (), self .input_tokens )}
164+ def aggregate_attr_dict (self ) -> Dict [TInputValue , float ]:
165+ return {
166+ k : v for v , k in zip (self .aggregate_attr .cpu ().tolist (), self .input_values )
167+ }
159168
160- def plot_token_attr (
169+ def plot_element_attr (
161170 self , show : bool = False
162171 ) -> Union [None , Tuple ["Figure" , "Axes" ]]:
163172 """
164173 Generate a matplotlib plot for visualising the attribution
165- of the output tokens .
174+ of the output elements .
166175
167176 Args:
168177 show (bool): whether to show the plot directly or return the figure and axis
169178 Default: False
170179 """
171180
172- if self .token_attr is None :
181+ if self .element_attr is None :
173182 raise ValueError (
174- "token_attr is None (no token-level attribution was performed), please "
175- "use plot_seq_attr instead for the sequence-level attribution plot"
183+ f"element_attr is None (no { self .element_descriptor .lower ()} -level attribution was "
184+ "performed), please use plot_aggregate_attr instead for the "
185+ f"{ self .aggregate_descriptor } -level attribution plot"
176186 )
177- token_attr = self .token_attr .cpu ()
187+ element_attr = self .element_attr .cpu ()
178188
179189 # maximum absolute attribution value
180190 # used as the boundary of normalization
181191 # always keep 0 as the mid point to differentiate pos/neg attr
182- max_abs_attr_val = token_attr .abs ().max ().item ()
192+ max_abs_attr_val = element_attr .abs ().max ().item ()
183193
184194 import matplotlib .pyplot as plt
185195
@@ -189,7 +199,7 @@ def plot_token_attr(
189199 ax .grid (False )
190200
191201 # Plot the heatmap
192- data = token_attr .numpy ()
202+ data = element_attr .numpy ()
193203
194204 fig .set_size_inches (
195205 max (data .shape [1 ] * 1.3 , 6.4 ), max (data .shape [0 ] / 2.5 , 4.8 )
@@ -219,17 +229,19 @@ def plot_token_attr(
219229
220230 # Create colorbar
221231 cbar = fig .colorbar (im , ax = ax ) # type: ignore
222- cbar .ax .set_ylabel ("Token Attribution" , rotation = - 90 , va = "bottom" )
232+ cbar .ax .set_ylabel (
233+ f"{ self .element_descriptor } Attribution" , rotation = - 90 , va = "bottom"
234+ )
223235
224236 # Show all ticks and label them with the respective list entries.
225- shortened_tokens = [
237+ shortened_values = [
226238 shorten (repr (t )[1 :- 1 ], width = 50 , placeholder = "..." )
227- for t in self .input_tokens
239+ for t in self .input_values
228240 ]
229- ax .set_xticks (np .arange (data .shape [1 ]), labels = shortened_tokens )
241+ ax .set_xticks (np .arange (data .shape [1 ]), labels = shortened_values )
230242 ax .set_yticks (
231243 np .arange (data .shape [0 ]),
232- labels = [repr (token )[1 :- 1 ] for token in self .output_tokens ],
244+ labels = [repr (name )[1 :- 1 ] for name in self .target_names ],
233245 )
234246
235247 # Let the horizontal axes labeling appear on top.
@@ -259,10 +271,12 @@ def plot_token_attr(
259271 else :
260272 return fig , ax
261273
262- def plot_seq_attr (self , show : bool = False ) -> Union [None , Tuple ["Figure" , "Axes" ]]:
274+ def plot_aggregated_attr (
275+ self , show : bool = False
276+ ) -> Union [None , Tuple ["Figure" , "Axes" ]]:
263277 """
264278 Generate a matplotlib plot for visualising the attribution
265- of the output sequence .
279+ of the aggregated output .
266280
267281 Args:
268282 show (bool): whether to show the plot directly or return the figure and axis
@@ -273,15 +287,15 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes
273287
274288 fig , ax = plt .subplots ()
275289
276- data = self .seq_attr .cpu ().numpy ()
290+ data = self .aggregate_attr .cpu ().numpy ()
277291
278292 fig .set_size_inches (max (data .shape [0 ] / 2 , 6.4 ), max (data .shape [0 ] / 4 , 4.8 ))
279293
280- shortened_tokens = [
294+ shortened_values = [
281295 shorten (repr (t )[1 :- 1 ], width = 50 , placeholder = "..." )
282- for t in self .input_tokens
296+ for t in self .input_values
283297 ]
284- ax .set_xticks (range (data .shape [0 ]), labels = shortened_tokens )
298+ ax .set_xticks (range (data .shape [0 ]), labels = shortened_values )
285299
286300 ax .tick_params (top = True , bottom = False , labeltop = True , labelbottom = False )
287301
@@ -309,14 +323,95 @@ def plot_seq_attr(self, show: bool = False) -> Union[None, Tuple["Figure", "Axes
309323 color = "#d0365b" ,
310324 )
311325
312- ax .set_ylabel ("Sequence Attribution" , rotation = 90 , va = "bottom" )
326+ ax .set_ylabel (
327+ f"{ self .aggregate_descriptor } Attribution" , rotation = 90 , va = "bottom"
328+ )
313329
314330 if show :
315331 plt .show ()
316332 return None # mypy wants this
317333 else :
318334 return fig , ax
319335
336+ # Aliases
337+
338+ @property
339+ def input_tokens (self ) -> List [TInputValue ]:
340+ return self .input_values
341+
342+ @input_tokens .setter
343+ def input_tokens (self , input_tokens : List [TInputValue ]) -> None :
344+ self .input_values = input_tokens
345+
346+ @property
347+ def output_tokens (self ) -> List [str ]:
348+ return self .target_names
349+
350+ @output_tokens .setter
351+ def output_tokens (self , output_tokens : List [str ]) -> None :
352+ self .target_names = output_tokens
353+
354+ @property
355+ def output_probs (self ) -> Optional [List [TTargetValue ]]:
356+ return self .target_values
357+
358+ @output_probs .setter
359+ def output_probs (self , output_probs : Optional [npt .ArrayLike ]) -> None :
360+ self .target_values = output_probs
361+
362+ @property
363+ def seq_attr (self ) -> Tensor :
364+ return self .aggregate_attr
365+
366+ @seq_attr .setter
367+ def seq_attr (self , seq_attr : npt .ArrayLike ) -> None :
368+ self .aggregate_attr = seq_attr
369+
370+ @property
371+ def token_attr (self ) -> Optional [Tensor ]:
372+ return self .element_attr
373+
374+ @token_attr .setter
375+ def token_attr (self , token_attr : Optional [npt .ArrayLike ]) -> None :
376+ self .element_attr = token_attr
377+
378+ @property
379+ def seq_attr_dict (self ) -> Dict [TInputValue , float ]:
380+ return self .aggregate_attr_dict
381+
382+ def plot_token_attr (
383+ self , show : bool = False
384+ ) -> Union [None , Tuple ["Figure" , "Axes" ]]:
385+ return self .plot_element_attr (show = show )
386+
387+ def plot_seq_attr (self , show : bool = False ) -> Union [None , Tuple ["Figure" , "Axes" ]]:
388+ return self .plot_aggregated_attr (show = show )
389+
390+
391+ @dataclass (kw_only = True )
392+ # pyre-ignore[13]: _aggregate_attr and _target_values initialized via setters
393+ class LLMAttributionResult (BaseLLMAttributionResult [str , float ]):
394+ """LLM Attribution Result for the captum.attr API"""
395+
396+ def __init__ (
397+ self ,
398+ * ,
399+ input_tokens : List [str ],
400+ output_tokens : List [str ],
401+ seq_attr : npt .ArrayLike ,
402+ token_attr : Optional [npt .ArrayLike ] = None ,
403+ output_probs : Optional [npt .ArrayLike ] = None ,
404+ ) -> None :
405+ super ().__init__ (
406+ input_values = input_tokens ,
407+ target_names = output_tokens ,
408+ target_values = output_probs ,
409+ aggregate_attr = seq_attr ,
410+ element_attr = token_attr ,
411+ aggregate_descriptor = "Sequence" ,
412+ element_descriptor = "Token" ,
413+ )
414+
320415
321416def _clean_up_pretty_token (token : str ) -> str :
322417 """Remove newlines and leading/trailing whitespace from token."""
0 commit comments