8
8
9
9
#include < executorch/examples/models/llava/runner/llava_runner.h>
10
10
#include < gflags/gflags.h>
11
- #ifndef LLAVA_NO_TORCH_DUMMY_IMAGE
12
- #include < torch/torch.h>
13
- #else
14
- #include < algorithm> // std::fill
15
- #endif
11
+ #define STB_IMAGE_IMPLEMENTATION
12
+ #include < stb_image.h>
13
+ #define STB_IMAGE_RESIZE_IMPLEMENTATION
14
+ #include < stb_image_resize.h>
16
15
17
16
#if defined(ET_USE_THREADPOOL)
18
17
#include < executorch/extension/threadpool/cpuinfo_utils.h>
@@ -28,10 +27,7 @@ DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
28
27
29
28
DEFINE_string (prompt, " The answer to the ultimate question is" , " Prompt." );
30
29
31
- DEFINE_string (
32
- image_path,
33
- " " ,
34
- " The path to a .pt file, a serialized torch tensor for an image, longest edge resized to 336." );
30
+ DEFINE_string (image_path, " " , " The path to a .jpg file." );
35
31
36
32
DEFINE_double (
37
33
temperature,
@@ -50,6 +46,56 @@ DEFINE_int32(
50
46
51
47
using executorch::extension::llm::Image;
52
48
49
+ void load_image (const std::string& image_path, Image& image) {
50
+ int width, height, channels;
51
+ unsigned char * data =
52
+ stbi_load (image_path.c_str (), &width, &height, &channels, 0 );
53
+ if (!data) {
54
+ ET_LOG (Fatal, " Failed to load image: %s" , image_path.c_str ());
55
+ exit (1 );
56
+ }
57
+ // resize the longest edge to 336
58
+ int new_width = width;
59
+ int new_height = height;
60
+ if (width > height) {
61
+ new_width = 336 ;
62
+ new_height = static_cast <int >(height * 336.0 / width);
63
+ } else {
64
+ new_height = 336 ;
65
+ new_width = static_cast <int >(width * 336.0 / height);
66
+ }
67
+ std::vector<uint8_t > resized_data (new_width * new_height * channels);
68
+ stbir_resize_uint8 (
69
+ data,
70
+ width,
71
+ height,
72
+ 0 ,
73
+ resized_data.data (),
74
+ new_width,
75
+ new_height,
76
+ 0 ,
77
+ channels);
78
+ // transpose to CHW
79
+ image.data .resize (channels * new_width * new_height);
80
+ for (int i = 0 ; i < new_width * new_height; ++i) {
81
+ for (int c = 0 ; c < channels; ++c) {
82
+ image.data [c * new_width * new_height + i] =
83
+ resized_data[i * channels + c];
84
+ }
85
+ }
86
+ image.width = new_width;
87
+ image.height = new_height;
88
+ image.channels = channels;
89
+ // convert to tensor
90
+ ET_LOG (
91
+ Info,
92
+ " image Channels: %" PRId32 " , Height: %" PRId32 " , Width: %" PRId32,
93
+ image.channels ,
94
+ image.height ,
95
+ image.width );
96
+ stbi_image_free (data);
97
+ }
98
+
53
99
int32_t main (int32_t argc, char ** argv) {
54
100
gflags::ParseCommandLineFlags (&argc, &argv, true );
55
101
@@ -84,40 +130,9 @@ int32_t main(int32_t argc, char** argv) {
84
130
// create llama runner
85
131
example::LlavaRunner runner (model_path, tokenizer_path, temperature);
86
132
87
- // read image and resize the longest edge to 336
88
- std::vector<uint8_t > image_data;
89
-
90
- #ifdef LLAVA_NO_TORCH_DUMMY_IMAGE
91
- // Work without torch using a random data
92
- image_data.resize (3 * 240 * 336 );
93
- std::fill (image_data.begin (), image_data.end (), 0 ); // black
94
- std::array<int32_t , 3 > image_shape = {3 , 240 , 336 };
95
- std::vector<Image> images = {
96
- {.data = image_data, .width = image_shape[2 ], .height = image_shape[1 ]}};
97
- #else // LLAVA_NO_TORCH_DUMMY_IMAGE
98
- // cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);
99
- // int longest_edge = std::max(image.rows, image.cols);
100
- // float scale_factor = 336.0f / longest_edge;
101
- // cv::Size new_size(image.cols * scale_factor, image.rows * scale_factor);
102
- // cv::Mat resized_image;
103
- // cv::resize(image, resized_image, new_size);
104
- // image_data.assign(resized_image.datastart, resized_image.dataend);
105
- torch::Tensor image_tensor;
106
- torch::load (image_tensor, image_path); // CHW
107
- ET_LOG (
108
- Info,
109
- " image size(0): %" PRId64 " , size(1): %" PRId64 " , size(2): %" PRId64,
110
- image_tensor.size (0 ),
111
- image_tensor.size (1 ),
112
- image_tensor.size (2 ));
113
- image_data.assign (
114
- image_tensor.data_ptr <uint8_t >(),
115
- image_tensor.data_ptr <uint8_t >() + image_tensor.numel ());
116
- std::vector<Image> images = {
117
- {.data = image_data,
118
- .width = static_cast <int32_t >(image_tensor.size (2 )),
119
- .height = static_cast <int32_t >(image_tensor.size (1 ))}};
120
- #endif // LLAVA_NO_TORCH_DUMMY_IMAGE
133
+ Image image;
134
+ load_image (image_path, image);
135
+ std::vector<Image> images = {image};
121
136
122
137
// generate
123
138
runner.generate (std::move (images), prompt, seq_len);
0 commit comments