Skip to content

Commit f1a7909

Browse files
authored
[Examples] Improve the summarization solution extraction example (camel-ai#164)
1 parent 2e3c771 commit f1a7909

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

examples/summarization/gpt_solution_extraction.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+
import argparse
1415
import concurrent.futures
1516
import itertools
1617
import json
1718
import os
1819
import random
19-
from typing import Dict
20+
from typing import Dict, Tuple
2021

2122
import numpy as np
2223

@@ -25,6 +26,17 @@
2526
from camel.prompts import SolutionExtractionPromptTemplateDict
2627
from camel.typing import ModelType, RoleType
2728

29+
parser = argparse.ArgumentParser(
30+
description='Arguments for conversation summarization.')
31+
parser.add_argument('--json_dir', type=str,
32+
help='Directory containing original json files',
33+
default='../camel/camel_data/ai_society')
34+
parser.add_argument(
35+
'--solution_dir', type=str, help='Directory for solution json files',
36+
default='../camel/camel_data/ai_society_solution_extraction')
37+
parser.add_argument('--seed', type=int, help='Seed for reproducibility',
38+
default=10)
39+
2840

2941
def flatten_conversation(conversation: Dict) -> str:
3042
r"""Format a conversation into a string.
@@ -64,6 +76,7 @@ def flatten_conversation(conversation: Dict) -> str:
6476
"""
6577

6678
num_messages = conversation['num_messages']
79+
assert num_messages >= 2
6780
role_1 = conversation['message_1']['role_name']
6881
role_2 = conversation['message_2']['role_name']
6982
task = conversation['specified_task']
@@ -87,7 +100,7 @@ def flatten_conversation(conversation: Dict) -> str:
87100
return formatted_data
88101

89102

90-
def format_combination(combination):
103+
def format_combination(combination: Tuple[int, int, int]):
91104
assistant_role, user_role, task = combination
92105
assistant_role = str(assistant_role).zfill(3)
93106
user_role = str(user_role).zfill(3)
@@ -96,7 +109,7 @@ def format_combination(combination):
96109

97110

98111
def solution_extraction(conversation: Dict, flattened_conversation: str,
99-
file_name: str) -> None:
112+
file_name: str, args: argparse.Namespace) -> None:
100113

101114
solution_extraction_template = SolutionExtractionPromptTemplateDict()
102115
assistant_sys_msg_prompt = solution_extraction_template[RoleType.ASSISTANT]
@@ -115,24 +128,22 @@ def solution_extraction(conversation: Dict, flattened_conversation: str,
115128
print(assistant_response.msg.content)
116129

117130
# Create folder to write solution_extraction to
118-
if not os.path.exists(
119-
"../camel/camel_data/ai_society_solution_extraction"):
120-
os.makedirs("../camel/camel_data/ai_society_solution_extraction")
131+
if not os.path.exists(args.solution_dir):
132+
os.makedirs(args.solution_dir)
121133

122134
# Append to the original JSON conversation file
123135
conversation['solution_extraction'] = assistant_response.msg.content
124136

125137
# Save new dictionary as JSON file
126-
with open(f"./camel_data/ai_society_solution_extraction/{file_name}.json",
127-
"w") as f:
138+
save_path = os.path.join(args.solution_dir, f'{file_name}.json')
139+
with open(save_path, "w") as f:
128140
json.dump(conversation, f)
129141

130142

131143
def main():
132-
133-
# Seed for reproducibility
134-
np.random.seed(10)
135-
random.seed(10)
144+
args = parser.parse_args()
145+
np.random.seed(args.seed)
146+
random.seed(args.seed)
136147

137148
total_num_assistant_roles = 50
138149
total_num_user_roles = 50
@@ -165,27 +176,23 @@ def main():
165176
format_combination(combination) for combination in file_names
166177
]
167178

168-
# Directory containing original json files
169-
# we want to extract solutions from AI Society dataset
170-
json_dir = "../camel/camel_data/ai_society/"
171-
172179
# Check that all files exist
173180
for file_name in file_names:
174-
json_file = os.path.join(json_dir, file_name + ".json")
181+
json_file = os.path.join(args.json_dir, f"{file_name}.json")
175182
if not os.path.exists(json_file):
176183
raise ValueError(f"File {json_file} does not exist.")
177184

178185
# Read in json files and extract solutions
179186
with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
180187
futures = []
181188
for file_name in file_names:
182-
json_file = os.path.join(json_dir, file_name + ".json")
189+
json_file = os.path.join(args.json_dir, f"{file_name}.json")
183190
with open(json_file) as f:
184191
conversation = json.load(f)
185192
flattened_conversation = flatten_conversation(conversation)
186193
futures.append(
187194
executor.submit(solution_extraction, conversation,
188-
flattened_conversation, file_name))
195+
flattened_conversation, file_name, args))
189196

190197
for future in concurrent.futures.as_completed(futures):
191198
try:

0 commit comments

Comments
 (0)