Skip to content

Commit 1e29689

Browse files
Merge pull request stanfordnlp#1086 from pedramsalimi/main
Fix the issue of handling input_keys using Dataset class (Issue stanfordnlp#898)
2 parents 4a958ec + f19fccb commit 1e29689

File tree

3 files changed

+80
-25
lines changed

3 files changed

+80
-25
lines changed

dspy/datasets/dataset.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77

88
class Dataset:
9-
def __init__(self, train_seed=0, train_size=None, eval_seed=0, dev_size=None, test_size=None):
9+
def __init__(self, train_seed=0, train_size=None, eval_seed=0, dev_size=None, test_size=None, input_keys=[]):
1010
self.train_size = train_size
1111
self.train_seed = train_seed
1212
self.dev_size = dev_size
1313
self.dev_seed = eval_seed
1414
self.test_size = test_size
1515
self.test_seed = eval_seed
16+
self.input_keys = input_keys
17+
1618
self.do_shuffle = True
1719

1820
self.name = self.__class__.__name__
@@ -73,8 +75,10 @@ def _shuffle_and_sample(self, split, data, size, seed=0):
7375
output = []
7476

7577
for example in data:
76-
output.append(Example(**example, dspy_uuid=str(uuid.uuid4()), dspy_split=split))
77-
78+
example_obj = Example(**example, dspy_uuid=str(uuid.uuid4()), dspy_split=split)
79+
if self.input_keys:
80+
example_obj = example_obj.with_inputs(*self.input_keys)
81+
output.append(example_obj)
7882
# TODO: NOTE: Ideally we use these uuids for dedup internally, for demos and internal train/val splits.
7983
# Now, some tasks (like convQA and Colors) have overlapping examples. Here, we should allow the user to give us
8084
# a uuid field that would respect this in some way. This means that we need a more refined concept that

dspy/primitives/example.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
class Example:
32
def __init__(self, base=None, **kwargs):
43
# Internal storage and other attributes
@@ -16,20 +15,20 @@ def __init__(self, base=None, **kwargs):
1615

1716
# Update with provided kwargs
1817
self._store.update(kwargs)
19-
18+
2019
def __getattr__(self, key):
21-
if key.startswith('__') and key.endswith('__'):
20+
if key.startswith("__") and key.endswith("__"):
2221
raise AttributeError
2322
if key in self._store:
2423
return self._store[key]
2524
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
2625

2726
def __setattr__(self, key, value):
28-
if key.startswith('_') or key in dir(self.__class__):
27+
if key.startswith("_") or key in dir(self.__class__):
2928
super().__setattr__(key, value)
3029
else:
3130
self._store[key] = value
32-
31+
3332
def __getitem__(self, key):
3433
return self._store[key]
3534

@@ -41,55 +40,58 @@ def __delitem__(self, key):
4140

4241
def __contains__(self, key):
4342
return key in self._store
44-
43+
4544
def __len__(self):
46-
return len([k for k in self._store if not k.startswith('dspy_')])
47-
45+
return len([k for k in self._store if not k.startswith("dspy_")])
46+
4847
def __repr__(self):
4948
# return f"Example({self._store})" + f" (input_keys={self._input_keys}, demos={self._demos})"
50-
d = {k: v for k, v in self._store.items() if not k.startswith('dspy_')}
49+
d = {k: v for k, v in self._store.items() if not k.startswith("dspy_")}
5150
return f"Example({d})" + f" (input_keys={self._input_keys})"
52-
51+
5352
def __str__(self):
5453
return self.__repr__()
55-
54+
5655
def __eq__(self, other):
5756
return isinstance(other, Example) and self._store == other._store
58-
57+
5958
def __hash__(self):
6059
return hash(tuple(self._store.items()))
6160

6261
def keys(self, include_dspy=False):
63-
return [k for k in self._store.keys() if not k.startswith('dspy_') or include_dspy]
64-
62+
return [k for k in self._store.keys() if not k.startswith("dspy_") or include_dspy]
63+
6564
def values(self, include_dspy=False):
66-
return [v for k, v in self._store.items() if not k.startswith('dspy_') or include_dspy]
65+
return [v for k, v in self._store.items() if not k.startswith("dspy_") or include_dspy]
6766

6867
def items(self, include_dspy=False):
69-
return [(k, v) for k, v in self._store.items() if not k.startswith('dspy_') or include_dspy]
68+
return [(k, v) for k, v in self._store.items() if not k.startswith("dspy_") or include_dspy]
7069

7170
def get(self, key, default=None):
7271
return self._store.get(key, default)
73-
72+
7473
def with_inputs(self, *keys):
7574
copied = self.copy()
7675
copied._input_keys = set(keys)
7776
return copied
78-
77+
7978
def inputs(self):
8079
if self._input_keys is None:
8180
raise ValueError("Inputs have not been set for this example. Use `example.with_inputs()` to set them.")
8281

8382
# return items that are in input_keys
8483
d = {key: self._store[key] for key in self._store if key in self._input_keys}
85-
return type(self)(d)
86-
84+
# return type(self)(d)
85+
new_instance = type(self)(base=d)
86+
new_instance._input_keys = self._input_keys # Preserve input_keys in new instance
87+
return new_instance
88+
8789
def labels(self):
8890
# return items that are NOT in input_keys
8991
input_keys = self.inputs().keys()
9092
d = {key: self._store[key] for key in self._store if key not in input_keys}
9193
return type(self)(d)
92-
94+
9395
def __iter__(self):
9496
return iter(dict(self._store))
9597

@@ -101,6 +103,6 @@ def without(self, *keys):
101103
for key in keys:
102104
del copied[key]
103105
return copied
104-
106+
105107
def toDict(self):
106108
return self._store.copy()

tests/datasets/test_dataset.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
import uuid
3+
4+
import pandas as pd
5+
6+
from dspy import Example
7+
from dspy.datasets.dataset import Dataset
8+
9+
dummy_data = """content,question,answer
10+
"This is content 1","What is this?","This is answer 1"
11+
"This is content 2","What is that?","This is answer 2"
12+
"""
13+
14+
with open("dummy.csv", "w") as file:
15+
file.write(dummy_data)
16+
17+
18+
class CSVDataset(Dataset):
19+
def __init__(self, file_path, input_keys=None, *args, **kwargs) -> None:
20+
super().__init__(input_keys=input_keys, *args, **kwargs)
21+
df = pd.read_csv(file_path)
22+
data = df.to_dict(orient="records")
23+
self._train = [
24+
Example(**record, dspy_uuid=str(uuid.uuid4()), dspy_split="train").with_inputs(*input_keys)
25+
for record in data[:1]
26+
]
27+
self._dev = [
28+
Example(**record, dspy_uuid=str(uuid.uuid4()), dspy_split="dev").with_inputs(*input_keys)
29+
for record in data[1:2]
30+
]
31+
32+
33+
class TestCSVDataset(unittest.TestCase):
34+
def test_input_keys(self):
35+
dataset = CSVDataset("dummy.csv", input_keys=["content", "question"])
36+
self.assertIsNotNone(dataset.train)
37+
38+
for example in dataset.train:
39+
print(example)
40+
inputs = example.inputs()
41+
print(f"Example inputs: {inputs}")
42+
self.assertIsNotNone(inputs)
43+
self.assertIn("content", inputs)
44+
self.assertIn("question", inputs)
45+
self.assertEqual(set(example._input_keys), {"content", "question"})
46+
47+
48+
if __name__ == "__main__":
49+
unittest.main()

0 commit comments

Comments
 (0)