Skip to content

Commit 79829db

Browse files
committed
gpt testing also supports benchmark
1 parent 0f32693 commit 79829db

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

llava/action/benchmark.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
n_frames = 4
1111
topk = 5
1212
action_representation = 'GT_random_narration'
13-
gpt_model = 'gpt-4o-mini-2024-07-18'
14-
# gpt_model = 'gpt-4o-2024-08-06'
15-
perspective = 'third_person'
13+
#gpt_model = 'gpt-4o-mini-2024-07-18'
14+
gpt_model = 'gpt-4o-2024-08-06'
15+
perspective = 'first_person'
16+
benchmark_testing = True
1617

1718

1819
def benchmark_avion_mcq(n_samples):
@@ -26,6 +27,7 @@ def benchmark_avion_mcq(n_samples):
2627
question_type = 'mc_',
2728
action_representation=action_representation,
2829
perspective = perspective,
30+
benchmark_testing = benchmark_testing,
2931
topk = topk)
3032
inferencer.multi_process_run(n_samples)
3133

@@ -40,6 +42,7 @@ def benchmark_tim_mcq(n_samples):
4042
question_type = 'mc_',
4143
action_representation=action_representation,
4244
perspective = perspective,
45+
benchmark_testing = benchmark_testing,
4346
topk = topk)
4447
inferencer.multi_process_run(n_samples)
4548

@@ -53,6 +56,7 @@ def benchmark_random_mcq(n_samples):
5356
question_type = 'mc_',
5457
action_representation=action_representation,
5558
perspective = perspective,
59+
benchmark_testing = benchmark_testing,
5660
topk = topk)
5761

5862
inferencer.multi_process_run(n_samples)
@@ -61,4 +65,4 @@ def benchmark_random_mcq(n_samples):
6165
if __name__ == '__main__':
6266
benchmark_avion_mcq(100)
6367
benchmark_tim_mcq(100)
64-
benchmark_random_mcq(100)
68+
#benchmark_random_mcq(100)

llava/action/chatgpt_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def init_data(self):
411411
self.mapping_vn2narration,
412412
self.verb_maps,
413413
self.noun_maps,
414-
benchmark_tesitng = self.benchmark_testing,
414+
benchmark_testing = self.benchmark_testing,
415415
is_train = False)
416416
else:
417417
mc_data = self.mc_generator.generate_multi_choice(gt_vn,

llava/action/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ def generate_multi_choice(self,
411411
mapping_vn2narration,
412412
verb_maps,
413413
noun_maps,
414-
is_train = True
414+
is_train = True,
415+
benchmark_testing = False
415416
):
416417

417418
"""
@@ -425,7 +426,7 @@ def generate_multi_choice(self,
425426
if is_train:
426427
return self.train_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
427428
else:
428-
return self.test_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps)
429+
return self.test_generate(gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps, benchmark_testing = benchmark_testing)
429430

430431
def train_generate(self, gt_vn, narration, k, action_representation, n_narrations, labels, mapping_vn2narration, verb_maps, noun_maps):
431432
# letters as A, B, C, D, .. Note we maximally support 26 letters

0 commit comments

Comments
 (0)