Skip to content

Commit b11807c

Browse files
authored
[llava] Remove torch.jit.save in llava example
Differential Revision: D74551753 Pull Request resolved: #10794
1 parent 0a30c42 commit b11807c

File tree

7 files changed

+108
-168
lines changed

7 files changed

+108
-168
lines changed

.ci/scripts/test_llava.sh

+9-18
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@ cmake_build_llava_runner_for_android() {
9393
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
9494
-DANDROID_ABI=arm64-v8a \
9595
${LLAVA_COMMON_CMAKE_ARGS} \
96-
-DCMAKE_PREFIX_PATH="$python_lib" \
97-
-DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \
96+
-DCMAKE_PREFIX_PATH="$python_lib" \
9897
-B${BUILD_DIR}/${dir} \
9998
${dir}
10099

@@ -107,11 +106,10 @@ export_llava() {
107106
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts
108107
}
109108

110-
# Download a new image with different size, to test if the model can handle different image sizes
111-
prepare_image_tensor() {
109+
# Download a new image
110+
download_image() {
112111
echo "Downloading image"
113112
curl -o basketball.jpg https://upload.wikimedia.org/wikipedia/commons/7/73/Chicago_Bulls_and_New_Jersey_Nets%2C_March_28%2C_1991.jpg
114-
$PYTHON_EXECUTABLE -m executorch.examples.models.llava.image_util --image-path basketball.jpg --output-path image.pt
115113
}
116114

117115
run_and_verify() {
@@ -121,20 +119,18 @@ run_and_verify() {
121119
echo "Export failed. Abort"
122120
exit 1
123121
fi
124-
if [[ ! -f "image.pt" ]]; then
125-
echo "image.pt is missing."
122+
if [[ ! -f "basketball.jpg" ]]; then
123+
echo "basketball.jpg is missing."
126124
exit 1
127125
fi
128126
if [[ ! -f "tokenizer.bin" ]]; then
129127
echo "tokenizer.bin is missing."
130128
exit 1
131129
fi
132130

133-
134-
135131
RUNTIME_ARGS="--model_path=llava.pte \
136132
--tokenizer_path=tokenizer.bin \
137-
--image_path=image.pt \
133+
--image_path=basketball.jpg \
138134
--prompt=ASSISTANT: \
139135
--temperature=0 \
140136
--seq_len=650"
@@ -149,13 +145,8 @@ run_and_verify() {
149145

150146
# verify result.txt
151147
RESULT=$(cat result.txt)
152-
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
153-
if [[ "$(uname)" == "Darwin" ]]; then
154-
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with several players on the court. One of the players is dribbling the ball, while the others are in various"
155-
else
156-
# set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes <unk> tokens.
157-
EXPECTED_PREFIX="ASSISTANT: image"
158-
fi
148+
EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with several players on the court. "
149+
159150
if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then
160151
echo "Expected result prefix: ${EXPECTED_PREFIX}"
161152
echo "Actual result: ${RESULT}"
@@ -184,5 +175,5 @@ fi
184175
export_llava
185176

186177
# Step3. Run
187-
prepare_image_tensor
178+
download_image
188179
run_and_verify

examples/models/llava/CMakeLists.txt

+14-15
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
# ~~~
1616
# It should also be cmake-lint clean.
1717
#
18-
cmake_minimum_required(VERSION 3.24) # 3.24 is required for WHOLE_ARCHIVE
18+
cmake_minimum_required(VERSION 3.24) # 3.24 is required for WHOLE_ARCHIVE
1919
project(llava)
2020

2121
# Duplicating options as root CMakeLists.txt
2222
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "Build the optimized kernels" OFF)
2323

24-
# This is a temporary hack to get around Torch dep so we can test this on android
25-
option(LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE "Hack option to feed dummy image to remove torch.load dep" OFF)
2624

2725
include(CMakeDependentOption)
2826
#
@@ -73,15 +71,6 @@ set(_common_include_directories ${EXECUTORCH_ROOT}/..)
7371
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
7472
find_package(gflags REQUIRED)
7573

76-
# Avoid torch dep from torch.load()-ing the image.
77-
# This is a temporary hack.
78-
if(LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE)
79-
add_definitions(-DLLAVA_NO_TORCH_DUMMY_IMAGE=1)
80-
message("Buidling the runner without Torch, feeding a dummy image!")
81-
else()
82-
find_package_torch()
83-
endif()
84-
8574
#
8675
# llava_main: test binary to run llava, with tokenizer and sampler integrated
8776
#
@@ -95,9 +84,6 @@ target_link_options_shared_lib(executorch)
9584
add_subdirectory(runner)
9685

9786
set(LINK_LIBS executorch gflags)
98-
if(NOT LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE)
99-
list(APPEND LINK_LIBS torch)
100-
endif()
10187
set(link_libraries ${LINK_LIBS})
10288
set(_srcs main.cpp)
10389

@@ -197,6 +183,19 @@ if(ANDROID)
197183
list(APPEND link_libraries log)
198184
endif()
199185

186+
# stb_image: a lightweight library to load images
187+
include(FetchContent)
188+
FetchContent_Declare(
189+
stb
190+
GIT_REPOSITORY https://github.com/nothings/stb.git
191+
GIT_TAG f0569113c93ad095470c54bf34a17b36646bbbb5
192+
)
193+
FetchContent_MakeAvailable(stb)
194+
# Add deprecated/ to use stb_image_resize.h for internal compatibility
195+
list(APPEND _common_include_directories ${stb_SOURCE_DIR}
196+
${stb_SOURCE_DIR}/deprecated
197+
)
198+
200199
add_executable(llava_main ${_srcs})
201200
if(CMAKE_BUILD_TYPE STREQUAL "Release")
202201
if(APPLE)

examples/models/llava/export_llava.py

-10
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from executorch.examples.models.llama.source_transformation.sdpa import (
3131
replace_sdpa_with_custom_op,
3232
)
33-
from executorch.examples.models.llava.image_util import serialize_image
3433
from executorch.examples.models.llava.model import LlavaModel
3534
from executorch.exir import (
3635
EdgeCompileConfig,
@@ -44,7 +43,6 @@
4443
ConstraintBasedSymShapeEvalPass,
4544
HintBasedSymShapeEvalPass,
4645
)
47-
4846
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
4947
from executorch.util.activation_memory_profiler import generate_memory_trace
5048
from pytorch_tokenizers.llama2c import Llama2cTokenizer as Tokenizer
@@ -265,13 +263,6 @@ def export_all(llava_model: LlavaModel):
265263
return executorch_program
266264

267265

268-
def get_image_tensor_for_llava_runner(llava_model):
269-
# llava runner doesn't have image reader so an image tensor is needed.
270-
(resized,) = llava_model.get_example_inputs()
271-
272-
serialize_image(resized, "image.pt")
273-
274-
275266
def get_tokenizer_for_llava_runner(llava_model):
276267
# serialize tokenizer into tokenizer.bin
277268
llava_model.tokenizer.save_vocabulary("./")
@@ -336,7 +327,6 @@ def main():
336327

337328
# artifacts
338329
if args.with_artifacts:
339-
get_image_tensor_for_llava_runner(llava_model)
340330
get_tokenizer_for_llava_runner(llava_model)
341331

342332

examples/models/llava/image_util.py

-79
This file was deleted.

examples/models/llava/main.cpp

+58-43
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
#include <executorch/examples/models/llava/runner/llava_runner.h>
1010
#include <gflags/gflags.h>
11-
#ifndef LLAVA_NO_TORCH_DUMMY_IMAGE
12-
#include <torch/torch.h>
13-
#else
14-
#include <algorithm> // std::fill
15-
#endif
11+
#define STB_IMAGE_IMPLEMENTATION
12+
#include <stb_image.h>
13+
#define STB_IMAGE_RESIZE_IMPLEMENTATION
14+
#include <stb_image_resize.h>
1615

1716
#if defined(ET_USE_THREADPOOL)
1817
#include <executorch/extension/threadpool/cpuinfo_utils.h>
@@ -28,10 +27,7 @@ DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
2827

2928
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
3029

31-
DEFINE_string(
32-
image_path,
33-
"",
34-
"The path to a .pt file, a serialized torch tensor for an image, longest edge resized to 336.");
30+
DEFINE_string(image_path, "", "The path to a .jpg file.");
3531

3632
DEFINE_double(
3733
temperature,
@@ -50,6 +46,56 @@ DEFINE_int32(
5046

5147
using executorch::extension::llm::Image;
5248

49+
void load_image(const std::string& image_path, Image& image) {
50+
int width, height, channels;
51+
unsigned char* data =
52+
stbi_load(image_path.c_str(), &width, &height, &channels, 0);
53+
if (!data) {
54+
ET_LOG(Fatal, "Failed to load image: %s", image_path.c_str());
55+
exit(1);
56+
}
57+
// resize the longest edge to 336
58+
int new_width = width;
59+
int new_height = height;
60+
if (width > height) {
61+
new_width = 336;
62+
new_height = static_cast<int>(height * 336.0 / width);
63+
} else {
64+
new_height = 336;
65+
new_width = static_cast<int>(width * 336.0 / height);
66+
}
67+
std::vector<uint8_t> resized_data(new_width * new_height * channels);
68+
stbir_resize_uint8(
69+
data,
70+
width,
71+
height,
72+
0,
73+
resized_data.data(),
74+
new_width,
75+
new_height,
76+
0,
77+
channels);
78+
// transpose to CHW
79+
image.data.resize(channels * new_width * new_height);
80+
for (int i = 0; i < new_width * new_height; ++i) {
81+
for (int c = 0; c < channels; ++c) {
82+
image.data[c * new_width * new_height + i] =
83+
resized_data[i * channels + c];
84+
}
85+
}
86+
image.width = new_width;
87+
image.height = new_height;
88+
image.channels = channels;
89+
// convert to tensor
90+
ET_LOG(
91+
Info,
92+
"image Channels: %" PRId32 ", Height: %" PRId32 ", Width: %" PRId32,
93+
image.channels,
94+
image.height,
95+
image.width);
96+
stbi_image_free(data);
97+
}
98+
5399
int32_t main(int32_t argc, char** argv) {
54100
gflags::ParseCommandLineFlags(&argc, &argv, true);
55101

@@ -84,40 +130,9 @@ int32_t main(int32_t argc, char** argv) {
84130
// create llama runner
85131
example::LlavaRunner runner(model_path, tokenizer_path, temperature);
86132

87-
// read image and resize the longest edge to 336
88-
std::vector<uint8_t> image_data;
89-
90-
#ifdef LLAVA_NO_TORCH_DUMMY_IMAGE
91-
// Work without torch using a random data
92-
image_data.resize(3 * 240 * 336);
93-
std::fill(image_data.begin(), image_data.end(), 0); // black
94-
std::array<int32_t, 3> image_shape = {3, 240, 336};
95-
std::vector<Image> images = {
96-
{.data = image_data, .width = image_shape[2], .height = image_shape[1]}};
97-
#else // LLAVA_NO_TORCH_DUMMY_IMAGE
98-
// cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);
99-
// int longest_edge = std::max(image.rows, image.cols);
100-
// float scale_factor = 336.0f / longest_edge;
101-
// cv::Size new_size(image.cols * scale_factor, image.rows * scale_factor);
102-
// cv::Mat resized_image;
103-
// cv::resize(image, resized_image, new_size);
104-
// image_data.assign(resized_image.datastart, resized_image.dataend);
105-
torch::Tensor image_tensor;
106-
torch::load(image_tensor, image_path); // CHW
107-
ET_LOG(
108-
Info,
109-
"image size(0): %" PRId64 ", size(1): %" PRId64 ", size(2): %" PRId64,
110-
image_tensor.size(0),
111-
image_tensor.size(1),
112-
image_tensor.size(2));
113-
image_data.assign(
114-
image_tensor.data_ptr<uint8_t>(),
115-
image_tensor.data_ptr<uint8_t>() + image_tensor.numel());
116-
std::vector<Image> images = {
117-
{.data = image_data,
118-
.width = static_cast<int32_t>(image_tensor.size(2)),
119-
.height = static_cast<int32_t>(image_tensor.size(1))}};
120-
#endif // LLAVA_NO_TORCH_DUMMY_IMAGE
133+
Image image;
134+
load_image(image_path, image);
135+
std::vector<Image> images = {image};
121136

122137
// generate
123138
runner.generate(std::move(images), prompt, seq_len);

0 commit comments

Comments
 (0)