Skip to content

Commit dda80de

Browse files
committed
better prompts
1 parent 24e2127 commit dda80de

File tree

2 files changed

+58
-62
lines changed

2 files changed

+58
-62
lines changed

llava/action/chatgpt_utils.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
from concurrent.futures import ProcessPoolExecutor
88
from tqdm import tqdm
99
from llava.action.utils import AvionMultiChoiceGenerator
10-
from llava.action.utils import avion_video_loader, avion_video_render_loader
10+
from llava.action.utils import avion_video_loader, avion_video_render_loader, generate_label_map
11+
from llava.action.dataset import datetime2sec
12+
from llava.action.ek_eval import process_raw_pred
13+
import csv
1114
import copy
1215
import torch
1316
import io
1417
import numpy as np
1518
import base64
19+
from pathlib import Path
20+
1621

1722
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
1823

@@ -93,6 +98,12 @@ class GT_Augmentation_Response(BaseModel):
9398
caption_with_reasoning: str
9499
disagree_with_human_annotation: bool
95100

101+
class GT_Agnostic_Response(BaseModel):
102+
"""
103+
The GT was known. The response is to add more information to the GT
104+
"""
105+
answer: str
106+
96107

97108
class GPTHandObjectResponse(BaseModel):
98109
"""
@@ -265,6 +276,7 @@ def __init__(self,
265276
handobj_root = None,
266277
clip_length = 4,
267278
action_representation = 'GT_random_narration',
279+
question_type = 'cot_mc',
268280
debug = False,
269281
topk = 10,
270282
):
@@ -281,6 +293,7 @@ def __init__(self,
281293
self.annotation_file = annotation_file
282294
self.avion_prediction_file = avion_prediction_file
283295
self.handobj_root = handobj_root
296+
self.question_type = question_type
284297
self.annotation_root = Path(annotation_file).parent
285298
self.action_representation = action_representation
286299
self.labels, self.mapping_vn2narration, self.mapping_vn2act, self.verb_maps, self.noun_maps = generate_label_map(self.annotation_root,
@@ -323,9 +336,8 @@ def init_data(self):
323336

324337
options = mc_data['options'][0]
325338

326-
option_string = ','.join(options)
327339
ret[idx] = {
328-
'options': option_string,
340+
'options': options,
329341
'gt_answer': narration,
330342
'start_second': start_second,
331343
'end_second': end_second,
@@ -380,14 +392,13 @@ def run(self, indices=None):
380392
except Exception as e:
381393
print ("An exception occurred: ", e)
382394
predicted_answer = parsed_answer.answer
383-
explanation = parsed_answer.explanation
384-
print (explanation)
395+
print (predicted_answer)
385396
gt_name = v['gt_answer']
386397
ret[k] = {
387398
'gt_name': gt_name,
388-
'chatgpt_answer': predicted_answer,
389-
'explanation': explanation
399+
'chatgpt_answer': process_raw_pred(predicted_answer),
390400
}
401+
print (ret)
391402
if self.debug:
392403
break
393404
if indices is None:
@@ -401,24 +412,26 @@ def run(self, indices=None):
401412
def predict_images(self, images, parsed_item):
402413
"""
403414
Predict the action from the images
404-
"""
405-
option_text = parsed_item['options']
415+
"""
416+
from llava.action.utils import format_task_related_prompt
417+
options = parsed_item['options']
406418
start_second = 0
407419
end_second = parsed_item['end_second'] - parsed_item['start_second']
408420
temperature = 0
409-
duration = end_second - start_second
421+
video_duration = end_second - start_second
422+
n_frames = len(images)
410423

