Skip to content

Commit aa3a913

Browse files
committed
WIP
1 parent 45f6eec commit aa3a913

File tree

2 files changed

+69
-17
lines changed

2 files changed

+69
-17
lines changed

llava/action/chatgpt_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def prepare_multiple_images(self, images):
289289

290290
return multi_image_content
291291

292-
def extract_frames(self, vid_path, start_second, end_second):
292+
def extract_frames(self, vid_path, start_second, end_second):
293293
if hasattr(self, 'handobj_root') and self.handobj_root is not None:
294294

295295
frames, time_meta = avion_video_render_loader(self.root, self.handobj_root,

llava/action/generate_comparison_dpo.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
import traceback
1515
from concurrent.futures import ProcessPoolExecutor
1616
import openai
17+
from llava.action.utils import avion_video_loader, avion_video_render_loader, generate_label_map
18+
import copy
19+
import json
1720

1821

1922
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -37,14 +40,15 @@ def __init__(self,
3740
root,
3841
annotation_file,
3942
clip_length = 4,
40-
debug = False
43+
debug = False,
44+
fraction = 0.2
4145
):
4246
self.root = root
4347
self.annotation_file = annotation_file
4448
self.clip_length = clip_length
4549
self.debug = debug
4650
self.question_type = 'gpt-gt-reason'
47-
51+
self.fraction = fraction
4852
self.data = self.init_data()
4953

5054
print (len(self.data))
@@ -56,10 +60,11 @@ def select_train_subset(self):
5660
header = csv_reader[0] # Get header
5761
data = csv_reader[1:] # Get data
5862
N = len(data)
59-
print ('N', N)
63+
6064
# get a random subset of the data such as 20% of them. Give the indices
6165
random.seed(0)
62-
indices = random.sample(range(N), int(N*0.2))
66+
indices = random.sample(range(N), int(N*self.fraction))
67+
print ('indices', len(indices))
6368
return indices
6469

6570
def init_data(self):
@@ -91,11 +96,14 @@ def init_data(self):
9196
count+=1
9297
return ret
9398

94-
def multi_process_run(self, n_samples = -1):
99+
100+
def multi_process_run(self, n_samples = -1, filename = 'inference_results.json'):
95101
# to initialize it
96102

97103
if n_samples != -1:
98104
indices = list(range(len(self.data)))[:n_samples]
105+
else:
106+
indices = list(range(len(self.data)))
99107

100108
num_chunks = os.cpu_count() if not self.debug else 1
101109

@@ -114,20 +122,20 @@ def multi_process_run(self, n_samples = -1):
114122
if self.debug:
115123
print (combined_results)
116124

125+
self.checkpoint(combined_results, filename)
117126

118127
def predict_images(self, images, parsed_item):
119128
"""
120129
Predict the action from the images
121130
"""
122131
from llava.action.utils import format_task_related_prompt
123-
options = parsed_item['options']
124132
start_second = 0
125133
end_second = parsed_item['end_second'] - parsed_item['start_second']
126134
temperature = 0
127135
video_duration = end_second - start_second
128136
n_frames = len(images)
129137

130-
task_related_prompt = format_task_related_prompt(options, self.question_type, perspective = 'first_person')
138+
task_related_prompt = format_task_related_prompt('', self.question_type, perspective = 'first_person')
131139

132140
time_instruction = f"The provided video lasts for {video_duration:.3f} seconds. "
133141

@@ -210,23 +218,67 @@ def run(self, indices = None):
210218

211219
caption = parsed_answer.caption
212220
print ('caption:', caption)
221+
print ('gt is ', v['gt_answer'])
222+
223+
ret[k] = copy.deepcopy(v)
224+
ret[k]['caption'] = caption
225+
213226

227+
214228
if self.debug:
215229
break
216230

217231
return ret
218-
219-
220-
221-
222-
232+
233+
234+
def create_comparison_data(positive_filename, negative_filename, out_filename):
235+
"""
236+
Create the comparison data
237+
"""
238+
ret = []
239+
with open(positive_filename, 'r') as f:
240+
positive_data = json.load(f)
241+
with open(negative_filename, 'r') as f:
242+
negative_data = json.load(f)
243+
244+
for key in positive_data:
245+
pos_data = positive_data[key]
246+
neg_data = negative_data[key]
247+
assert pos_data['vid_path'] == neg_data['vid_path']
248+
assert pos_data['start_second'] == neg_data['start_second']
249+
template = {
250+
'id': pos_data['vid_path'].replace('/', '-'),
251+
'prompt': '',
252+
'answer': pos_data['caption'],
253+
'chosen': pos_data['caption'],
254+
'rejected': neg_data['caption'],
255+
'video': pos_data['vid_path'].replace('/', '-'),
256+
'split': 'train',
257+
'dataset_name' : 'EK100',
258+
'start_timestamp': pos_data['start_second'],
259+
'end_timestamp': pos_data['end_second'],
260+
'num_samples': 1,
261+
'question_type': 'dpo',
262+
'task_instruction': '',
263+
}
264+
ret.append(template)
265+
266+
# save to jsonl
267+
with open(out_filename, 'w') as f:
268+
for item in ret:
269+
f.write(json.dumps(item) + '\n')
223270

224271
if __name__ == '__main__':
225272
video_root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
226273
anno_root = '/data/shaokai/epic-kitchens-100-annotations/'
227274
clip_length = 8
228275

229-
cap = CaptionInference(video_root, os.path.join(anno_root, 'EPIC_100_train.csv'), clip_length, debug = True)
230-
231-
#cap.multi_process_run(n_samples = 2)
232-
cap.run()
276+
# cap = CaptionInference(video_root,
277+
# os.path.join(anno_root, 'EPIC_100_train.csv'),
278+
# clip_length,
279+
# debug = False,
280+
# fraction = 0.01)
281+
# cap.multi_process_run(n_samples = -1, filename = f'gpt4o_inference_{clip_length}frame_1percent.json')
282+
283+
284+
create_comparison_data('gpt4o_inference_8frame_1percent.json', 'gpt4o_inference_1frame_1percent.json', 'comparison_data_1percent.jsonl')

0 commit comments

Comments
 (0)