Skip to content

Commit 7a877d1

Browse files
authored
Use exponential_backoff_retry for completion call (stanfordnlp#8023)
* use exponential_backoff_retry for completion call * add test for exponential_backoff_retry
1 parent 2f2e2f3 commit 7a877d1

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

dspy/clients/lm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
289289
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
290290
retry_kwargs = dict(
291291
retry_policy=_get_litellm_retry_policy(num_retries),
292+
retry_strategy="exponential_backoff_retry",
292293
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
293294
# to completion()), the default value of max_retries is non-zero for certain providers, and
294295
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0

tests/clients/test_lm.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from unittest import mock
22

3+
import time
34
import litellm
45
import pydantic
56
import pytest
7+
from openai import RateLimitError
68

79
import dspy
810
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs
@@ -250,3 +252,22 @@ def test_dump_state():
250252
"launch_kwargs": { "temperature": 1 },
251253
"train_kwargs": { "temperature": 5 },
252254
}
255+
256+
257+
def test_exponential_backoff_retry():
258+
time_counter = []
259+
def mock_create(*args, **kwargs):
260+
time_counter.append(time.time())
261+
# These fields are called during the error handling
262+
mock_response = mock.Mock()
263+
mock_response.headers = {}
264+
mock_response.status_code = 429
265+
raise RateLimitError(response=mock_response, message="message", body="error")
266+
lm = dspy.LM(model='openai/gpt-3.5-turbo', max_tokens=250, num_retries=3)
267+
with mock.patch.object(litellm.OpenAIChatCompletion, "completion", side_effect=mock_create):
268+
with pytest.raises(RateLimitError):
269+
lm("question")
270+
271+
# The first retry happens immediately regardless of the configuration
272+
for i in range(1, len(time_counter)-1):
273+
assert time_counter[i+1] - time_counter[i] >= 2**(i-1)

0 commit comments

Comments
 (0)