Skip to content

Commit 45572c2

Browse files
yiyixuxuyiyixuxu
andauthored
fix the get_indices function (huggingface#2418)
Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 5f65ef4 commit 45572c2

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
4848
>>> # use get_indices function to find out indices of the tokens you want to alter
4949
>>> pipe.get_indices(prompt)
50+
{0: '<|startoftext|>', 1: 'a</w>', 2: 'cat</w>', 3: 'and</w>', 4: 'a</w>', 5: 'frog</w>', 6: '<|endoftext|>'}
5051
5152
>>> token_indices = [2, 5]
5253
>>> seed = 6141
@@ -662,7 +663,7 @@ def register_attention_control(self):
662663
def get_indices(self, prompt: str) -> Dict[str, int]:
663664
"""Utility function to list the indices of the tokens you wish to alte"""
664665
ids = self.tokenizer(prompt).input_ids
665-
indices = {tok: i for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))}
666+
indices = {i: tok for tok, i in zip(self.tokenizer.convert_ids_to_tokens(ids), range(len(ids)))}
666667
return indices
667668

668669
@torch.no_grad()

0 commit comments

Comments
 (0)