411-
system_prompt = f"""
412-
You are seeing video frames from an egocentric view of a person. Pretend that you are the person. Your task is to describe what action you are performing.
413-
To assist you for how to describe the action, the video's start time is {start_second} and the end time is {end_second:.3f} and the duration is {duration:.3f} seconds.
414-
You were given multiple choice options {option_text}. Pick the correct one and put that into the answer. Note in the answer do not include the option letter, just the name of the action.
415-
"""
424+
task_related_prompt = format_task_related_prompt(options, self.question_type, perspective = 'first_person')
425+
426+
time_instruction = f"The provided video lasts for {video_duration:.3f} seconds, and {n_frames} frames are uniformly sampled from it. "
427+
428+
system_prompt = time_instruction + task_related_prompt
429+
430+
print (system_prompt)
416431

417432
if self.handobj_root is not None:
418433
system_prompt += f"""To further assist you, we mark hands and object when they are visible. The left hand is marked with a bounding box that contains letter L and the right hand's bounding box contains letter R. The object is marked as 'O'."""
419434

420-
system_prompt += f"""Before giving the answer, explain why the correct answer is correct and why the other options are incorrect. You must pay attention to the hands and objects to support your reasoning when they are present."""
421-
422435

423436
system_message = [{"role": "system", "content": system_prompt}]
424437

