From c275119444e4d57d0f99ee93cde084d52de79e8e Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 12 May 2025 08:01:35 -0700 Subject: [PATCH] Remove torch.jit.save in llava example (#10794) Summary: As titled. [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: " label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). Test Plan: Rely on CI test [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable. Reviewed By: iseeyuan Differential Revision: D74551753 Pulled By: larryliu0820 --- .ci/scripts/test_llava.sh | 27 +++---- examples/models/llava/CMakeLists.txt | 29 ++++---- examples/models/llava/export_llava.py | 10 --- examples/models/llava/image_util.py | 79 -------------------- examples/models/llava/main.cpp | 101 +++++++++++++++----------- examples/models/llava/model.py | 28 ++++++- examples/models/llava/targets.bzl | 2 +- 7 files changed, 108 insertions(+), 168 deletions(-) delete mode 100644 examples/models/llava/image_util.py diff --git a/.ci/scripts/test_llava.sh b/.ci/scripts/test_llava.sh index 8a1d5683b33..9a0251c9a38 100644 --- a/.ci/scripts/test_llava.sh +++ b/.ci/scripts/test_llava.sh @@ -93,8 +93,7 @@ cmake_build_llava_runner_for_android() { -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ -DANDROID_ABI=arm64-v8a \ ${LLAVA_COMMON_CMAKE_ARGS} \ - -DCMAKE_PREFIX_PATH="$python_lib" \ - -DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \ + -DCMAKE_PREFIX_PATH="$python_lib" \ -B${BUILD_DIR}/${dir} \ ${dir} @@ -107,11 +106,10 @@ export_llava() { $PYTHON_EXECUTABLE -m executorch.examples.models.llava.export_llava --pte-name llava.pte --with-artifacts } -# Download a new image with different size, to test if the model can handle different image sizes -prepare_image_tensor() { +# Download a new image +download_image() { echo "Downloading image" curl -o basketball.jpg https://upload.wikimedia.org/wikipedia/commons/7/73/Chicago_Bulls_and_New_Jersey_Nets%2C_March_28%2C_1991.jpg - $PYTHON_EXECUTABLE -m executorch.examples.models.llava.image_util --image-path basketball.jpg --output-path image.pt } run_and_verify() { @@ -121,8 +119,8 @@ run_and_verify() { echo "Export failed. Abort" exit 1 fi - if [[ ! -f "image.pt" ]]; then - echo "image.pt is missing." + if [[ ! -f "basketball.jpg" ]]; then + echo "basketball.jpg is missing." exit 1 fi if [[ ! -f "tokenizer.bin" ]]; then @@ -130,11 +128,9 @@ run_and_verify() { exit 1 fi - - RUNTIME_ARGS="--model_path=llava.pte \ --tokenizer_path=tokenizer.bin \ - --image_path=image.pt \ + --image_path=basketball.jpg \ --prompt=ASSISTANT: \ --temperature=0 \ --seq_len=650" @@ -149,13 +145,8 @@ run_and_verify() { # verify result.txt RESULT=$(cat result.txt) - # set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes tokens. - if [[ "$(uname)" == "Darwin" ]]; then - EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with several players on the court. One of the players is dribbling the ball, while the others are in various" - else - # set the expected prefix to be the same as prompt because there's a bug in sdpa_with_kv_cache that causes tokens. - EXPECTED_PREFIX="ASSISTANT: image" - fi + EXPECTED_PREFIX="ASSISTANT: image captures a basketball game in progress, with several players on the court. " + if [[ "${RESULT}" == *"${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT}" @@ -184,5 +175,5 @@ fi export_llava # Step3. Run -prepare_image_tensor +download_image run_and_verify diff --git a/examples/models/llava/CMakeLists.txt b/examples/models/llava/CMakeLists.txt index 232e83d8b0a..fe3eb5628b2 100644 --- a/examples/models/llava/CMakeLists.txt +++ b/examples/models/llava/CMakeLists.txt @@ -15,14 +15,12 @@ # ~~~ # It should also be cmake-lint clean. # -cmake_minimum_required(VERSION 3.24) # 3.24 is required for WHOLE_ARCHIVE +cmake_minimum_required(VERSION 3.24) # 3.24 is required for WHOLE_ARCHIVE project(llava) # Duplicating options as root CMakeLists.txt option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "Build the optimized kernels" OFF) -# This is a temporary hack to get around Torch dep so we can test this on android -option(LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE "Hack option to feed dummy image to remove torch.load dep" OFF) include(CMakeDependentOption) # @@ -73,15 +71,6 @@ set(_common_include_directories ${EXECUTORCH_ROOT}/..) set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) find_package(gflags REQUIRED) -# Avoid torch dep from torch.load()-ing the image. -# This is a temporary hack. -if(LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE) - add_definitions(-DLLAVA_NO_TORCH_DUMMY_IMAGE=1) - message("Buidling the runner without Torch, feeding a dummy image!") -else() - find_package_torch() -endif() - # # llava_main: test binary to run llava, with tokenizer and sampler integrated # @@ -95,9 +84,6 @@ target_link_options_shared_lib(executorch) add_subdirectory(runner) set(LINK_LIBS executorch gflags) -if(NOT LLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE) - list(APPEND LINK_LIBS torch) -endif() set(link_libraries ${LINK_LIBS}) set(_srcs main.cpp) @@ -197,6 +183,19 @@ if(ANDROID) list(APPEND link_libraries log) endif() +# stb_image: a lightweight library to load images +include(FetchContent) +FetchContent_Declare( + stb + GIT_REPOSITORY https://github.com/nothings/stb.git + GIT_TAG f0569113c93ad095470c54bf34a17b36646bbbb5 +) +FetchContent_MakeAvailable(stb) +# Add deprecated/ to use stb_image_resize.h for internal compatibility +list(APPEND _common_include_directories ${stb_SOURCE_DIR} + ${stb_SOURCE_DIR}/deprecated +) + add_executable(llava_main ${_srcs}) if(CMAKE_BUILD_TYPE STREQUAL "Release") if(APPLE) diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index 66b61840866..60c21897e7f 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -30,7 +30,6 @@ from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) -from executorch.examples.models.llava.image_util import serialize_image from executorch.examples.models.llava.model import LlavaModel from executorch.exir import ( EdgeCompileConfig, @@ -44,7 +43,6 @@ ConstraintBasedSymShapeEvalPass, HintBasedSymShapeEvalPass, ) - from executorch.extension.llm.export.builder import DType, LLMEdgeManager from executorch.util.activation_memory_profiler import generate_memory_trace from pytorch_tokenizers.llama2c import Llama2cTokenizer as Tokenizer @@ -265,13 +263,6 @@ def export_all(llava_model: LlavaModel): return executorch_program -def get_image_tensor_for_llava_runner(llava_model): - # llava runner doesn't have image reader so an image tensor is needed. - (resized,) = llava_model.get_example_inputs() - - serialize_image(resized, "image.pt") - - def get_tokenizer_for_llava_runner(llava_model): # serialize tokenizer into tokenizer.bin llava_model.tokenizer.save_vocabulary("./") @@ -336,7 +327,6 @@ def main(): # artifacts if args.with_artifacts: - get_image_tensor_for_llava_runner(llava_model) get_tokenizer_for_llava_runner(llava_model) diff --git a/examples/models/llava/image_util.py b/examples/models/llava/image_util.py deleted file mode 100644 index 3f78f0a6ed6..00000000000 --- a/examples/models/llava/image_util.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Utility functions for image processing. Run it with your image: - -# python image_util.py --image-path - -import logging -from argparse import ArgumentParser - -import torch -import torchvision -from PIL import Image -from torch import nn - - -FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" -logging.basicConfig(level=logging.INFO, format=FORMAT) - - -# pyre-ignore: Undefined or invalid type [11]: Annotation `Image` is not defined as a type. -def prepare_image(image: Image, target_h: int, target_w: int) -> torch.Tensor: - """Read image into a tensor and resize the image so that it fits in - a target_h x target_w canvas. - - Args: - image (Image): An Image object. - target_h (int): Target height. - target_w (int): Target width. - - Returns: - torch.Tensor: resized image tensor. - """ - img = torchvision.transforms.functional.pil_to_tensor(image) - # height ratio - ratio_h = img.shape[1] / target_h - # width ratio - ratio_w = img.shape[2] / target_w - # resize the image so that it fits in a target_h x target_w canvas - ratio = max(ratio_h, ratio_w) - output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) - img = torchvision.transforms.Resize(size=output_size)(img) - return img - - -def serialize_image(image: torch.Tensor, path: str) -> None: - copy = torch.tensor(image) - m = nn.Module() - par = nn.Parameter(copy, requires_grad=False) - m.register_parameter("0", par) - tensors = torch.jit.script(m) - tensors.save(path) - - logging.info(f"Saved image tensor to {path}") - - -def main(): - parser = ArgumentParser() - parser.add_argument( - "--image-path", - required=True, - help="Path to the image.", - ) - parser.add_argument( - "--output-path", - default="image.pt", - ) - args = parser.parse_args() - - image = Image.open(args.image_path) - image_tensor = prepare_image(image, target_h=336, target_w=336) - serialize_image(image_tensor, args.output_path) - - -if __name__ == "__main__": - main() diff --git a/examples/models/llava/main.cpp b/examples/models/llava/main.cpp index b01b33f5dd8..bdf191a789c 100644 --- a/examples/models/llava/main.cpp +++ b/examples/models/llava/main.cpp @@ -8,11 +8,10 @@ #include #include -#ifndef LLAVA_NO_TORCH_DUMMY_IMAGE -#include -#else -#include // std::fill -#endif +#define STB_IMAGE_IMPLEMENTATION +#include +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#include #if defined(ET_USE_THREADPOOL) #include @@ -28,10 +27,7 @@ DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); -DEFINE_string( - image_path, - "", - "The path to a .pt file, a serialized torch tensor for an image, longest edge resized to 336."); +DEFINE_string(image_path, "", "The path to a .jpg file."); DEFINE_double( temperature, @@ -50,6 +46,56 @@ DEFINE_int32( using executorch::extension::llm::Image; +void load_image(const std::string& image_path, Image& image) { + int width, height, channels; + unsigned char* data = + stbi_load(image_path.c_str(), &width, &height, &channels, 0); + if (!data) { + ET_LOG(Fatal, "Failed to load image: %s", image_path.c_str()); + exit(1); + } + // resize the longest edge to 336 + int new_width = width; + int new_height = height; + if (width > height) { + new_width = 336; + new_height = static_cast(height * 336.0 / width); + } else { + new_height = 336; + new_width = static_cast(width * 336.0 / height); + } + std::vector resized_data(new_width * new_height * channels); + stbir_resize_uint8( + data, + width, + height, + 0, + resized_data.data(), + new_width, + new_height, + 0, + channels); + // transpose to CHW + image.data.resize(channels * new_width * new_height); + for (int i = 0; i < new_width * new_height; ++i) { + for (int c = 0; c < channels; ++c) { + image.data[c * new_width * new_height + i] = + resized_data[i * channels + c]; + } + } + image.width = new_width; + image.height = new_height; + image.channels = channels; + // convert to tensor + ET_LOG( + Info, + "image Channels: %" PRId32 ", Height: %" PRId32 ", Width: %" PRId32, + image.channels, + image.height, + image.width); + stbi_image_free(data); +} + int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -84,40 +130,9 @@ int32_t main(int32_t argc, char** argv) { // create llama runner example::LlavaRunner runner(model_path, tokenizer_path, temperature); - // read image and resize the longest edge to 336 - std::vector image_data; - -#ifdef LLAVA_NO_TORCH_DUMMY_IMAGE - // Work without torch using a random data - image_data.resize(3 * 240 * 336); - std::fill(image_data.begin(), image_data.end(), 0); // black - std::array image_shape = {3, 240, 336}; - std::vector images = { - {.data = image_data, .width = image_shape[2], .height = image_shape[1]}}; -#else // LLAVA_NO_TORCH_DUMMY_IMAGE - // cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR); - // int longest_edge = std::max(image.rows, image.cols); - // float scale_factor = 336.0f / longest_edge; - // cv::Size new_size(image.cols * scale_factor, image.rows * scale_factor); - // cv::Mat resized_image; - // cv::resize(image, resized_image, new_size); - // image_data.assign(resized_image.datastart, resized_image.dataend); - torch::Tensor image_tensor; - torch::load(image_tensor, image_path); // CHW - ET_LOG( - Info, - "image size(0): %" PRId64 ", size(1): %" PRId64 ", size(2): %" PRId64, - image_tensor.size(0), - image_tensor.size(1), - image_tensor.size(2)); - image_data.assign( - image_tensor.data_ptr(), - image_tensor.data_ptr() + image_tensor.numel()); - std::vector images = { - {.data = image_data, - .width = static_cast(image_tensor.size(2)), - .height = static_cast(image_tensor.size(1))}}; -#endif // LLAVA_NO_TORCH_DUMMY_IMAGE + Image image; + load_image(image_path, image); + std::vector images = {image}; // generate runner.generate(std::move(images), prompt, seq_len); diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 7bcf560536c..1050fbdfae1 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -12,6 +12,7 @@ import requests import torch +import torchvision from executorch.examples.models.llama.llama_transformer import construct_transformer from executorch.examples.models.llama.model_args import ModelArgs @@ -21,8 +22,6 @@ from executorch.examples.models.llama.source_transformation.sdpa import ( replace_sdpa_with_custom_op, ) - -from executorch.examples.models.llava.image_util import prepare_image from executorch.examples.models.model_base import EagerModelBase from PIL import Image @@ -37,6 +36,31 @@ ) +# pyre-ignore: Undefined or invalid type [11]: Annotation `Image` is not defined as a type. +def prepare_image(image: Image, target_h: int, target_w: int) -> torch.Tensor: + """Read image into a tensor and resize the image so that it fits in + a target_h x target_w canvas. + + Args: + image (Image): An Image object. + target_h (int): Target height. + target_w (int): Target width. + + Returns: + torch.Tensor: resized image tensor. + """ + img = torchvision.transforms.functional.pil_to_tensor(image) + # height ratio + ratio_h = img.shape[1] / target_h + # width ratio + ratio_w = img.shape[2] / target_w + # resize the image so that it fits in a target_h x target_w canvas + ratio = max(ratio_h, ratio_w) + output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) + img = torchvision.transforms.Resize(size=output_size)(img) + return img + + class Llava(torch.nn.Module): def __init__( self, diff --git a/examples/models/llava/targets.bzl b/examples/models/llava/targets.bzl index 6f3a370acf4..bc653e37144 100644 --- a/examples/models/llava/targets.bzl +++ b/examples/models/llava/targets.bzl @@ -15,7 +15,7 @@ def define_common_targets(): ], external_deps = [ "gflags", - "torch-core-cpp", + "stb", ], **get_oss_build_kwargs() )