1414import traceback
1515from concurrent .futures import ProcessPoolExecutor
1616import openai
17+ from llava .action .utils import avion_video_loader , avion_video_render_loader , generate_label_map
18+ import copy
19+ import json
1720
1821
1922client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
@@ -37,14 +40,15 @@ def __init__(self,
3740 root ,
3841 annotation_file ,
3942 clip_length = 4 ,
40- debug = False
43+ debug = False ,
44+ fraction = 0.2
4145 ):
4246 self .root = root
4347 self .annotation_file = annotation_file
4448 self .clip_length = clip_length
4549 self .debug = debug
4650 self .question_type = 'gpt-gt-reason'
47-
51+ self . fraction = fraction
4852 self .data = self .init_data ()
4953
5054 print (len (self .data ))
@@ -56,10 +60,11 @@ def select_train_subset(self):
5660 header = csv_reader [0 ] # Get header
5761 data = csv_reader [1 :] # Get data
5862 N = len (data )
59- print ( 'N' , N )
63+
6064 # get a random subset of the data such as 20% of them. Give the indices
6165 random .seed (0 )
62- indices = random .sample (range (N ), int (N * 0.2 ))
66+ indices = random .sample (range (N ), int (N * self .fraction ))
67+ print ('indices' , len (indices ))
6368 return indices
6469
6570 def init_data (self ):
@@ -91,11 +96,14 @@ def init_data(self):
9196 count += 1
9297 return ret
9398
94- def multi_process_run (self , n_samples = - 1 ):
99+
100+ def multi_process_run (self , n_samples = - 1 , filename = 'inference_results.json' ):
95101 # to initialize it
96102
97103 if n_samples != - 1 :
98104 indices = list (range (len (self .data )))[:n_samples ]
105+ else :
106+ indices = list (range (len (self .data )))
99107
100108 num_chunks = os .cpu_count () if not self .debug else 1
101109
@@ -114,20 +122,20 @@ def multi_process_run(self, n_samples = -1):
114122 if self .debug :
115123 print (combined_results )
116124
125+ self .checkpoint (combined_results , filename )
117126
118127 def predict_images (self , images , parsed_item ):
119128 """
120129 Predict the action from the images
121130 """
122131 from llava .action .utils import format_task_related_prompt
123- options = parsed_item ['options' ]
124132 start_second = 0
125133 end_second = parsed_item ['end_second' ] - parsed_item ['start_second' ]
126134 temperature = 0
127135 video_duration = end_second - start_second
128136 n_frames = len (images )
129137
130- task_related_prompt = format_task_related_prompt (options , self .question_type , perspective = 'first_person' )
138+ task_related_prompt = format_task_related_prompt ('' , self .question_type , perspective = 'first_person' )
131139
132140 time_instruction = f"The provided video lasts for { video_duration :.3f} seconds. "
133141
@@ -210,23 +218,67 @@ def run(self, indices = None):
210218
211219 caption = parsed_answer .caption
212220 print ('caption:' , caption )
221+ print ('gt is ' , v ['gt_answer' ])
222+
223+ ret [k ] = copy .deepcopy (v )
224+ ret [k ]['caption' ] = caption
225+
213226
227+
214228 if self .debug :
215229 break
216230
217231 return ret
218-
219-
220-
221-
222-
232+
233+
234+ def create_comparison_data (positive_filename , negative_filename , out_filename ):
235+ """
236+ Create the comparison data
237+ """
238+ ret = []
239+ with open (positive_filename , 'r' ) as f :
240+ positive_data = json .load (f )
241+ with open (negative_filename , 'r' ) as f :
242+ negative_data = json .load (f )
243+
244+ for key in positive_data :
245+ pos_data = positive_data [key ]
246+ neg_data = negative_data [key ]
247+ assert pos_data ['vid_path' ] == neg_data ['vid_path' ]
248+ assert pos_data ['start_second' ] == neg_data ['start_second' ]
249+ template = {
250+ 'id' : pos_data ['vid_path' ].replace ('/' , '-' ),
251+ 'prompt' : '' ,
252+ 'answer' : pos_data ['caption' ],
253+ 'chosen' : pos_data ['caption' ],
254+ 'rejected' : neg_data ['caption' ],
255+ 'video' : pos_data ['vid_path' ].replace ('/' , '-' ),
256+ 'split' : 'train' ,
257+ 'dataset_name' : 'EK100' ,
258+ 'start_timestamp' : pos_data ['start_second' ],
259+ 'end_timestamp' : pos_data ['end_second' ],
260+ 'num_samples' : 1 ,
261+ 'question_type' : 'dpo' ,
262+ 'task_instruction' : '' ,
263+ }
264+ ret .append (template )
265+
266+ # save to jsonl
267+ with open (out_filename , 'w' ) as f :
268+ for item in ret :
269+ f .write (json .dumps (item ) + '\n ' )
223270
224271if __name__ == '__main__' :
225272 video_root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
226273 anno_root = '/data/shaokai/epic-kitchens-100-annotations/'
227274 clip_length = 8
228275
229- cap = CaptionInference (video_root , os .path .join (anno_root , 'EPIC_100_train.csv' ), clip_length , debug = True )
230-
231- #cap.multi_process_run(n_samples = 2)
232- cap .run ()
276+ # cap = CaptionInference(video_root,
277+ # os.path.join(anno_root, 'EPIC_100_train.csv'),
278+ # clip_length,
279+ # debug = False,
280+ # fraction = 0.01)
281+ # cap.multi_process_run(n_samples = -1, filename = f'gpt4o_inference_{clip_length}frame_1percent.json')
282+
283+
284+ create_comparison_data ('gpt4o_inference_8frame_1percent.json' , 'gpt4o_inference_1frame_1percent.json' , 'comparison_data_1percent.jsonl' )
0 commit comments