Skip to content

Commit e7505a1

Browse files
Update chain_of_thought.py (stanfordnlp#8006)
* Update chain_of_thought.py Allow use of custom CoT representations (e.g., List of strings) instead of forcing CoT be a string output. Previously, even if the user passed in a custom definition for the reasoning field, the type of the reasoning field output would be a string. This change ensures that the type of the field is consistent with its annotation, and allows for users to specify a custom type for reasoning without creating a custom field for it. Also, this change introduces a docstring and type hints into the ChainOfThought module. * Update chain_of_thought.py Fixed error due to missing import. * Update chain_of_thought.py Avoid circular import and ensure there are two newlines between imports and initial code. * Update chain_of_thought.py Avoided circular import caused by referencing `dspy.OutputField`
1 parent 48cb347 commit e7505a1

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

dspy/predict/chain_of_thought.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
11
import dspy
22
from dspy.primitives.program import Module
3-
from dspy.signatures.signature import ensure_signature
3+
from dspy.signatures.field import OutputField
4+
from dspy.signatures.signature import ensure_signature, Signature
5+
from pydantic.fields import FieldInfo
6+
from typing import Optional, Union, Type
47

58

69
class ChainOfThought(Module):
7-
def __init__(self, signature, rationale_type=None, **config):
10+
11+
def __init__(
12+
self,
13+
signature: Type[Signature],
14+
rationale_field: Optional[Union[OutputField, FieldInfo]] = None,
15+
rationale_field_type: Type = str,
16+
**config
17+
):
18+
"""
19+
A module that reasons step by step in order to predict the output of a task.
20+
21+
Args:
22+
signature (Type[dspy.Signature]): The signature of the module.
23+
rationale_field (Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]]): The field that will contain the reasoning.
24+
rationale_field_type (Type): The type of the rationale field.
25+
**config: The configuration for the module.
26+
"""
827
super().__init__()
9-
1028
signature = ensure_signature(signature)
11-
1229
prefix = "Reasoning: Let's think step by step in order to"
1330
desc = "${reasoning}"
14-
rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)
15-
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)
16-
31+
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
32+
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
33+
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
1734
self.predict = dspy.Predict(extended_signature, **config)
1835

1936
def forward(self, **kwargs):

0 commit comments

Comments
 (0)