1414from tqdm import tqdm
1515import csv
1616import llava
17- from llava .action .utils import avion_video_loader , create_multi_choice_from_avion_predictions , generate_label_map , AvionMultiChoiceGenerator , avion_video_render_loader
17+ from llava .action .utils import avion_video_loader , generate_label_map , AvionMultiChoiceGenerator , avion_video_render_loader
1818from llava .action .dataset import datetime2sec
1919
2020client = openai .OpenAI (api_key = os .environ .get ("OPENAI_API_KEY" ))
@@ -197,8 +197,22 @@ def prepare_multiple_images(self, images):
197197 return multi_image_content
198198
199199 def extract_frames (self , vid_path , start_second , end_second ):
200- frames , time_meta = avion_video_render_loader (self .root , self .handobj_root ,
201- # frames, time_meta = avion_video_loader(self.root,
200+ video_loader = avion_video_loader if self .handobj_root is None else avion_video_render_loader
201+
202+ if self .handobj_root is None :
203+ frames , time_meta = video_loader (self .root ,
204+ vid_path ,
205+ 'MP4' ,
206+ start_second ,
207+ end_second ,
208+ chunk_len = 15 ,
209+ clip_length = self .clip_length ,
210+ threads = 1 ,
211+ fast_rrc = False ,
212+ fast_rcc = False ,
213+ jitter = False )
214+ else :
215+ frames , time_meta = video_loader (self .root , self .handobj_root ,
202216 vid_path ,
203217 'MP4' ,
204218 start_second ,
@@ -222,7 +236,7 @@ def __init__(self,
222236 root ,
223237 annotation_file ,
224238 avion_prediction_file ,
225- handobj_root ,
239+ handobj_root = None ,
226240 clip_length = 4 ,
227241 action_representation = 'GT_random_narration' ,
228242 debug = False ,
@@ -294,10 +308,11 @@ def init_data(self):
294308
295309 return ret
296310
297- def multi_process_run (self ):
311+ def multi_process_run (self , n_samples = - 1 ):
298312 # to initialize it
299313
300- indices = list (range (len (self .data )))[:500 ]
314+ if n_samples != - 1 :
315+ indices = list (range (len (self .data )))[:n_samples ]
301316
302317 num_chunks = os .cpu_count () if not self .debug else 2
303318
@@ -340,6 +355,7 @@ def run(self, indices=None):
340355 print ("An exception occurred: " , e )
341356 predicted_answer = parsed_answer .answer
342357 explanation = parsed_answer .explanation
358+ print (explanation )
343359 gt_name = v ['gt_answer' ]
344360 ret [k ] = {
345361 'gt_name' : gt_name ,
@@ -365,19 +381,18 @@ def predict_images(self, images, parsed_item):
365381 end_second = parsed_item ['end_second' ] - parsed_item ['start_second' ]
366382 temperature = 0
367383 duration = end_second - start_second
368- system_prompt_prefix = f"""
384+
385+ system_prompt = f"""
369386 You are seeing video frames from an egocentric view of a person. Pretend that you are the person. Your task is to describe what action you are performing.
370387 To assist you for how to describe the action, the video's start time is { start_second } and the end time is { end_second :.3f} and the duration is { duration :.3f} seconds.
371- You were given multiple choice options { option_text } . Pick the correct one and put that into the answer. Note in the answer do not include the option letter, just the name of the action.
372- Also explain why the correct answer is correct and why the other options are incorrect.
388+ You were given multiple choice options { option_text } . Pick the correct one and put that into the answer. Note in the answer do not include the option letter, just the name of the action.
373389 """
374390
375- # print ('system prompt prefix')
376- # print (system_prompt_prefix)
377-
378- system_prompt_suffix = """"""
391+ if self . handobj_root is not None :
392+ system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'."""
393+
394+ system_prompt += f """Before giving the answer, explain why the correct answer is correct and why the other options are incorrect. You must pay attention to the hands and objects to support your reasoning when they are present. """
379395
380- system_prompt = system_prompt_prefix + system_prompt_suffix
381396
382397 system_message = [{"role" : "system" , "content" : system_prompt }]
383398
@@ -523,16 +538,18 @@ def multi_process_annotate(train_file_path, root, debug = False, anno_type = 'gp
523538def multi_process_inference (root ,
524539 annotation_file ,
525540 avion_prediction_file ,
526- handobj_root ,
541+ handobj_root = None ,
527542 action_representation = 'GT_random_narration' ,
528543 clip_length = 4 ,
529544 topk = 5 ,
530- debug = False ):
545+ debug = False ,
546+ n_samples = - 1
547+ ):
531548
532549 annotator = GPTInferenceAnnotator (root ,
533550 annotation_file ,
534551 avion_prediction_file ,
535- handobj_root ,
552+ handobj_root = handobj_root ,
536553 clip_length = clip_length ,
537554 debug = debug ,
538555 action_representation = action_representation ,
@@ -541,7 +558,7 @@ def multi_process_inference(root,
541558 # indices = list(range(len(annotator.data)))[:100]
542559 # annotator.run()
543560
544- annotator .multi_process_run ()
561+ annotator .multi_process_run (n_samples = n_samples )
545562
546563def calculate_gpt_accuracy (path = None , data = None ):
547564
@@ -577,15 +594,18 @@ def convert_json_to_jsonl(path):
577594
578595if __name__ == '__main__' :
579596
580- # train_file_path = '/data/epic_kitchen/AVION_PREDS/avion_mc_top5_GT_random_narration/train_convs_narration.jsonl'
581- # root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
582- # val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
583- # avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
597+ # amg0
598+ train_file_path = '/data/epic_kitchen/AVION_PREDS/avion_mc_top5_GT_random_narration/train_convs_narration.jsonl'
599+ root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
600+ val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
601+ avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
602+ handobj_root = '/data/epic_kitchen/Save_dir'
584603
585- root = '/mediaPFM/data/haozhe/onevision/llava_video/EK100'
586- val_file = '/mediaPFM/data/haozhe/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv'
587- avion_prediction_file = '/mediaPFM/data/haozhe/EK100/EK100_in_LLAVA/avion_pred_ids_val.json'
588- handobj_root = '/mnt/SV_storage/VFM/hand_object_detector/Save_dir'
604+ # haozhe's path
605+ # root = '/mediaPFM/data/haozhe/onevision/llava_video/EK100'
606+ # val_file = '/mediaPFM/data/haozhe/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv'
607+ # avion_prediction_file = '/mediaPFM/data/haozhe/EK100/EK100_in_LLAVA/avion_pred_ids_val.json'
608+ # handobj_root = '/mnt/SV_storage/VFM/hand_object_detector/Save_dir'
589609
590610
591611
@@ -594,13 +614,17 @@ def convert_json_to_jsonl(path):
594614
595615 # multi_process_annotate(train_file_path, root, debug = False, n_samples = 10000)
596616
617+
618+
597619 multi_process_inference (root ,
598620 val_file ,
599621 avion_prediction_file ,
600- handobj_root ,
622+ handobj_root = handobj_root ,
601623 debug = False ,
602- clip_length = 4 ,
603- topk = 5 )
624+ clip_length = 8 ,
625+ topk = 5 ,
626+ n_samples = 100 )
627+
604628
605629 #calculate_gpt_accuracy('valset_chatgpt_inference_results/gpt-4o-avion_top10_4frames_fixed_narration.json')
606630
0 commit comments