Skip to content

Commit 4e34fa7

Browse files
added AnswerCorrectness and AnswerFaithfulness auto-eval modules
1 parent e2b7a0c commit 4e34fa7

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

dspy/evaluate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .evaluate import Evaluate
22
from .metrics import *
3+
from .auto_evaluation import *

dspy/evaluate/auto_evaluation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import dspy
2+
3+
class AnswerCorrectnessSignature(dspy.Signature):
4+
"""Determines if predicted answer matches the gold answer."""
5+
6+
question = dspy.InputField()
7+
gold_answer = dspy.InputField(desc="correct answer for question")
8+
predicted_answer = dspy.InputField(desc="predicted answer for question")
9+
is_correct = dspy.OutputField(desc='True or False')
10+
11+
class AnswerCorrectness(dspy.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature)
15+
16+
def forward(self, question, gold_answer, predicted_answer):
17+
return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer)
18+
19+
20+
class AnswerFaithfulnessSignature(dspy.Signature):
21+
"""Checks if answer for question is based on rationale."""
22+
23+
context = dspy.InputField(desc="relevant facts for producing answer")
24+
question = dspy.InputField()
25+
answer = dspy.InputField(desc="often between 1 and 5 words")
26+
faithful = dspy.OutputField(desc='True or False')
27+
28+
class AnswerFaithfulness(dspy.Module):
29+
def __init__(self):
30+
super().__init__()
31+
self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature)
32+
33+
def forward(self, context, question, answer):
34+
return self.evaluate_faithfulness(context=context, question=question, answer=answer)

dspy/signatures/signature.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def __new__(cls, name, bases, class_dict):
3333

3434
return new_class
3535

36+
@property
37+
def kwargs(cls):
38+
return cls.signature.fields
39+
3640
def __call__(cls, *args, **kwargs):
3741
if len(args) == 1 and isinstance(args[0], str):
3842
instance = super(SignatureMeta, cls).__call__(*args, **kwargs)
@@ -42,7 +46,9 @@ def __call__(cls, *args, **kwargs):
4246

4347
def __getattr__(cls, attr):
4448
# Redirect attribute access to the template object when accessed on the class directly
45-
return getattr(cls._template, attr)
49+
if attr not in cls.__dict__:
50+
return getattr(cls._template, attr)
51+
return super().__getattr__(attr)
4652

4753
class Signature(metaclass=SignatureMeta):
4854
def __init__(self, signature: str = "", instructions: str = ""):
@@ -51,6 +57,11 @@ def __init__(self, signature: str = "", instructions: str = ""):
5157
self.fields = {}
5258
self.parse_structure()
5359

60+
def __getattr__(self, attr):
61+
if attr not in self.__dict__:
62+
return getattr(self.__class__, attr)
63+
return super().__getattr__(attr)
64+
5465
@property
5566
def kwargs(self):
5667
return {k: v for k, v in self.fields.items()}

0 commit comments

Comments
 (0)