Skip to content

Commit ec9bd59

Browse files
committed
fixed hand object inference
1 parent 5bb019b commit ec9bd59

File tree

2 files changed

+57
-33
lines changed

2 files changed

+57
-33
lines changed

llava/action/chatgpt_utils.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tqdm import tqdm
1515
import csv
1616
import llava
17-
from llava.action.utils import avion_video_loader, create_multi_choice_from_avion_predictions, generate_label_map, AvionMultiChoiceGenerator, avion_video_render_loader
17+
from llava.action.utils import avion_video_loader, generate_label_map, AvionMultiChoiceGenerator, avion_video_render_loader
1818
from llava.action.dataset import datetime2sec
1919

2020
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -197,8 +197,22 @@ def prepare_multiple_images(self, images):
197197
return multi_image_content
198198

199199
def extract_frames(self, vid_path, start_second, end_second):
200-
frames, time_meta = avion_video_render_loader(self.root, self.handobj_root,
201-
# frames, time_meta = avion_video_loader(self.root,
200+
video_loader = avion_video_loader if self.handobj_root is None else avion_video_render_loader
201+
202+
if self.handobj_root is None:
203+
frames, time_meta = video_loader(self.root,
204+
vid_path,
205+
'MP4',
206+
start_second,
207+
end_second,
208+
chunk_len = 15,
209+
clip_length = self.clip_length,
210+
threads = 1,
211+
fast_rrc=False,
212+
fast_rcc = False,
213+
jitter = False)
214+
else:
215+
frames, time_meta = video_loader(self.root, self.handobj_root,
202216
vid_path,
203217
'MP4',
204218
start_second,
@@ -222,7 +236,7 @@ def __init__(self,
222236
root,
223237
annotation_file,
224238
avion_prediction_file,
225-
handobj_root,
239+
handobj_root = None,
226240
clip_length = 4,
227241
action_representation = 'GT_random_narration',
228242
debug = False,
@@ -294,10 +308,11 @@ def init_data(self):
294308

295309
return ret
296310

297-
def multi_process_run(self):
311+
def multi_process_run(self, n_samples = -1):
298312
# to initialize it
299313

300-
indices = list(range(len(self.data)))[:500]
314+
if n_samples != -1:
315+
indices = list(range(len(self.data)))[:n_samples]
301316

302317
num_chunks = os.cpu_count() if not self.debug else 2
303318

@@ -340,6 +355,7 @@ def run(self, indices=None):
340355
print ("An exception occurred: ", e)
341356
predicted_answer = parsed_answer.answer
342357
explanation = parsed_answer.explanation
358+
print (explanation)
343359
gt_name = v['gt_answer']
344360
ret[k] = {
345361
'gt_name': gt_name,
@@ -365,19 +381,18 @@ def predict_images(self, images, parsed_item):
365381
end_second = parsed_item['end_second'] - parsed_item['start_second']
366382
temperature = 0
367383
duration = end_second - start_second
368-
system_prompt_prefix = f"""
384+
385+
system_prompt = f"""
369386
You are seeing video frames from an egocentric view of a person. Pretend that you are the person. Your task is to describe what action you are performing.
370387
To assist you for how to describe the action, the video's start time is {start_second} and the end time is {end_second:.3f} and the duration is {duration:.3f} seconds.
371-
You were given multiple choice options {option_text}. Pick the correct one and put that into the answer. Note in the answer do not include the option letter, just the name of the action.
372-
Also explain why the correct answer is correct and why the other options are incorrect.
388+
You were given multiple choice options {option_text}. Pick the correct one and put that into the answer. Note in the answer do not include the option letter, just the name of the action.
373389
"""
374390

375-
# print ('system prompt prefix')
376-
# print (system_prompt_prefix)
377-
378-
system_prompt_suffix = """"""
391+
if self.handobj_root is not None:
392+
system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'."""
393+
394+
system_prompt += f"""Before giving the answer, explain why the correct answer is correct and why the other options are incorrect. You must pay attention to the hands and objects to support your reasoning when they are present."""
379395

380-
system_prompt = system_prompt_prefix + system_prompt_suffix
381396

382397
system_message = [{"role": "system", "content": system_prompt}]
383398

@@ -523,16 +538,18 @@ def multi_process_annotate(train_file_path, root, debug = False, anno_type = 'gp
523538
def multi_process_inference(root,
524539
annotation_file,
525540
avion_prediction_file,
526-
handobj_root,
541+
handobj_root = None,
527542
action_representation = 'GT_random_narration',
528543
clip_length = 4,
529544
topk = 5,
530-
debug = False):
545+
debug = False,
546+
n_samples = -1
547+
):
531548

532549
annotator = GPTInferenceAnnotator(root,
533550
annotation_file,
534551
avion_prediction_file,
535-
handobj_root,
552+
handobj_root = handobj_root,
536553
clip_length = clip_length,
537554
debug = debug,
538555
action_representation = action_representation,
@@ -541,7 +558,7 @@ def multi_process_inference(root,
541558
# indices = list(range(len(annotator.data)))[:100]
542559
# annotator.run()
543560

544-
annotator.multi_process_run()
561+
annotator.multi_process_run(n_samples = n_samples)
545562

546563
def calculate_gpt_accuracy(path = None, data = None):
547564

@@ -577,15 +594,18 @@ def convert_json_to_jsonl(path):
577594

578595
if __name__ == '__main__':
579596

580-
# train_file_path = '/data/epic_kitchen/AVION_PREDS/avion_mc_top5_GT_random_narration/train_convs_narration.jsonl'
581-
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
582-
# val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
583-
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
597+
# amg0
598+
train_file_path = '/data/epic_kitchen/AVION_PREDS/avion_mc_top5_GT_random_narration/train_convs_narration.jsonl'
599+
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
600+
val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
601+
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
602+
handobj_root = '/data/epic_kitchen/Save_dir'
584603

585-
root = '/mediaPFM/data/haozhe/onevision/llava_video/EK100'
586-
val_file = '/mediaPFM/data/haozhe/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv'
587-
avion_prediction_file = '/mediaPFM/data/haozhe/EK100/EK100_in_LLAVA/avion_pred_ids_val.json'
588-
handobj_root = '/mnt/SV_storage/VFM/hand_object_detector/Save_dir'
604+
# haozhe's path
605+
# root = '/mediaPFM/data/haozhe/onevision/llava_video/EK100'
606+
# val_file = '/mediaPFM/data/haozhe/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv'
607+
# avion_prediction_file = '/mediaPFM/data/haozhe/EK100/EK100_in_LLAVA/avion_pred_ids_val.json'
608+
# handobj_root = '/mnt/SV_storage/VFM/hand_object_detector/Save_dir'
589609

590610

591611

@@ -594,13 +614,17 @@ def convert_json_to_jsonl(path):
594614

595615
# multi_process_annotate(train_file_path, root, debug = False, n_samples = 10000)
596616

617+
618+
597619
multi_process_inference(root,
598620
val_file,
599621
avion_prediction_file,
600-
handobj_root,
622+
handobj_root = handobj_root,
601623
debug = False,
602-
clip_length = 4,
603-
topk = 5)
624+
clip_length = 8,
625+
topk = 5,
626+
n_samples = 100)
627+
604628

605629
#calculate_gpt_accuracy('valset_chatgpt_inference_results/gpt-4o-avion_top10_4frames_fixed_narration.json')
606630

llava/action/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,10 +587,10 @@ def avion_video_render_loader(root, handobj_root, vid, ext, second, end_second,
587587

588588
frames = render_frames(frames, hand_dets_list, obj_dets_list, thresh_hand=0.5, thresh_obj=0.5)
589589

590-
# plt.figure()
591-
# plt.imshow(frames[0])
592-
# plt.savefig('frame_rendered.png')
593-
# plt.close()
590+
plt.figure()
591+
plt.imshow(frames[0])
592+
plt.savefig('frame_rendered.png')
593+
plt.close()
594594

595595
all_frames.append(frames)
596596
all_frame_ids.append(frame_ids)

0 commit comments

Comments
 (0)