-
Notifications
You must be signed in to change notification settings - Fork 171
FEAT: add text embedder #1694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FEAT: add text embedder #1694
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left some comments! I recommend reading first the official docs, then update the implementation and the tests. I also recommend adding an example (and test it out) or an integration test. LMK if something is not clear! 😉
@@ -23,7 +23,7 @@ classifiers = [ | |||
"Programming Language :: Python :: Implementation :: CPython", | |||
"Programming Language :: Python :: Implementation :: PyPy", | |||
] | |||
dependencies = ["haystack-ai>=2.9.0", "google-generativeai>=0.3.1"] | |||
dependencies = ["haystack-ai>=2.9.0", "google-generativeai>=0.3.1", "google-genai==1.13.0"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding google-genai==1.13.0
with exact version pinning could cause conflicts. What about google-genai>=1.13.0
?
:param model: The name of the Google AI embedding model to use. | ||
Defaults to "models/embedding-001". | ||
:param api_key: The Google AI API key. It can be explicitly provided or automatically read from the | ||
`GOOGLE_API_KEY` environment variable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: above you're initializing this from GEMINI_API_KEY
, but here you are referring to GOOGLE_API_KEY
.
configs.title = self.title | ||
elif self.title and self.task_type != "retrieval_document": | ||
warnings.warn( | ||
UserWarning("Warning: Title 'Should Be Ignored' is ignored because task_type is 'retrieval_query'"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be f"Warning: title '{self.title}' is ignored..."
?
raise RuntimeError(msg) from e | ||
|
||
# Extract embeddings - result.embedding should be the list of lists | ||
embeddings = result.get("embedding") # Use .get for safety, returns None if key missing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to docs, result should be an object and not a dict
, so you should do result.embeddings
.
I see that in the tests you're mocking this response (so tests are actually passing), but have you tried it outside tests (e.g. in an integration tests or an example?)
texts = ["text 1", "text 2"] | ||
expected_embeddings = [[0.1, 0.2], [0.3, 0.4]] | ||
# Configure the mock embed_content method to return a successful response | ||
mock_client_instance.models.embed_content.return_value = {"embedding": expected_embeddings} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the wrong mocking I was mentioning above.
According to docs, a correct mock should be something like:
mock_response = MagicMock()
mock_response.embeddings = None # or e.g. [[0.1, 0.2], [0.3, 0.4]]
mock_client_instance.models.embed_content.return_value = mock_response
Can you please update tests accordingly?
Related Issues
Proposed Changes:
Added the Google Text Embedder
How did you test it?
Unit tests
Notes for the reviewer
Didn't add the Document Embedder, is it needed?
Checklist
fix:
,feat:
,build:
,chore:
,ci:
,docs:
,style:
,refactor:
,perf:
,test:
.