Skip to content
47 changes: 47 additions & 0 deletions tests/models/multimodal/processing/test_glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from vllm.assets.video import VideoAsset
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.video import OpenCVDynamicVideoBackend, OpenCVVideoBackend

from ...utils import build_model_context

Expand Down Expand Up @@ -50,3 +51,49 @@ def test_processor_override(

assert grid_t == expected_grid_t
assert video_tok_count == expected_toks_per_frame * grid_t


@pytest.mark.parametrize("model_id", ["zai-org/GLM-4.1V-9B-Thinking"])
@pytest.mark.parametrize("fps", [2])
def test_video_loader_consistency(
model_id: str,
fps: int,
):
"""
Ensure dynamic video loader (pre-sampled by loader) and normal video
loader (post-sampled by processor) produce same video processing outputs.
"""
ctx = build_model_context(
model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"video": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
hf_processor_mm_kwargs = {"fps": fps}

# Build the image str / prompt based on the number of images we pass
prompt = "<|begin_of_video|><|video|><|end_of_video|>"

video_path = VideoAsset(name="baby_reading", num_frames=-1).video_path
with open(video_path, "rb") as f:
video_bytes = f.read()

static_video, static_metadata = OpenCVVideoBackend.load_bytes(video_bytes)
dynamic_video, dynamic_metadata = OpenCVDynamicVideoBackend.load_bytes(
video_bytes, requested_fps=fps)

# pre-sampled loader shouldn't read all frames
assert len(dynamic_video) < len(static_video)

static_mm_data = {"video": [(static_video, static_metadata)]}
dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]}

static_outputs = processor.apply(prompt, static_mm_data,
hf_processor_mm_kwargs)
dynamic_outputs = processor.apply(prompt, dynamic_mm_data,
hf_processor_mm_kwargs)

assert static_outputs["prompt_token_ids"] == dynamic_outputs[
"prompt_token_ids"]
assert static_outputs["mm_kwargs"].get_data(
) == dynamic_outputs["mm_kwargs"].get_data()
26 changes: 26 additions & 0 deletions tests/multimodal/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,32 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
assert metadata_sync == metadata_async


@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("max_duration", [1, 60, 1800])
@pytest.mark.parametrize("requested_fps", [2, 24])
async def test_fetch_video_http_with_dynamic_loader(
video_url: str, max_duration: int, requested_fps: int,
monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic")
connector = MediaConnector(
media_io_kwargs={
"video": {
"max_duration": max_duration,
"requested_fps": requested_fps,
}
})

video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(
video_url)

assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async
assert metadata_sync["video_backend"] == "opencv_dynamic"


# Used for `test_argsort_mm_positions`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"
Expand Down
16 changes: 8 additions & 8 deletions vllm/assets/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,23 @@ class VideoAsset:
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]

@property
def video_path(self) -> str:
return download_video_asset(self.filename)

@property
def pil_images(self) -> list[Image.Image]:
video_path = download_video_asset(self.filename)
ret = video_to_pil_images_list(video_path, self.num_frames)
ret = video_to_pil_images_list(self.video_path, self.num_frames)
return ret

@property
def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.filename)
ret = video_to_ndarrays(video_path, self.num_frames)
ret = video_to_ndarrays(self.video_path, self.num_frames)
return ret

@property
def metadata(self) -> dict[str, Any]:
video_path = download_video_asset(self.filename)
ret = video_get_metadata(video_path)
ret = video_get_metadata(self.video_path)
return ret

def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
Expand All @@ -134,5 +135,4 @@ def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:

See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.filename)
return librosa.load(video_path, sr=sampling_rate)[0]
return librosa.load(self.video_path, sr=sampling_rate)[0]
100 changes: 60 additions & 40 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,43 @@ def _get_video_second_idx(self, metadata: dict[str, Any],
selected_timestamps.append(timestamps_list[idx])
return selected_timestamps

def _construct_video_placeholder(
self,
video_array: np.ndarray,
metadata: dict[str, Any],
grid_thw: torch.Tensor,
) -> str:
hf_processor = self.get_hf_processor()
tokenizer = self.get_tokenizer()
image_processor = hf_processor.image_processor

hf_config = self.get_hf_config()
boi_token_id = hf_config.image_start_token_id
eoi_token_id = hf_config.image_end_token_id
bov_token_id = hf_config.video_start_token_id
eov_token_id = hf_config.video_end_token_id
merge_length = image_processor.merge_size**2

assert isinstance(grid_thw, torch.Tensor)
timestamps = self._get_video_second_idx(metadata, len(video_array))
frames_idx_token = [
tokenizer.encode(str(i), add_special_tokens=False)
for i in timestamps
]
T, H, W = grid_thw
num_tokens_per_frame = int(H * W) // merge_length
placeholder = []
placeholder.append(bov_token_id)
for frame_idx in frames_idx_token:
placeholder.append(boi_token_id)
placeholder.extend([hf_processor.video_token_id] *
num_tokens_per_frame)
placeholder.append(eoi_token_id)
placeholder.extend(frame_idx)
placeholder.append(eov_token_id)

return placeholder


class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):

