Skip to content

feat(wren-ai-service): add context window size handling in LLMProvider and related components #1693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,21 @@ data:
models:
- alias: default
model: gpt-4.1-nano-2025-04-14
context_window_size: 1000000
kwargs:
max_tokens: 4096
n: 1
seed: 0
temperature: 0
- model: gpt-4.1-mini-2025-04-14
context_window_size: 1000000
kwargs:
max_tokens: 4096
n: 1
seed: 0
temperature: 0
- model: gpt-4.1-2025-04-14
context_window_size: 1000000
kwargs:
max_tokens: 4096
n: 1
Expand Down
3 changes: 3 additions & 0 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ timeout: 120
models:
- alias: default
model: gpt-4.1-nano-2025-04-14
context_window_size: 1000000
kwargs:
max_tokens: 4096
n: 1
seed: 0
temperature: 0
- model: gpt-4.1-mini-2025-04-14
context_window_size: 1000000
kwargs:
max_tokens: 4096
n: 1
seed: 0
temperature: 0
- model: gpt-4.1-2025-04-14
context_window_size: 1000000
kwargs:
max_tokens: 4096
n: 1
Expand Down
3 changes: 3 additions & 0 deletions wren-ai-service/src/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def get_model(self):
def get_model_kwargs(self):
return self._model_kwargs

def get_context_window_size(self):
return self._context_window_size


class EmbedderProvider(metaclass=ABCMeta):
@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def check_using_db_schemas_without_pruning(
dbschema_retrieval: list[Document],
encoding: tiktoken.Encoding,
enable_column_pruning: bool,
context_window_size: int,
) -> dict:
retrieval_results = []
has_calculated_field = False
Expand Down Expand Up @@ -269,7 +270,7 @@ def check_using_db_schemas_without_pruning(
retrieval_result["table_ddl"] for retrieval_result in retrieval_results
]
_token_count = len(encoding.encode(" ".join(table_ddls)))
if _token_count > 100_000 or enable_column_pruning:
if _token_count > context_window_size or enable_column_pruning:
return {
"db_schemas": [],
"tokens": _token_count,
Expand Down Expand Up @@ -465,6 +466,7 @@ def __init__(

self._configs = {
"encoding": _encoding,
"context_window_size": llm_provider.get_context_window_size(),
}

super().__init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def preprocess(
sql_data: Dict,
encoding: tiktoken.Encoding,
context_window_size: int,
) -> Dict:
def reduce_data_size(data: list, reduction_step: int = 50) -> list:
"""Reduce the size of data by removing elements from the end.
Expand Down Expand Up @@ -48,8 +49,8 @@ def reduce_data_size(data: list, reduction_step: int = 50) -> list:
_token_count = len(encoding.encode(str(sql_data)))
num_rows_used_in_llm = len(sql_data.get("data", []))
iteration = 0

while _token_count > 100_000:
while _token_count > context_window_size:
if iteration > 1000:
"""
Avoid infinite loop
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(

self._configs = {
"encoding": _encoding,
"context_window_size": llm_provider.get_context_window_size(),
}

super().__init__(Driver({}, sys.modules[__name__], adapter=base.DictResult()))
Expand Down
9 changes: 7 additions & 2 deletions wren-ai-service/src/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def llm_processor(entry: dict) -> dict:
"n": 1,
"max_tokens": 4096,
"response_format": {"type": "json_object"}
}
},
"context_window_size": 100000
}
],
"api_base": "https://api.openai.com/v1"
Expand All @@ -52,6 +53,7 @@ def llm_processor(entry: dict) -> dict:
"max_tokens": 4096,
"response_format": {"type": "json_object"}
},
"context_window_size": 100000,
"api_base": "https://api.openai.com/v1"
}
}
Expand All @@ -70,12 +72,15 @@ def llm_processor(entry: dict) -> dict:
for model in entry.get("models", []):
model_name = f"{entry.get('provider')}.{model.get('alias', model.get('model'))}"
model_additional_params = {
k: v for k, v in model.items() if k not in ["model", "kwargs", "alias"]
k: v
for k, v in model.items()
if k not in ["model", "kwargs", "alias", "context_window_size"]
}
returned[model_name] = {
"provider": entry["provider"],
"model": model["model"],
"kwargs": model["kwargs"],
"context_window_size": model.get("context_window_size", 100000),
**model_additional_params,
**others,
}
Expand Down
8 changes: 6 additions & 2 deletions wren-ai-service/src/providers/llm/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
connect_chunks,
)
from src.providers.loader import provider
from src.utils import remove_trailing_slash, extract_braces_content
from src.utils import extract_braces_content, remove_trailing_slash


@provider("litellm_llm")
Expand All @@ -31,6 +31,7 @@ def __init__(
api_version: Optional[str] = None,
kwargs: Optional[Dict[str, Any]] = None,
timeout: float = 120.0,
context_window_size: int = 100000,
**_,
):
self._model = model
Expand All @@ -39,6 +40,7 @@ def __init__(
self._api_version = api_version
self._model_kwargs = kwargs
self._timeout = timeout
self._context_window_size = context_window_size

def get_generator(
self,
Expand Down Expand Up @@ -113,7 +115,9 @@ async def _run(
check_finish_reason(response)

return {
"replies": [extract_braces_content(message.content) for message in completions],
"replies": [
extract_braces_content(message.content) for message in completions
],
"meta": [message.meta for message in completions],
}

Expand Down
Loading