Skip to content

Commit e0055ca

Browse files
make version check loose (stanfordnlp#1946)
1 parent 422973a commit e0055ca

File tree

4 files changed

+114
-5
lines changed

4 files changed

+114
-5
lines changed

dspy/primitives/module.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,7 @@ def save(self, path, save_program=False):
219219
with open(path, "wb") as f:
220220
cloudpickle.dump(state, f)
221221
else:
222-
raise ValueError(
223-
f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}"
224-
)
222+
raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")
225223

226224
def load(self, path):
227225
"""Load the saved module. You may also want to check out dspy.load, if you want to

dspy/utils/saving.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111

1212
def get_dependency_versions():
13+
cloudpickle_version = '.'.join(cloudpickle.__version__.split('.')[:2])
1314
return {
14-
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
15+
"python": f"{sys.version_info.major}.{sys.version_info.minor}",
1516
"dspy": importlib_metadata.version("dspy"),
16-
"cloudpickle": cloudpickle.__version__,
17+
"cloudpickle": cloudpickle_version,
1718
}
1819

1920

tests/primitives/test_module.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dspy
22
import threading
33
from dspy.utils.dummies import DummyLM
4+
import logging
5+
from unittest.mock import patch
46

57

68
def test_deepcopy_basic():
@@ -106,3 +108,56 @@ def dummy_metric(example, pred, trace=None):
106108

107109
assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature)
108110
assert new_cot.predict.demos == compiled_cot.predict.demos
111+
112+
113+
def test_load_with_version_mismatch(tmp_path):
114+
from dspy.primitives.module import logger
115+
116+
# Mock versions during save
117+
save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}
118+
119+
# Mock versions during load
120+
load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}
121+
122+
predict = dspy.Predict("question->answer")
123+
124+
# Create a custom handler to capture log messages
125+
class ListHandler(logging.Handler):
126+
def __init__(self):
127+
super().__init__()
128+
self.messages = []
129+
130+
def emit(self, record):
131+
self.messages.append(record.getMessage())
132+
133+
# Add handler and set level
134+
handler = ListHandler()
135+
original_level = logger.level
136+
logger.addHandler(handler)
137+
logger.setLevel(logging.WARNING)
138+
139+
try:
140+
save_path = tmp_path / "program.pkl"
141+
# Mock version during save
142+
with patch("dspy.primitives.module.get_dependency_versions", return_value=save_versions):
143+
predict.save(save_path)
144+
145+
# Mock version during load
146+
with patch("dspy.primitives.module.get_dependency_versions", return_value=load_versions):
147+
loaded_predict = dspy.Predict("question->answer")
148+
loaded_predict.load(save_path)
149+
150+
# Assert warnings were logged, and one warning for each mismatched dependency.
151+
assert len(handler.messages) == 3
152+
153+
for msg in handler.messages:
154+
assert "There is a mismatch of" in msg
155+
156+
# Verify the model still loads correctly despite version mismatches
157+
assert isinstance(loaded_predict, dspy.Predict)
158+
assert str(predict.signature) == str(loaded_predict.signature)
159+
160+
finally:
161+
# Clean up: restore original level and remove handler
162+
logger.setLevel(original_level)
163+
logger.removeHandler(handler)

tests/utils/test_saving.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import dspy
22
from dspy.utils import DummyLM
3+
from unittest.mock import patch
4+
import pytest
5+
from dspy.utils.saving import get_dependency_versions
6+
import logging
37

48

59
def test_save_predict(tmp_path):
@@ -74,3 +78,54 @@ def dummy_metric(example, pred, trace=None):
7478
loaded_predict = dspy.load(tmp_path)
7579
assert compiled_predict.demos == loaded_predict.demos
7680
assert compiled_predict.signature == loaded_predict.signature
81+
82+
83+
def test_load_with_version_mismatch(tmp_path):
84+
from dspy.utils.saving import logger
85+
86+
# Mock versions during save
87+
save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}
88+
89+
# Mock versions during load
90+
load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}
91+
92+
predict = dspy.Predict("question->answer")
93+
94+
# Create a custom handler to capture log messages
95+
class ListHandler(logging.Handler):
96+
def __init__(self):
97+
super().__init__()
98+
self.messages = []
99+
100+
def emit(self, record):
101+
self.messages.append(record.getMessage())
102+
103+
# Add handler and set level
104+
handler = ListHandler()
105+
original_level = logger.level
106+
logger.addHandler(handler)
107+
logger.setLevel(logging.WARNING)
108+
109+
try:
110+
# Mock version during save
111+
with patch("dspy.utils.saving.get_dependency_versions", return_value=save_versions):
112+
predict.save(tmp_path, save_program=True)
113+
114+
# Mock version during load
115+
with patch("dspy.utils.saving.get_dependency_versions", return_value=load_versions):
116+
loaded_predict = dspy.load(tmp_path)
117+
118+
# Assert warnings were logged, and one warning for each mismatched dependency.
119+
assert len(handler.messages) == 3
120+
121+
for msg in handler.messages:
122+
assert "There is a mismatch of" in msg
123+
124+
# Verify the model still loads correctly despite version mismatches
125+
assert isinstance(loaded_predict, dspy.Predict)
126+
assert predict.signature == loaded_predict.signature
127+
128+
finally:
129+
# Clean up: restore original level and remove handler
130+
logger.setLevel(original_level)
131+
logger.removeHandler(handler)

0 commit comments

Comments
 (0)