Skip to content

Commit 2a17d22

Browse files
authored
fix(embeddings): Fix embeddings names (camel-ai#384)
1 parent e826fc3 commit 2a17d22

File tree

4 files changed

+34
-32
lines changed

4 files changed

+34
-32
lines changed

camel/embeddings/base.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,24 @@
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
from abc import ABC, abstractmethod
15-
from typing import Any, List
15+
from typing import Any, Generic, List, TypeVar
1616

17+
T = TypeVar('T')
1718

18-
class BaseEmbedding(ABC):
19+
20+
class BaseEmbedding(ABC, Generic[T]):
1921
r"""Abstract base class for text embedding functionalities."""
2022

2123
@abstractmethod
22-
def embed_texts(
24+
def embed_list(
2325
self,
24-
texts: List[str],
26+
objs: List[T],
2527
**kwargs: Any,
2628
) -> List[List[float]]:
2729
r"""Generates embeddings for the given texts.
2830
2931
Args:
30-
texts (List[str]): The texts for which to generate the embeddings.
32+
objs (List[T]): The objects for which to generate the embeddings.
3133
**kwargs (Any): Extra kwargs passed to the embedding API.
3234
3335
Returns:
@@ -36,22 +38,22 @@ def embed_texts(
3638
"""
3739
pass
3840

39-
def embed_text(
41+
def embed(
4042
self,
41-
text: str,
43+
obj: T,
4244
**kwargs: Any,
4345
) -> List[float]:
4446
r"""Generates an embedding for the given text.
4547
4648
Args:
47-
text (str): The text for which to generate the embedding.
49+
obj (T): The object for which to generate the embedding.
4850
**kwargs (Any): Extra kwargs passed to the embedding API.
4951
5052
Returns:
5153
List[float]: A list of floating-point numbers representing the
5254
generated embedding.
5355
"""
54-
return self.embed_texts([text], **kwargs)[0]
56+
return self.embed_list([obj], **kwargs)[0]
5557

5658
@abstractmethod
5759
def get_output_dim(self) -> int:

camel/embeddings/openai_embedding.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@
2020
from camel.utils import openai_api_key_required
2121

2222

23-
class OpenAIEmbedding(BaseEmbedding):
23+
class OpenAIEmbedding(BaseEmbedding[str]):
2424
r"""Provides text embedding functionalities using OpenAI's models.
2525
2626
Args:
2727
model (OpenAiEmbeddingModel, optional): The model type to be used for
28-
generating embeddings. (default: :obj:`ModelType.ADA2`)
28+
generating embeddings. (default: :obj:`ModelType.ADA_2`)
2929
3030
Raises:
3131
RuntimeError: If an unsupported model type is specified.
3232
"""
3333

3434
def __init__(
3535
self,
36-
model_type: EmbeddingModelType = EmbeddingModelType.ADA2,
36+
model_type: EmbeddingModelType = EmbeddingModelType.ADA_2,
3737
) -> None:
3838
if not model_type.is_openai:
3939
raise ValueError("Invalid OpenAI embedding model type.")
@@ -42,15 +42,15 @@ def __init__(
4242
self.client = OpenAI()
4343

4444
@openai_api_key_required
45-
def embed_texts(
45+
def embed_list(
4646
self,
47-
texts: List[str],
47+
objs: List[str],
4848
**kwargs: Any,
4949
) -> List[List[float]]:
5050
r"""Generates embeddings for the given texts.
5151
5252
Args:
53-
texts (List[str]): The texts for which to generate the embeddings.
53+
objs (List[str]): The texts for which to generate the embeddings.
5454
**kwargs (Any): Extra kwargs passed to the embedding API.
5555
5656
Returns:
@@ -59,7 +59,7 @@ def embed_texts(
5959
"""
6060
# TODO: count tokens
6161
response = self.client.embeddings.create(
62-
input=texts,
62+
input=objs,
6363
model=self.model_type.value,
6464
**kwargs,
6565
)

camel/types/enums.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,34 +114,34 @@ def validate_model_name(self, model_name: str) -> bool:
114114

115115

116116
class EmbeddingModelType(Enum):
117-
ADA2 = "text-embedding-ada-002"
118-
ADA1 = "text-embedding-ada-001"
119-
BABBAGE1 = "text-embedding-babbage-001"
120-
CURIE1 = "text-embedding-curie-001"
121-
DAVINCI1 = "text-embedding-davinci-001"
117+
ADA_2 = "text-embedding-ada-002"
118+
ADA_1 = "text-embedding-ada-001"
119+
BABBAGE_1 = "text-embedding-babbage-001"
120+
CURIE_1 = "text-embedding-curie-001"
121+
DAVINCI_1 = "text-embedding-davinci-001"
122122

123123
@property
124124
def is_openai(self) -> bool:
125125
r"""Returns whether this type of models is an OpenAI-released model."""
126126
return self in {
127-
EmbeddingModelType.ADA1,
128-
EmbeddingModelType.ADA2,
129-
EmbeddingModelType.BABBAGE1,
130-
EmbeddingModelType.CURIE1,
131-
EmbeddingModelType.DAVINCI1,
127+
EmbeddingModelType.ADA_2,
128+
EmbeddingModelType.ADA_1,
129+
EmbeddingModelType.BABBAGE_1,
130+
EmbeddingModelType.CURIE_1,
131+
EmbeddingModelType.DAVINCI_1,
132132
}
133133

134134
@property
135135
def output_dim(self) -> int:
136-
if self is EmbeddingModelType.ADA2:
136+
if self is EmbeddingModelType.ADA_2:
137137
return 1536
138-
elif self is EmbeddingModelType.ADA1:
138+
elif self is EmbeddingModelType.ADA_1:
139139
return 1024
140-
elif self is EmbeddingModelType.BABBAGE1:
140+
elif self is EmbeddingModelType.BABBAGE_1:
141141
return 2048
142-
elif self is EmbeddingModelType.CURIE1:
142+
elif self is EmbeddingModelType.CURIE_1:
143143
return 4096
144-
elif self is EmbeddingModelType.DAVINCI1:
144+
elif self is EmbeddingModelType.DAVINCI_1:
145145
return 12288
146146
else:
147147
raise ValueError(f"Unknown model type {self}.")

test/embeddings/test_openai_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919
@pytest.mark.parametrize("embedding_model", [OpenAIEmbedding()])
2020
def test_embedding(embedding_model: BaseEmbedding):
2121
text = "test embedding text."
22-
vector = embedding_model.embed_text(text)
22+
vector = embedding_model.embed(text)
2323
assert len(vector) == embedding_model.get_output_dim()

0 commit comments

Comments
 (0)