@@ -574,7 +587,7 @@ def parse_conversation_from_train_convs(self, item):
574587
"""
575588
conversations = item['conversations']
576589
human_dict = conversations[0]
577-
option_text = ','.join(eval(human_dict['value']))
590+
option_text = ', '.join(eval(human_dict['value']))
578591
gpt_dict = conversations[1]
579592
gt_answer = gpt_dict['value']
580593
print ('gt_answer', gt_answer)
@@ -691,31 +704,6 @@ def multi_process_annotate(train_file_path,
691704

692705
annotator.multi_process_run(n_samples = n_samples)
693706

694-
def multi_process_inference(root,
695-
annotation_file,
696-
avion_prediction_file,
697-
handobj_root = None,
698-
action_representation = 'GT_random_narration',
699-
clip_length = 4,
700-
topk = 5,
701-
debug = False,
702-
n_samples = -1
703-
):
704-
705-
annotator = GPTInferenceAnnotator(root,
706-
annotation_file,
707-
avion_prediction_file,
708-
handobj_root = handobj_root,
709-
clip_length = clip_length,
710-
debug = debug,
711-
action_representation = action_representation,
712-
topk = topk)
713-
714-
# indices = list(range(len(annotator.data)))[:100]
715-
# annotator.run()
716-
717-
annotator.multi_process_run(n_samples = n_samples)
718-
719707
def calculate_gpt_accuracy(path = None, data = None):
720708

721709
assert path is not None or data is not None
@@ -872,14 +860,22 @@ def convert_instruct_json_to_jsonl(path, apply_filter = False):
872860
# n_samples = -1,
873861
# anno_type = 'gpt-gt-instruct-reason')
874862

875-
# multi_process_inference(root,
876-
# val_file,
877-
# avion_prediction_file,
878-
# handobj_root = handobj_root,
879-
# debug = False,
880-
# clip_length = 8,
881-
# topk = 5,
882-
# n_samples = 100)
863+
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
864+
val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
865+
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
866+
867+
868+
annotator = GPTInferenceAnnotator(root,
869+
val_file,
870+
avion_prediction_file,
871+
clip_length = 4,
872+
debug = False,
873+
action_representation = "GT_random_narration",
874+
question_type = 'mc_GT_random_narration',
875+
topk = 5)
876+
877+
annotator.multi_process_run(n_samples = 100)
878+
883879

884880
# convert_json_to_jsonl('train_anno_gpt-gt-reason_4_10000.json')
885881

@@ -889,7 +885,7 @@ def convert_instruct_json_to_jsonl(path, apply_filter = False):
889885
# ann = GPTHandObjectAnnotator(train_file_path, debug = False)
890886
# ann.multi_process_run(n_samples = -1)
891887

892-
convert_json_to_jsonl('train_anno_gpt-gt-reason_4_first_person_all.json')
888+
# convert_json_to_jsonl('train_anno_gpt-gt-reason_4_first_person_all.json')
893889

894890
#calc_disagree_ratio_from_jsonl('train_anno_gpt-gt-reason_4_first_person_all.jsonl')
895891

llava/action/utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,18 @@ def format_task_related_prompt(question, question_type, perspective = "first_per
168168
We are thinking about tweaking the prompt based on the action representation.
169169
"""
170170
if perspective == "first_person":
171-
perspective_prefix = "You are seeing this video from egocentric view and your hands are sometimes interacting with obects. What action are you performing? "
171+
perspective_prefix = "You are seeing this video from egocentric view and you are the person. Your hands are sometimes interacting with obects. What action are you performing? "
172172
elif perspective == "third_person":
173173
perspective_prefix = "The video is taken from egocentric view. What action is the person performing? "
174174
if question_type.startswith("mc_"):
175175
action_rep_suffix = "Given multiple choices, format your answer briefly such as 'A. move knife'. "
176176
prefix = f"{perspective_prefix}{action_rep_suffix}\n"
177177
assert isinstance(question, list)
178-
suffix = ",".join(question)
179-
suffix = "Here are the options you are tasked:\n" + suffix
178+
suffix = ", ".join(question)
179+
suffix = "Here are the options of actions you are selecting:\n" + suffix
180180
ret = prefix + suffix
181181
elif question_type == "gpt-gt-reason":
182-
ret = f"{perspective_prefix}Please explain your reasoning steps before reaching to your answer. "
182+
ret = f"{perspective_prefix}Describe in details what you see from the video frames."
183183
elif question_type == "gpt-gt-instruct-reason":
184184
ret = question
185185
elif question_type == "gpt-hand-object":
@@ -188,11 +188,11 @@ def format_task_related_prompt(question, question_type, perspective = "first_per
188188
"""
189189
Explain the reasoning first and do the multiple-choice.
190190
"""
191-
action_rep_suffix = "Given multiple choices, explain your reasoning steps before you reach to your answer. "
191+
action_rep_suffix = "Describe what you see in details. Afterwards, briefly format your answer such as 'A. move knife'. "
192192
prefix = f"{perspective_prefix} {action_rep_suffix}\n"
193193
assert isinstance(question, list)
194-
suffix = ",".join(question)
195-
suffix = "Here are the options you are tasked:\n" + suffix
194+
suffix = ", ".join(question)
195+
suffix = "Here are the options of choices you are selecting:\n" + suffix
196196
ret = prefix + suffix
197197
else:
198198
raise NotImplementedError(f"question_type: {question_type} is not supported")
@@ -202,10 +202,10 @@ def format_task_related_prompt(question, question_type, perspective = "first_per
202202

203203
def format_time_instruction(video_duration, n_frames, include_frame_time = False):
204204

205-
prefix = f"You are seeing a video taken from egocentric view. The video lasts for {video_duration:.3f} seconds, and {n_frames} frames are uniformly sampled from it."
205+
prefix = f"The provided video lasts for {video_duration:.3f} seconds, and {n_frames} frames are uniformly sampled from it."
206206

207207
frame_time = [i * (video_duration / n_frames) for i in range(n_frames)]
208-
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
208+
frame_time = ", ".join([f"{i:.2f}s" for i in frame_time])
209209

210210
suffix = ""
211211
if include_frame_time:
@@ -671,7 +671,7 @@ def avion_video_render_loader(root, handobj_root, vid, ext, second, end_second,
671671
all_frame_ids = np.concatenate(all_frame_ids, axis = 0)
672672
frame_time = [e/fps for e in all_frame_ids]
673673
frame_time-= frame_time[0]
674-
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
674+
frame_time = ", ".join([f"{i:.2f}s" for i in frame_time])
675675
time_meta['frame_time'] = frame_time
676676
assert res.shape[0] == clip_length, "{}, {}, {}, {}, {}, {}, {}".format(root, vid, second, end_second, res.shape[0], rel_frame_ids, frame_ids)
677677
return res, time_meta

0 commit comments

Comments
 (0)