Skip to content

Commit c30b4aa

Browse files
authored
Add gpt-3.5-turbo-16k support (camel-ai#171)
1 parent 1eeda03 commit c30b4aa

File tree

11 files changed

+125
-62
lines changed

11 files changed

+125
-62
lines changed

camel/agents/chat_agent.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
from camel.messages import ChatMessage, MessageType, SystemMessage
2424
from camel.models import BaseModelBackend, ModelFactory
2525
from camel.typing import ModelType, RoleType
26-
from camel.utils import (
27-
get_model_token_limit,
28-
num_tokens_from_messages,
29-
openai_api_key_required,
30-
)
26+
from camel.utils import num_tokens_from_messages, openai_api_key_required
3127

3228

3329
@dataclass(frozen=True)
@@ -84,11 +80,11 @@ def __init__(
8480
self.model: ModelType = (model if model is not None else
8581
ModelType.GPT_3_5_TURBO)
8682
self.model_config: ChatGPTConfig = model_config or ChatGPTConfig()
87-
self.model_token_limit: int = get_model_token_limit(self.model)
8883
self.message_window_size: Optional[int] = message_window_size
8984

9085
self.model_backend: BaseModelBackend = ModelFactory.create(
9186
self.model, self.model_config.__dict__)
87+
self.model_token_limit: int = self.model_backend.token_limit
9288

9389
self.terminated: bool = False
9490
self.init_messages()

camel/models/base_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,24 @@
1414
from abc import ABC, abstractmethod
1515
from typing import Any, Dict, List
1616

17+
from camel.typing import ModelType
18+
1719

1820
class BaseModelBackend(ABC):
1921
r"""Base class for different model backends.
2022
May be OpenAI API, a local LLM, a stub for unit tests, etc."""
2123

24+
def __init__(self, model_type: ModelType,
25+
model_config_dict: Dict[str, Any]) -> None:
26+
r"""Constructor for the model backend.
27+
28+
Args:
29+
model_type (ModelType): Model for which a backend is created.
30+
model_config_dict (Dict[str, Any]): A config dictionary.
31+
"""
32+
self.model_type = model_type
33+
self.model_config_dict = model_config_dict
34+
2235
@abstractmethod
2336
def run(self, messages: List[Dict]) -> Dict[str, Any]:
2437
r"""Runs the query to the backend model.
@@ -35,3 +48,11 @@ def run(self, messages: List[Dict]) -> Dict[str, Any]:
3548
Dict[str, Any]: All backends must return a dict in OpenAI format.
3649
"""
3750
pass
51+
52+
@property
53+
def token_limit(self) -> int:
54+
r"""Returns the maximum token limit for a given model.
55+
Returns:
56+
int: The maximum token limit for the given model.
57+
"""
58+
return self.model_type.token_limit

camel/models/model_factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ def create(model_type: ModelType,
4242
"""
4343
model_class: Any
4444
if model_type in {
45-
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k
45+
ModelType.GPT_3_5_TURBO,
46+
ModelType.GPT_3_5_TURBO_16K,
47+
ModelType.GPT_4,
48+
ModelType.GPT_4_32k,
4649
}:
4750
model_class = OpenAIModel
4851
elif model_type == ModelType.STUB:

camel/models/openai_model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,16 @@ def __init__(self, model_type: ModelType,
2828
Args:
2929
model_type (ModelType): Model for which a backend is created,
3030
one of GPT_* series.
31-
model_config_dict (Dict[str, Any]): a dictionary that will
31+
model_config_dict (Dict[str, Any]): A dictionary that will
3232
be fed into openai.ChatCompletion.create().
3333
"""
34-
super().__init__()
35-
self.model_type = model_type
36-
self.model_config_dict = model_config_dict
34+
super().__init__(model_type, model_config_dict)
3735

3836
def run(self, messages: List[Dict]) -> Dict[str, Any]:
3937
r"""Run inference of OpenAI chat completion.
4038
4139
Args:
42-
messages (List[Dict]): message list with the chat history
40+
messages (List[Dict]): Message list with the chat history
4341
in OpenAI API format.
4442
4543
Returns:

camel/models/stub_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
from typing import Any, Dict, List
1515

1616
from camel.models import BaseModelBackend
17+
from camel.typing import ModelType
1718

1819

1920
class StubModel(BaseModelBackend):
2021
r"""A dummy model used for unit tests."""
22+
model_type = ModelType.STUB
2123

22-
def __init__(self, *args, **kwargs) -> None:
24+
def __init__(self, model_type: ModelType,
25+
model_config_dict: Dict[str, Any]) -> None:
2326
r"""All arguments are unused for the dummy model."""
24-
super().__init__()
27+
pass
2528

2629
def run(self, messages: List[Dict]) -> Dict[str, Any]:
2730
r"""Run fake inference by returning a fixed string.

camel/typing.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,34 @@ class RoleType(Enum):
2424

2525
class ModelType(Enum):
2626
GPT_3_5_TURBO = "gpt-3.5-turbo"
27+
GPT_3_5_TURBO_16K = "gpt-3.5-turbo-16k"
2728
GPT_4 = "gpt-4"
2829
GPT_4_32k = "gpt-4-32k"
2930
STUB = "stub"
3031

3132
@property
32-
def value_for_tiktoken(self):
33+
def value_for_tiktoken(self) -> str:
3334
return self.value if self.name != "STUB" else "gpt-3.5-turbo"
3435

36+
@property
37+
def token_limit(self) -> int:
38+
r"""Returns the maximum token limit for a given model.
39+
Returns:
40+
int: The maximum token limit for the given model.
41+
"""
42+
if self is ModelType.GPT_3_5_TURBO:
43+
return 4096
44+
elif self is ModelType.GPT_3_5_TURBO_16K:
45+
return 16384
46+
elif self is ModelType.GPT_4:
47+
return 8192
48+
elif self is ModelType.GPT_4_32k:
49+
return 32768
50+
elif self is ModelType.STUB:
51+
return 4096
52+
else:
53+
raise ValueError("Unknown model type")
54+
3555

3656
class TaskType(Enum):
3757
AI_SOCIETY = "ai_society"

camel/utils.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,28 @@
2929

3030
def count_tokens_openai_chat_models(
3131
messages: List[OpenAIMessage],
32-
encoding: Any,
32+
encoding: tiktoken.Encoding,
33+
tokens_per_message: int,
34+
tokens_per_name: int,
3335
) -> int:
3436
r"""Counts the number of tokens required to generate an OpenAI chat based
3537
on a given list of messages.
3638
3739
Args:
3840
messages (List[OpenAIMessage]): The list of messages.
39-
encoding (Any): The encoding method to use.
41+
encoding (tiktoken.Encoding): The encoding method to use.
4042
4143
Returns:
4244
int: The number of tokens required.
4345
"""
4446
num_tokens = 0
4547
for message in messages:
46-
# message follows <im_start>{role/name}\n{content}<im_end>\n
47-
num_tokens += 4
48+
num_tokens += tokens_per_message
4849
for key, value in message.items():
4950
num_tokens += len(encoding.encode(value))
5051
if key == "name": # if there's a name, the role is omitted
51-
num_tokens += -1 # role is always 1 token
52-
num_tokens += 2 # every reply is primed with <im_start>assistant
52+
num_tokens += tokens_per_name
53+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
5354
return num_tokens
5455

5556

@@ -81,11 +82,26 @@ def num_tokens_from_messages(
8182
except KeyError:
8283
encoding = tiktoken.get_encoding("cl100k_base")
8384

84-
if model in {
85-
ModelType.GPT_3_5_TURBO, ModelType.GPT_4, ModelType.GPT_4_32k,
86-
ModelType.STUB
87-
}:
88-
return count_tokens_openai_chat_models(messages, encoding)
85+
if model.value_for_tiktoken.startswith("gpt-3.5-turbo"):
86+
# Every message follows <|start|>{role/name}\n{content}<|end|>\n
87+
tokens_per_message = 4
88+
# If there's a name, the role is omitted
89+
tokens_per_name = -1
90+
return count_tokens_openai_chat_models(
91+
messages,
92+
encoding,
93+
tokens_per_message,
94+
tokens_per_name,
95+
)
96+
elif model.value_for_tiktoken.startswith("gpt-4"):
97+
tokens_per_message = 3
98+
tokens_per_name = 1
99+
return count_tokens_openai_chat_models(
100+
messages,
101+
encoding,
102+
tokens_per_message,
103+
tokens_per_name,
104+
)
89105
else:
90106
raise NotImplementedError(
91107
f"`num_tokens_from_messages`` is not presently implemented "
@@ -97,27 +113,6 @@ def num_tokens_from_messages(
97113
f"for information about openai chat models.")
98114

99115

100-
def get_model_token_limit(model: ModelType) -> int:
101-
r"""Returns the maximum token limit for a given model.
102-
103-
Args:
104-
model (ModelType): The type of the model.
105-
106-
Returns:
107-
int: The maximum token limit for the given model.
108-
"""
109-
if model == ModelType.GPT_3_5_TURBO:
110-
return 4096
111-
elif model == ModelType.GPT_4:
112-
return 8192
113-
elif model == ModelType.GPT_4_32k:
114-
return 32768
115-
elif model == ModelType.STUB:
116-
return 4096
117-
else:
118-
raise ValueError("Unknown model type")
119-
120-
121116
def openai_api_key_required(func: F) -> F:
122117
r"""Decorator that checks if the OpenAI API key is available in the
123118
environment variables.

test/agents/test_chat_agent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from camel.generators import SystemMessageGenerator
1919
from camel.messages import ChatMessage, SystemMessage
2020
from camel.typing import ModelType, RoleType, TaskType
21-
from camel.utils import get_model_token_limit
2221

2322
parametrize = pytest.mark.parametrize('model', [
2423
ModelType.STUB,
@@ -28,7 +27,7 @@
2827

2928

3029
@parametrize
31-
def test_chat_agent(model):
30+
def test_chat_agent(model: ModelType):
3231

3332
model_config = ChatGPTConfig()
3433
system_msg = SystemMessageGenerator(
@@ -54,7 +53,7 @@ def test_chat_agent(model):
5453
assert assistant_response.info['id'] is not None
5554

5655
assistant.reset()
57-
token_limit = get_model_token_limit(model)
56+
token_limit = assistant.model_token_limit
5857
user_msg = ChatMessage(role_name="Patient", role_type=RoleType.USER,
5958
meta_dict=dict(), role="user",
6059
content="token" * (token_limit + 1))

test/messages/test_message_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_base_message_contains_operator(base_message: BaseMessage):
5151
def test_base_message_token_len(base_message: BaseMessage):
5252
token_len = base_message.token_len()
5353
assert isinstance(token_len, int)
54-
assert token_len == 9
54+
assert token_len == 10
5555

5656

5757
def test_extract_text_and_code_prompts():

test/models/test_model_factory.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
from camel.typing import ModelType
1919

2020
parametrize = pytest.mark.parametrize('model', [
21-
ModelType.STUB,
2221
pytest.param(ModelType.GPT_3_5_TURBO, marks=pytest.mark.model_backend),
22+
pytest.param(ModelType.GPT_3_5_TURBO_16K, marks=pytest.mark.model_backend),
2323
pytest.param(ModelType.GPT_4, marks=pytest.mark.model_backend),
24+
ModelType.STUB,
2425
])
2526

2627

@@ -30,18 +31,12 @@ def test_model_factory(model):
3031
model_inst = ModelFactory.create(model, model_config_dict)
3132
messages = [
3233
{
33-
'role': 'system',
34-
'content': 'You can make a task more specific.'
34+
"role": "system",
35+
"content": "Initialize system",
3536
},
3637
{
37-
'role':
38-
'user',
39-
'content': ('Here is a task that Python Programmer will help '
40-
'Stock Trader to complete: Develop a trading bot '
41-
'for the stock market.\nPlease make it more specific.'
42-
' Be creative and imaginative.\nPlease reply with '
43-
'the specified task in 50 words or less. '
44-
'Do not add anything else.')
38+
"role": "user",
39+
"content": "Hello",
4540
},
4641
]
4742
response = model_inst.run(messages)

0 commit comments

Comments
 (0)