Expand Down Expand Up @@ -1118,17 +1155,10 @@ def _call_hf_processor(
for item in mm_data.pop("videos", []):
video_array, metadata = item

# FIXME(Isotr0py): Activate the below logic after we can disable
# resampling from video loader backend.
# assert metadata["total_num_frames"] == len(video_array), (
# f"Total frames {metadata['total_num_frames']} does not "
# f"match the length of video array {len(video_array)}.")
if metadata["video_backend"] == "opencv_dynamic":
mm_kwargs["do_sample_frames"] = False

# NOTE: Temporary workaround for resampled videos.
# this can cause a divergence with HF implementation if
# the input video is resampled in advance.

if metadata["total_num_frames"] != len(video_array):
elif metadata["total_num_frames"] != len(video_array):
logger.warning(
"Total frames in metadata "
"(%s) does not match the length of "
Expand All @@ -1140,23 +1170,34 @@ def _call_hf_processor(
len(video_array),
)
metadata["total_num_frames"] = len(video_array)
metadata = VideoMetadata(**metadata)

video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
video_mm_data["video_metadata"] = [[VideoMetadata(**metadata)]]

video_outputs = super()._call_hf_processor(
prompt="<|begin_of_video|><|video|><|end_of_video|>",
mm_data=video_mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_ids = video_outputs.pop("input_ids")
input_ids[input_ids == processor.image_token_id] = (
processor.video_token_id)
video_placeholder = processor.tokenizer.batch_decode(
input_ids)[0]
if "do_sample_frames" in mm_kwargs and not mm_kwargs[
"do_sample_frames"]:
# Transformers v4.55 has incorrect timestamps issue for
Copy link
Member

@DarkLight1337 DarkLight1337 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a link to the relevant issue so we know when to remove this workaround?

Copy link
Member Author

@Isotr0py Isotr0py Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root issue is the hardcoded 24 fps in Transformers v4.55's no sampling code path:
https://github.com/huggingface/transformers/blob/d79b2d981f28b2730d402244ac3c2e9a8c054eee/src/transformers/models/glm4v/video_processing_glm4v.py#L173-L176

I think huggingface/transformers#39600 should have fixed this issue. And we can remove this after Transformers v4.56 update. (Although current GLM4.1V's vLLM multimodal processor is broken on Transformers v4.56, I would like to fix it in following PR together 😅)

# skip sampling. We construct the placeholder manually to
# get placeholders with correct timestamps.
placeholder = self.info._construct_video_placeholder(
video_array,
metadata,
video_outputs["video_grid_thw"].squeeze(0),
)
video_placeholder = processor.tokenizer.decode(placeholder)
else:
input_ids = video_outputs.pop("input_ids")
input_ids[input_ids == processor.image_token_id] = (
processor.video_token_id)
video_placeholder = processor.tokenizer.batch_decode(
input_ids)[0]
prompt = prompt.replace(
"<|begin_of_video|><|video|><|end_of_video|>",
video_placeholder,
Expand Down Expand Up @@ -1202,14 +1243,6 @@ def _get_prompt_updates(
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
hf_config = self.info.get_hf_config()

boi_token_id = hf_config.image_start_token_id
eoi_token_id = hf_config.image_end_token_id

bov_token_id = hf_config.video_start_token_id
eov_token_id = hf_config.video_end_token_id

merge_length = image_processor.merge_size**2

Expand All @@ -1227,21 +1260,8 @@ def get_video_replacement_glm4v(item_idx: int):
assert isinstance(grid_thw, torch.Tensor)

video, metadata = mm_items["video"][item_idx]
timestamps = self.info._get_video_second_idx(metadata, len(video))
frames_idx_token = [
tokenizer.encode(str(i), add_special_tokens=False)
for i in timestamps
]
num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length
placeholder = []
placeholder.append(bov_token_id)
for frame_idx in frames_idx_token:
placeholder.append(boi_token_id)
placeholder.extend([hf_processor.video_token_id] *
num_tokens_per_frame)
placeholder.append(eoi_token_id)
placeholder.extend(frame_idx)
placeholder.append(eov_token_id)
placeholder = self.info._construct_video_placeholder(
video, metadata, grid_thw)
return PromptUpdateDetails.select_token_id(
placeholder,
embed_token_id=hf_processor.video_token_id,
Expand Down
Loading
Loading