Skip to content

Commit 5bb019b

Browse files
committed
Merge remote-tracking branch 'dev/spatial_video_loader' into temp_work
2 parents 5edb39b + 050d5d2 commit 5bb019b

File tree

4 files changed

+297
-17
lines changed

4 files changed

+297
-17
lines changed

llava/action/chatgpt_utils.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import io
33
import json
44
import os
5+
import sys
6+
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
57
import numpy as np
68
import openai
79
from pydantic import BaseModel
@@ -12,7 +14,7 @@
1214
from tqdm import tqdm
1315
import csv
1416
import llava
15-
from llava.action.utils import avion_video_loader, create_multi_choice_from_avion_predictions, generate_label_map, AvionMultiChoiceGenerator
17+
from llava.action.utils import avion_video_loader, create_multi_choice_from_avion_predictions, generate_label_map, AvionMultiChoiceGenerator, avion_video_render_loader
1618
from llava.action.dataset import datetime2sec
1719

1820
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
@@ -34,6 +36,9 @@ def generate_prompt(cls, start_second, end_second, option_text, gt_answer):
3436
You are seeing video frames from an egocentric view of a person.
3537
Please talk as if you are the person in the video and describe what action you are performing.
3638
To assist you for how to describe the action, the video's start time is {start_second} and the end time is {end_second} and the duration is {end_second - start_second} seconds.
39+
Meanwhile, the left hand region is marked as 'L' in red bounding box and the right hand region is marked as 'R' in blue bounding box.
40+
The contact information is also provided in the bouding box tags with 'N' for no contact, 'S' for self contact, 'O' for other person contact, 'P' for portable object contact, and 'F' for stationary object contact.
41+
The contacted objects are also marked as 'O' in yellow bounding box.
3742
To further assist you for how to describe the action, note that in a multi-choice video question answering, you were given following options {option_text} and the correct answer is {gt_answer}.
3843
In addition to describe what you see, describe why wrong answers were wrong and why right answer was right.
3944
When you explain why wrong answers were wrong and why right answer was right, you should use the following flow of reasoning:
@@ -192,7 +197,8 @@ def prepare_multiple_images(self, images):
192197
return multi_image_content
193198

