Skip to content

Commit 0d0af3c

Browse files
Resolve "$def" for nested tool arg type (stanfordnlp#8534)
* resolve schema for nested type * add testing * fix test
1 parent 6ead3da commit 0d0af3c

File tree

2 files changed

+47
-34
lines changed

2 files changed

+47
-34
lines changed

dspy/adapters/types/tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _parse_function(self, func: Callable, arg_desc: dict[str, str] | None = None
103103
v_json_schema = _resolve_json_schema_reference(v.model_json_schema())
104104
args[k] = v_json_schema
105105
else:
106-
args[k] = TypeAdapter(v).json_schema()
106+
args[k] = _resolve_json_schema_reference(TypeAdapter(v).json_schema())
107107
if default_values[k] is not inspect.Parameter.empty:
108108
args[k]["default"] = default_values[k]
109109
if arg_desc and k in arg_desc:

tests/adapters/test_tool.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ class UserProfile(BaseModel):
5050
tags: list[str] = []
5151

5252

53-
def complex_dummy_function(profile: UserProfile, priority: int, notes: str | None = None) -> dict[str, Any]:
53+
class Note(BaseModel):
54+
content: str
55+
author: str
56+
57+
58+
def complex_dummy_function(profile: UserProfile, priority: int, notes: list[Note] | None = None) -> dict[str, Any]:
5459
"""Process user profile with complex nested structure.
5560
5661
Args:
@@ -89,7 +94,9 @@ async def async_dummy_with_pydantic(model: DummyModel) -> str:
8994

9095

9196
async def async_complex_dummy_function(
92-
profile: UserProfile, priority: int, notes: str | None = None
97+
profile: UserProfile,
98+
priority: int,
99+
notes: list[Note] | None = None,
93100
) -> dict[str, Any]:
94101
"""Process user profile with complex nested structure asynchronously.
95102
@@ -167,6 +174,7 @@ def test_tool_from_function_with_pydantic_nesting():
167174
tool = Tool(complex_dummy_function)
168175

169176
assert tool.name == "complex_dummy_function"
177+
170178
assert "profile" in tool.args
171179
assert "priority" in tool.args
172180
assert "notes" in tool.args
@@ -177,6 +185,13 @@ def test_tool_from_function_with_pydantic_nesting():
177185
assert tool.args["profile"]["properties"]["contact"]["type"] == "object"
178186
assert tool.args["profile"]["properties"]["contact"]["properties"]["email"]["type"] == "string"
179187

188+
# Reference should be resolved for nested pydantic models
189+
assert "$defs" not in str(tool.args["notes"])
190+
assert tool.args["notes"]["anyOf"][0]["type"] == "array"
191+
assert tool.args["notes"]["anyOf"][0]["items"]["type"] == "object"
192+
assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["content"]["type"] == "string"
193+
assert tool.args["notes"]["anyOf"][0]["items"]["properties"]["author"]["type"] == "string"
194+
180195

181196
def test_tool_callable():
182197
tool = Tool(dummy_function)
@@ -319,11 +334,11 @@ async def test_async_tool_with_complex_pydantic():
319334
),
320335
)
321336

322-
result = await tool.acall(profile=profile, priority=1, notes="Test note")
337+
result = await tool.acall(profile=profile, priority=1, notes=[Note(content="Test note", author="Test author")])
323338
assert result["user_id"] == 1
324339
assert result["name"] == "Test User"
325340
assert result["priority"] == 1
326-
assert result["notes"] == "Test note"
341+
assert result["notes"] == [Note(content="Test note", author="Test author")]
327342
assert result["primary_address"]["street"] == "123 Main St"
328343

329344

@@ -382,42 +397,39 @@ def test_async_tool_call_in_sync_mode():
382397
([], [{"type": "tool_calls", "tool_calls": []}]),
383398
(
384399
[{"name": "search", "args": {"query": "hello"}}],
385-
[{
386-
"type": "tool_calls",
387-
"tool_calls": [{
388-
"type": "function",
389-
"function": {"name": "search", "arguments": {"query": "hello"}}
390-
}]
391-
}],
400+
[
401+
{
402+
"type": "tool_calls",
403+
"tool_calls": [{"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}}],
404+
}
405+
],
392406
),
393407
(
394408
[
395409
{"name": "search", "args": {"query": "hello"}},
396-
{"name": "translate", "args": {"text": "world", "lang": "fr"}}
410+
{"name": "translate", "args": {"text": "world", "lang": "fr"}},
411+
],
412+
[
413+
{
414+
"type": "tool_calls",
415+
"tool_calls": [
416+
{"type": "function", "function": {"name": "search", "arguments": {"query": "hello"}}},
417+
{
418+
"type": "function",
419+
"function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}},
420+
},
421+
],
422+
}
397423
],
398-
[{
399-
"type": "tool_calls",
400-
"tool_calls": [
401-
{
402-
"type": "function",
403-
"function": {"name": "search", "arguments": {"query": "hello"}}
404-
},
405-
{
406-
"type": "function",
407-
"function": {"name": "translate", "arguments": {"text": "world", "lang": "fr"}}
408-
}
409-
]
410-
}],
411424
),
412425
(
413426
[{"name": "get_time", "args": {}}],
414-
[{
415-
"type": "tool_calls",
416-
"tool_calls": [{
417-
"type": "function",
418-
"function": {"name": "get_time", "arguments": {}}
419-
}]
420-
}],
427+
[
428+
{
429+
"type": "tool_calls",
430+
"tool_calls": [{"type": "function", "function": {"name": "get_time", "arguments": {}}}],
431+
}
432+
],
421433
),
422434
]
423435

@@ -431,11 +443,12 @@ def test_tool_calls_format_basic(tool_calls_data, expected):
431443

432444
assert result == expected
433445

446+
434447
def test_tool_calls_format_from_dict_list():
435448
"""Test format works with ToolCalls created from from_dict_list."""
436449
tool_calls_dicts = [
437450
{"name": "search", "args": {"query": "hello"}},
438-
{"name": "translate", "args": {"text": "world", "lang": "fr"}}
451+
{"name": "translate", "args": {"text": "world", "lang": "fr"}},
439452
]
440453

441454
tool_calls = ToolCalls.from_dict_list(tool_calls_dicts)

0 commit comments

Comments
 (0)