You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We need to extract the features (of shape transformer.width) from specific different locations of the EOT token across the batch, so the shape becomes [batch_size, transformer.width].
If we did x[:, text.argmax(dim=-1)], it simply selects the features at any of the argmax indices across the batch, so the shape would become [batch_size, batch_size, transformer.width].
It could be simpler to use index_select instead and still achieve the same effect, but I found myself fewer mistakes by explicitly indexing with torch.arange().
Hello,
This is a minor question about the code, I want to be sure I do not let slip any subtleties.
In L354 of
model.py
, there is the final step to extract the textself.text_projection
is the last projection to have the text featurestext.argmax(dim=-1)
picks the features of the EOT token.Why there is a
torch.arange(x.shape[0])
? It could bex[:, text.argmax(dim=-1)]
, right?Thanks for the work, code and model.
The text was updated successfully, but these errors were encountered: