Skip to content

Commit 48687a6

Browse files
feat: Jina reranker toolkit (camel-ai#2170)
Co-authored-by: Wendong-Fan <[email protected]>
1 parent eb3f3cc commit 48687a6

File tree

5 files changed

+525
-0
lines changed

5 files changed

+525
-0
lines changed

camel/toolkits/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from .pyautogui_toolkit import PyAutoGUIToolkit
6868
from .openai_agent_toolkit import OpenAIAgentToolkit
6969
from .searxng_toolkit import SearxNGToolkit
70+
from .jina_reranker_toolkit import JinaRerankerToolkit
7071

7172

7273
__all__ = [
@@ -122,4 +123,5 @@
122123
'PyAutoGUIToolkit',
123124
'OpenAIAgentToolkit',
124125
'SearxNGToolkit',
126+
'JinaRerankerToolkit',
125127
]
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
from typing import List, Optional, Tuple
15+
16+
import torch
17+
from transformers import AutoModel
18+
19+
from camel.toolkits import FunctionTool
20+
from camel.toolkits.base import BaseToolkit
21+
from camel.utils import MCPServer
22+
23+
24+
@MCPServer()
25+
class JinaRerankerToolkit(BaseToolkit):
26+
r"""A class representing a toolkit for reranking documents
27+
using Jina Reranker.
28+
29+
This class provides methods for reranking documents (text or images)
30+
based on their relevance to a given query using the Jina Reranker model.
31+
"""
32+
33+
def __init__(
34+
self,
35+
timeout: Optional[float] = None,
36+
device: Optional[str] = None,
37+
) -> None:
38+
r"""Initializes a new instance of the JinaRerankerToolkit class.
39+
40+
Args:
41+
timeout (Optional[float]): The timeout value for API requests
42+
in seconds. If None, no timeout is applied.
43+
(default: :obj:`None`)
44+
device (Optional[str]): Device to load the model on. If None,
45+
will use CUDA if available, otherwise CPU.
46+
(default: :obj:`None`)
47+
"""
48+
super().__init__(timeout=timeout)
49+
50+
self.model = AutoModel.from_pretrained(
51+
'jinaai/jina-reranker-m0',
52+
torch_dtype="auto",
53+
trust_remote_code=True,
54+
)
55+
DEVICE = (
56+
device
57+
if device is not None
58+
else ("cuda" if torch.cuda.is_available() else "cpu")
59+
)
60+
self.model.to(DEVICE)
61+
self.model.eval()
62+
63+
def _sort_documents(
64+
self, documents: List[str], scores: List[float]
65+
) -> List[Tuple[str, float]]:
66+
r"""Sort documents by their scores in descending order.
67+
68+
Args:
69+
documents (List[str]): List of documents to sort.
70+
scores (List[float]): Corresponding scores for each document.
71+
72+
Returns:
73+
List[Tuple[str, float]]: Sorted list of (document, score) pairs.
74+
75+
Raises:
76+
ValueError: If documents and scores have different lengths.
77+
"""
78+
if len(documents) != len(scores):
79+
raise ValueError("Number of documents must match number of scores")
80+
doc_score_pairs = list(zip(documents, scores))
81+
doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
82+
83+
return doc_score_pairs
84+
85+
def rerank_text_documents(
86+
self,
87+
query: str,
88+
documents: List[str],
89+
max_length: int = 1024,
90+
) -> List[Tuple[str, float]]:
91+
r"""Reranks text documents based on their relevance to a text query.
92+
93+
Args:
94+
query (str): The text query for reranking.
95+
documents (List[str]): List of text documents to be reranked.
96+
max_length (int): Maximum token length for processing.
97+
(default: :obj:`1024`)
98+
99+
Returns:
100+
List[Tuple[str, float]]: A list of tuples containing
101+
the reranked documents and their relevance scores.
102+
"""
103+
if self.model is None:
104+
raise ValueError(
105+
"Model has not been initialized or failed to initialize."
106+
)
107+
108+
with torch.inference_mode():
109+
text_pairs = [[query, doc] for doc in documents]
110+
scores = self.model.compute_score(
111+
text_pairs, max_length=max_length, doc_type="text"
112+
)
113+
114+
return self._sort_documents(documents, scores)
115+
116+
def rerank_image_documents(
117+
self,
118+
query: str,
119+
documents: List[str],
120+
max_length: int = 2048,
121+
) -> List[Tuple[str, float]]:
122+
r"""Reranks image documents based on their relevance to a text query.
123+
124+
Args:
125+
query (str): The text query for reranking.
126+
documents (List[str]): List of image URLs or paths to be reranked.
127+
max_length (int): Maximum token length for processing.
128+
(default: :obj:`2048`)
129+
130+
Returns:
131+
List[Tuple[str, float]]: A list of tuples containing
132+
the reranked image URLs/paths and their relevance scores.
133+
"""
134+
if self.model is None:
135+
raise ValueError(
136+
"Model has not been initialized or failed to initialize."
137+
)
138+
139+
with torch.inference_mode():
140+
image_pairs = [[query, doc] for doc in documents]
141+
scores = self.model.compute_score(
142+
image_pairs, max_length=max_length, doc_type="image"
143+
)
144+
145+
return self._sort_documents(documents, scores)
146+
147+
def image_query_text_documents(
148+
self,
149+
image_query: str,
150+
documents: List[str],
151+
max_length: int = 2048,
152+
) -> List[Tuple[str, float]]:
153+
r"""Reranks text documents based on their relevance to an image query.
154+
155+
Args:
156+
image_query (str): The image URL or path used as query.
157+
documents (List[str]): List of text documents to be reranked.
158+
max_length (int): Maximum token length for processing.
159+
(default: :obj:`2048`)
160+
161+
Returns:
162+
List[Tuple[str, float]]: A list of tuples containing
163+
the reranked documents and their relevance scores.
164+
"""
165+
if self.model is None:
166+
raise ValueError("Model has not been initialized.")
167+
with torch.inference_mode():
168+
image_pairs = [[image_query, doc] for doc in documents]
169+
scores = self.model.compute_score(
170+
image_pairs,
171+
max_length=max_length,
172+
query_type="image",
173+
doc_type="text",
174+
)
175+
176+
return self._sort_documents(documents, scores)
177+
178+
def image_query_image_documents(
179+
self,
180+
image_query: str,
181+
documents: List[str],
182+
max_length: int = 2048,
183+
) -> List[Tuple[str, float]]:
184+
r"""Reranks image documents based on their relevance to an image query.
185+
186+
Args:
187+
image_query (str): The image URL or path used as query.
188+
documents (List[str]): List of image URLs or paths to be reranked.
189+
max_length (int): Maximum token length for processing.
190+
(default: :obj:`2048`)
191+
192+
Returns:
193+
List[Tuple[str, float]]: A list of tuples containing
194+
the reranked image URLs/paths and their relevance scores.
195+
"""
196+
if self.model is None:
197+
raise ValueError("Model has not been initialized.")
198+
199+
with torch.inference_mode():
200+
image_pairs = [[image_query, doc] for doc in documents]
201+
scores = self.model.compute_score(
202+
image_pairs,
203+
max_length=max_length,
204+
query_type="image",
205+
doc_type="image",
206+
)
207+
208+
return self._sort_documents(documents, scores)
209+
210+
def get_tools(self) -> List[FunctionTool]:
211+
r"""Returns a list of FunctionTool objects representing the
212+
functions in the toolkit.
213+
214+
Returns:
215+
List[FunctionTool]: A list of FunctionTool objects
216+
representing the functions in the toolkit.
217+
"""
218+
return [
219+
FunctionTool(self.rerank_text_documents),
220+
FunctionTool(self.rerank_image_documents),
221+
FunctionTool(self.image_query_text_documents),
222+
FunctionTool(self.image_query_image_documents),
223+
]

docs/key_modules/tools.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ CAMEL provides a variety of built-in toolkits that you can use right away. Here'
166166
| GoogleScholarToolkit | A toolkit for retrieving information about authors and their publications from Google Scholar. |
167167
| HumanToolkit | A toolkit for facilitating human-in-the-loop interactions and feedback in AI systems. |
168168
| ImageAnalysisToolkit | A toolkit for comprehensive image analysis and understanding using vision-capable language models. |
169+
| JinaRerankerToolkit | A toolkit for reranking documents (text or images) based on their relevance to a given query using the Jina Reranker model. |
169170
| LinkedInToolkit | A toolkit for LinkedIn operations including creating posts, deleting posts, and retrieving user profile information. |
170171
| MathToolkit | A toolkit for performing basic mathematical operations such as addition, subtraction, and multiplication. |
171172
| MCPToolkit | A toolkit for interacting with external tools using the Model Context Protocol (MCP). |
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14+
15+
from camel.agents import ChatAgent
16+
from camel.models import ModelFactory
17+
from camel.toolkits import JinaRerankerToolkit
18+
from camel.types import ModelPlatformType, ModelType
19+
20+
model = ModelFactory.create(
21+
model_platform=ModelPlatformType.DEFAULT,
22+
model_type=ModelType.DEFAULT,
23+
)
24+
25+
reranker_toolkit = JinaRerankerToolkit(device="cpu")
26+
reranker_tool = reranker_toolkit.get_tools()
27+
28+
agent = ChatAgent(model=model, tools=reranker_tool)
29+
30+
documents = [
31+
"Markdown is a lightweight markup language with plain-text "
32+
"formatting syntax.",
33+
"Python is a high-level, interpreted programming language known for "
34+
"its readability.",
35+
"SLM (Small Language Models) are compact AI models designed for "
36+
"specific tasks.",
37+
"JavaScript is a scripting language primarily used for "
38+
"creating interactive web pages.",
39+
]
40+
41+
query = "How to use markdown with small language models"
42+
43+
response = agent.step(
44+
f"Can you rerank these documents {documents} against the query {query}"
45+
)
46+
print(str(response.info['tool_calls'])[:1000])
47+
""""
48+
===========================================================================
49+
[ToolCallingRecord(tool_name='rerank_text_documents', args={'query': 'How to
50+
use markdown with small language models', 'documents': ['Markdown is a
51+
lightweight markup language with plain-text formatting syntax.', 'Python is a
52+
high-level, interpreted programming language known for its readability.', 'SLM
53+
(Small Language Models) are compact AI models designed for specific tasks.',
54+
'JavaScript is a scripting language primarily used for creating interactive
55+
web pages.'], 'max_length': 1024}, result=[('Markdown is a lightweight markup
56+
language with plain-text formatting syntax.', 0.7915633916854858), ('SLM
57+
(Small Language Models) are compact AI models designed for specific tasks.', 0.
58+
7915633916854858), ('Python is a high-level, interpreted programming language
59+
known for its readability.', 0.43936243653297424), ('JavaScript is a scripting
60+
language primarily used for creating interactive web pages.', 0.
61+
3716837763786316)], tool_call_id='call_JKnuvTO1fUQP7PWhyCSQCK7N')]
62+
===========================================================================
63+
"""

0 commit comments

Comments
 (0)