1111# See the License for the specific language governing permissions and
1212# limitations under the License.
1313# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14+ import argparse
1415import concurrent .futures
1516import itertools
1617import json
1718import os
1819import random
19- from typing import Dict
20+ from typing import Dict , Tuple
2021
2122import numpy as np
2223
2526from camel .prompts import SolutionExtractionPromptTemplateDict
2627from camel .typing import ModelType , RoleType
2728
29+ parser = argparse .ArgumentParser (
30+ description = 'Arguments for conversation summarization.' )
31+ parser .add_argument ('--json_dir' , type = str ,
32+ help = 'Directory containing original json files' ,
33+ default = '../camel/camel_data/ai_society' )
34+ parser .add_argument (
35+ '--solution_dir' , type = str , help = 'Directory for solution json files' ,
36+ default = '../camel/camel_data/ai_society_solution_extraction' )
37+ parser .add_argument ('--seed' , type = int , help = 'Seed for reproducibility' ,
38+ default = 10 )
39+
2840
2941def flatten_conversation (conversation : Dict ) -> str :
3042 r"""Format a conversation into a string.
@@ -64,6 +76,7 @@ def flatten_conversation(conversation: Dict) -> str:
6476 """
6577
6678 num_messages = conversation ['num_messages' ]
79+ assert num_messages >= 2
6780 role_1 = conversation ['message_1' ]['role_name' ]
6881 role_2 = conversation ['message_2' ]['role_name' ]
6982 task = conversation ['specified_task' ]
@@ -87,7 +100,7 @@ def flatten_conversation(conversation: Dict) -> str:
87100 return formatted_data
88101
89102
90- def format_combination (combination ):
103+ def format_combination (combination : Tuple [ int , int , int ] ):
91104 assistant_role , user_role , task = combination
92105 assistant_role = str (assistant_role ).zfill (3 )
93106 user_role = str (user_role ).zfill (3 )
@@ -96,7 +109,7 @@ def format_combination(combination):
96109
97110
98111def solution_extraction (conversation : Dict , flattened_conversation : str ,
99- file_name : str ) -> None :
112+ file_name : str , args : argparse . Namespace ) -> None :
100113
101114 solution_extraction_template = SolutionExtractionPromptTemplateDict ()
102115 assistant_sys_msg_prompt = solution_extraction_template [RoleType .ASSISTANT ]
@@ -115,24 +128,22 @@ def solution_extraction(conversation: Dict, flattened_conversation: str,
115128 print (assistant_response .msg .content )
116129
117130 # Create folder to write solution_extraction to
118- if not os .path .exists (
119- "../camel/camel_data/ai_society_solution_extraction" ):
120- os .makedirs ("../camel/camel_data/ai_society_solution_extraction" )
131+ if not os .path .exists (args .solution_dir ):
132+ os .makedirs (args .solution_dir )
121133
122134 # Append to the original JSON conversation file
123135 conversation ['solution_extraction' ] = assistant_response .msg .content
124136
125137 # Save new dictionary as JSON file
126- with open ( f"./camel_data/ai_society_solution_extraction/ { file_name } .json" ,
127- "w" ) as f :
138+ save_path = os . path . join ( args . solution_dir , f' { file_name } .json' )
139+ with open ( save_path , "w" ) as f :
128140 json .dump (conversation , f )
129141
130142
131143def main ():
132-
133- # Seed for reproducibility
134- np .random .seed (10 )
135- random .seed (10 )
144+ args = parser .parse_args ()
145+ np .random .seed (args .seed )
146+ random .seed (args .seed )
136147
137148 total_num_assistant_roles = 50
138149 total_num_user_roles = 50
@@ -165,27 +176,23 @@ def main():
165176 format_combination (combination ) for combination in file_names
166177 ]
167178
168- # Directory containing original json files
169- # we want to extract solutions from AI Society dataset
170- json_dir = "../camel/camel_data/ai_society/"
171-
172179 # Check that all files exist
173180 for file_name in file_names :
174- json_file = os .path .join (json_dir , file_name + " .json" )
181+ json_file = os .path .join (args . json_dir , f" { file_name } .json" )
175182 if not os .path .exists (json_file ):
176183 raise ValueError (f"File { json_file } does not exist." )
177184
178185 # Read in json files and extract solutions
179186 with concurrent .futures .ProcessPoolExecutor (max_workers = 16 ) as executor :
180187 futures = []
181188 for file_name in file_names :
182- json_file = os .path .join (json_dir , file_name + " .json" )
189+ json_file = os .path .join (args . json_dir , f" { file_name } .json" )
183190 with open (json_file ) as f :
184191 conversation = json .load (f )
185192 flattened_conversation = flatten_conversation (conversation )
186193 futures .append (
187194 executor .submit (solution_extraction , conversation ,
188- flattened_conversation , file_name ))
195+ flattened_conversation , file_name , args ))
189196
190197 for future in concurrent .futures .as_completed (futures ):
191198 try :
0 commit comments