194199
def extract_frames(self, vid_path, start_second, end_second):
195-
frames, time_meta = avion_video_loader(self.root,
200+
frames, time_meta = avion_video_render_loader(self.root, self.handobj_root,
201+
# frames, time_meta = avion_video_loader(self.root,
196202
vid_path,
197203
'MP4',
198204
start_second,
@@ -216,6 +222,7 @@ def __init__(self,
216222
root,
217223
annotation_file,
218224
avion_prediction_file,
225+
handobj_root,
219226
clip_length = 4,
220227
action_representation = 'GT_random_narration',
221228
debug = False,
@@ -233,6 +240,7 @@ def __init__(self,
233240
self.topk = topk
234241
self.annotation_file = annotation_file
235242
self.avion_prediction_file = avion_prediction_file
243+
self.handobj_root = handobj_root
236244
self.annotation_root = Path(annotation_file).parent
237245
self.action_representation = action_representation
238246
self.labels, self.mapping_vn2narration, self.mapping_vn2act, self.verb_maps, self.noun_maps = generate_label_map(self.annotation_root,
@@ -289,7 +297,7 @@ def init_data(self):
289297
def multi_process_run(self):
290298
# to initialize it
291299

292-
indices = list(range(len(self.data)))
300+
indices = list(range(len(self.data)))[:500]
293301

294302
num_chunks = os.cpu_count() if not self.debug else 2
295303

@@ -312,8 +320,11 @@ def multi_process_run(self):
312320

313321
self.checkpoint(combined_results, "gpt_inference_results.json")
314322

315-
def run(self, indices):
316-
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in indices}
323+
def run(self, indices=None):
324+
if indices is None:
325+
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in list(range(len(self.data)))}
326+
else:
327+
data_batch = {i : self.data[i] for i in range(len(self.data)) if i in indices}
317328
ret = {}
318329

319330
for k,v in tqdm(data_batch.items()):
@@ -337,7 +348,11 @@ def run(self, indices):
337348
}
338349
if self.debug:
339350
break
340-
return ret
351+
if indices is None:
352+
calculation = calculate_gpt_accuracy(data = ret)
353+
self.checkpoint(ret, "gpt_inference_results.json")
354+
else:
355+
return ret
341356

342357

343358

@@ -508,6 +523,7 @@ def multi_process_annotate(train_file_path, root, debug = False, anno_type = 'gp
508523
def multi_process_inference(root,
509524
annotation_file,
510525
avion_prediction_file,
526+
handobj_root,
511527
action_representation = 'GT_random_narration',
512528
clip_length = 4,
513529
topk = 5,
@@ -516,11 +532,15 @@ def multi_process_inference(root,
516532
annotator = GPTInferenceAnnotator(root,
517533
annotation_file,
518534
avion_prediction_file,
535+
handobj_root,
519536
clip_length = clip_length,
520537
debug = debug,
521538
action_representation = action_representation,
522539
topk = topk)
523540

541+
# indices = list(range(len(annotator.data)))[:100]
542+
# annotator.run()
543+
524544
annotator.multi_process_run()
525545

526546
def calculate_gpt_accuracy(path = None, data = None):
@@ -557,10 +577,15 @@ def convert_json_to_jsonl(path):
557577

558578
if __name__ == '__main__':
559579

560-
train_file_path = '/data/epic_kitchen/AVION_PREDS/avion_mc_top5_GT_random_narration/train_convs_narration.jsonl'
561-
root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
562-
val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
563-
avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
580+
# train_file_path = '/data/epic_kitchen/AVION_PREDS/avion_mc_top5_GT_random_narration/train_convs_narration.jsonl'
581+
# root = '/data/EK100/EK100_320p_15sec_30fps_libx264'
582+
# val_file = '/data/epic_kitchen/epic-kitchens-100-annotations/EPIC_100_validation.csv'
583+
# avion_prediction_file = '/data/epic_kitchen/AVION_PREDS/avion_pred_ids_val.json'
584+
585+
root = '/mediaPFM/data/haozhe/onevision/llava_video/EK100'
586+
val_file = '/mediaPFM/data/haozhe/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv'
587+
avion_prediction_file = '/mediaPFM/data/haozhe/EK100/EK100_in_LLAVA/avion_pred_ids_val.json'
588+
handobj_root = '/mnt/SV_storage/VFM/hand_object_detector/Save_dir'
564589

565590

566591

@@ -569,13 +594,14 @@ def convert_json_to_jsonl(path):
569594

570595
# multi_process_annotate(train_file_path, root, debug = False, n_samples = 10000)
571596

572-
# multi_process_inference(root,
573-
# val_file,
574-
# avion_prediction_file,
575-
# debug = True,
576-
# clip_length = 4,
577-
# topk = 5)
597+
multi_process_inference(root,
598+
val_file,
599+
avion_prediction_file,
600+
handobj_root,
601+
debug = False,
602+
clip_length = 4,
603+
topk = 5)
578604

579605
#calculate_gpt_accuracy('valset_chatgpt_inference_results/gpt-4o-avion_top10_4frames_fixed_narration.json')
580606

581-
convert_json_to_jsonl('train_anno_gpt-gt-reason_4_10000.json')
607+
# convert_json_to_jsonl('train_anno_gpt-gt-reason_4_10000.json')

llava/action/render_utils.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import numpy as np
2+
import cv2
3+
import ast
4+
from PIL import Image, ImageDraw, ImageFont
5+
6+
color_rgb = [(255,255,0), (255, 128,0), (128,255,0), (0,128,255), (0,0,255), (127,0,255), (255,0,255), (255,0,127), (255,0,0), (255,204,153), (255,102,102), (153,255,153), (153,153,255), (0,0,153)]
7+
color_rgba = [(255,255,0,70), (255, 128,0,70), (128,255,0,70), (0,128,255,70), (0,0,255,70), (127,0,255,70), (255,0,255,70), (255,0,127,70), (255,0,0,70), (255,204,153,70), (255,102,102,70), (153,255,153,70), (153,153,255,70), (0,0,153,70)]
8+
9+
10+
hand_rgb = [(0, 90, 181), (220, 50, 32)]
11+
hand_rgba = [(0, 90, 181, 70), (220, 50, 32, 70)]
12+
13+
obj_rgb = (255, 194, 10)
14+
obj_rgba = (255, 194, 10, 70)
15+
16+
side_map = {'l':'Left', 'r':'Right'}
17+
side_map2 = {0:'Left', 1:'Right'}
18+
side_map3 = {0:'L', 1:'R'}
19+
state_map = {0:'No Contact', 1:'Self Contact', 2:'Another Person', 3:'Portable Object', 4:'Stationary Object'}
20+
state_map2 = {0:'N', 1:'S', 2:'O', 3:'P', 4:'F'}
21+
22+
vis_settings = {'font_size':20, 'line_width':2, 'point_radius':4, 'hand_color':hand_rgb, 'hand_alpha':[None, None], 'obj_color':obj_rgb, 'obj_alpha':None, 'text_alpha':(255, 255, 255, 255)}
23+
24+
def calculate_center(bb):
25+
return [(bb[0] + bb[2])/2, (bb[1] + bb[3])/2]
26+
27+
def filter_object(obj_dets, hand_dets):
28+
filtered_object = []
29+
object_cc_list = []
30+
for j in range(obj_dets.shape[0]):
31+
object_cc_list.append(calculate_center(obj_dets[j,:4]))
32+
object_cc_list = np.array(object_cc_list)
33+
img_obj_id = []
34+
for i in range(hand_dets.shape[0]):
35+
if hand_dets[i, 5] <= 0:
36+
img_obj_id.append(-1)
37+
continue
38+
hand_cc = np.array(calculate_center(hand_dets[i,:4]))
39+
point_cc = np.array([(hand_cc[0]+hand_dets[i,6]*10000*hand_dets[i,7]), (hand_cc[1]+hand_dets[i,6]*10000*hand_dets[i,8])])
40+
dist = np.sum((object_cc_list - point_cc)**2,axis=1)
41+
dist_min = np.argmin(dist)
42+
img_obj_id.append(dist_min)
43+
return img_obj_id
44+
45+
def draw_obj_mask(image, draw, obj_idx, obj_bbox, obj_score, width, height):
46+
font = ImageFont.truetype('llava/action/times_b.ttf', size=vis_settings['font_size'])
47+
mask = Image.new('RGBA', (width, height))
48+
pmask = ImageDraw.Draw(mask)
49+
pmask.rectangle(obj_bbox, outline=vis_settings['obj_color'], width=vis_settings['line_width'], fill=vis_settings['obj_alpha'])
50+
image.paste(mask, (0,0), mask)
51+
52+
draw.rectangle([obj_bbox[0], max(0, obj_bbox[1]-vis_settings['font_size']), obj_bbox[0]+vis_settings['font_size']+2,
53+
max(0, obj_bbox[1]-vis_settings['font_size'])+vis_settings['font_size']],
54+
fill=vis_settings['text_alpha'], outline=vis_settings['obj_color'], width=vis_settings['line_width'])
55+
draw.text((obj_bbox[0]+5, max(0, obj_bbox[1]-vis_settings['font_size'])-2), f'O', font=font, fill=(0,0,0)) #
56+
57+
return image
58+
59+
def draw_hand_mask(image, draw, hand_idx, hand_bbox, hand_score, side, state, width, height):
60+
font = ImageFont.truetype('llava/action/times_b.ttf', size=vis_settings['font_size'])
61+
if side == 0:
62+
side_idx = 0
63+
elif side == 1:
64+
side_idx = 1
65+
mask = Image.new('RGBA', (width, height))
66+
pmask = ImageDraw.Draw(mask)
67+
pmask.rectangle(hand_bbox, outline=vis_settings['hand_color'][side_idx], width=vis_settings['line_width'], fill=vis_settings['hand_alpha'][side_idx])
68+
image.paste(mask, (0,0), mask)
69+
# text
70+
71+
draw = ImageDraw.Draw(image)
72+
draw.rectangle([hand_bbox[0], max(0, hand_bbox[1]-vis_settings['font_size']), hand_bbox[0]+vis_settings['font_size']*2+2,
73+
max(0, hand_bbox[1]-vis_settings['font_size'])+vis_settings['font_size']],
74+
fill=vis_settings['text_alpha'], outline=vis_settings['hand_color'][side_idx], width=vis_settings['line_width'])
75+
draw.text((hand_bbox[0]+6, max(0, hand_bbox[1]-vis_settings['font_size'])-2), f'{side_map3[int(float(side))]}-{state_map2[int(float(state))]}', font=font, fill=(0,0,0)) #
76+
77+
return image
78+
79+
def draw_line_point(draw, side_idx, hand_center, object_center):
80+
81+
draw.line([hand_center, object_center], fill=vis_settings['hand_color'][side_idx], width=vis_settings['line_width'])
82+
x, y = hand_center[0], hand_center[1]
83+
r=vis_settings['point_radius']
84+
draw.ellipse((x-r, y-r, x+r, y+r), fill=vis_settings['hand_color'][side_idx])
85+
x, y = object_center[0], object_center[1]
86+
draw.ellipse((x-r, y-r, x+r, y+r), fill=vis_settings['obj_color'])
87+
88+
def vis_detections_PIL(im, class_name, dets, thresh=0.8):
89+
"""Visual debugging of detections."""
90+
91+
image = Image.fromarray(im).convert("RGBA")
92+
draw = ImageDraw.Draw(image)
93+
width, height = image.size
94+
95+
for hand_idx, i in enumerate(range(np.minimum(10, dets.shape[0]))):
96+
bbox = list(int(np.round(x)) for x in dets[i, :4])
97+
score = dets[i, 4]
98+
lr = dets[i, -1]
99+
state = dets[i, 5]
100+
if score > thresh:
101+
image = draw_hand_mask(image, draw, hand_idx, bbox, score, lr, state, width, height)
102+
103+
return image
104+
105+
def vis_detections_filtered_objects_PIL(im, obj_dets, hand_dets, thresh_hand=0.8, thresh_obj=0.01):
106+
107+
# convert to PIL
108+
im = im[:,:,::-1]
109+
image = Image.fromarray(im).convert("RGBA")
110+
draw = ImageDraw.Draw(image)
111+
width, height = image.size
112+
113+
if (obj_dets is not None) and (hand_dets is not None):
114+
img_obj_id = filter_object(obj_dets, hand_dets)
115+
for obj_idx, i in enumerate(range(np.minimum(10, obj_dets.shape[0]))):
116+
bbox = list(int(np.round(x)) for x in obj_dets[i, :4])
117+
score = obj_dets[i, 4]
118+
if score > thresh_obj and i in img_obj_id:
119+
# viz obj by PIL
120+
image = draw_obj_mask(image, draw, obj_idx, bbox, score, width, height)
121+
122+
for hand_idx, i in enumerate(range(np.minimum(10, hand_dets.shape[0]))):
123+
bbox = list(int(np.round(x)) for x in hand_dets[i, :4])
124+
score = hand_dets[i, 4]
125+
lr = hand_dets[i, -1]
126+
state = hand_dets[i, 5]
127+
if score > thresh_hand:
128+
# viz hand by PIL
129+
image = draw_hand_mask(image, draw, hand_idx, bbox, score, lr, state, width, height)
130+
131+
if state > 0: # in contact hand
132+
133+
obj_cc, hand_cc = calculate_center(obj_dets[img_obj_id[i],:4]), calculate_center(bbox)
134+
# viz line by PIL
135+
if lr == 0:
136+
side_idx = 0
137+
elif lr == 1:
138+
side_idx = 1
139+
draw_line_point(draw, side_idx, (int(hand_cc[0]), int(hand_cc[1])), (int(obj_cc[0]), int(obj_cc[1])))
140+
141+
elif hand_dets is not None:
142+
image = vis_detections_PIL(im, 'hand', hand_dets, thresh_hand)
143+
144+
return image
145+
146+
def render_frame(im, hand_dets, obj_dets, thresh_hand=0.5, thresh_obj=0.5):
147+
im_show = im.copy()
148+
im_show = cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)
149+
hand_dets = np.array(ast.literal_eval(hand_dets)) if hand_dets != '[]' else None
150+
obj_dets = np.array(ast.literal_eval(obj_dets)) if obj_dets != '[]' else None
151+
im_show = vis_detections_filtered_objects_PIL(im_show, obj_dets, hand_dets, thresh_hand, thresh_obj)
152+
# im_show.save('test.png')
153+
im_show = np.array(im_show)
154+
return im_show

llava/action/times_b.ttf

93.4 KB
Binary file not shown.

0 commit comments

Comments
 (0)