Skip to content

Question / Confirmation on arange in encode_text #371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
AwePhD opened this issue Jun 29, 2023 · 2 comments
Closed

Question / Confirmation on arange in encode_text #371

AwePhD opened this issue Jun 29, 2023 · 2 comments

Comments

@AwePhD
Copy link

AwePhD commented Jun 29, 2023

Hello,

This is a minor question about the code, I want to be sure I do not let slip any subtleties.

In L354 of model.py, there is the final step to extract the text

# x.shape = [batch_size, n_ctx, transformer.width]
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
  • self.text_projection is the last projection to have the text features
  • text.argmax(dim=-1) picks the features of the EOT token.

Why there is a torch.arange(x.shape[0])? It could be x[:, text.argmax(dim=-1)], right?

Thanks for the work, code and model.

@jongwook
Copy link
Collaborator

We need to extract the features (of shape transformer.width) from specific different locations of the EOT token across the batch, so the shape becomes [batch_size, transformer.width].

If we did x[:, text.argmax(dim=-1)], it simply selects the features at any of the argmax indices across the batch, so the shape would become [batch_size, batch_size, transformer.width].

It could be simpler to use index_select instead and still achieve the same effect, but I found myself fewer mistakes by explicitly indexing with torch.arange().

@AwePhD
Copy link
Author

AwePhD commented Jul 11, 2023

Hello,

You are absolutely right, thanks for your time to clear my confusion!

@AwePhD AwePhD closed this as completed Jul 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants