Files
AICASGC/my_patch.py
2026-02-26 08:15:21 +00:00

365 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor, Qwen3VLProcessorKwargs
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModelOutputWithPast, BaseModelOutputWithDeepstackFeatures
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging, TransformersKwargs, can_return_tuple
from transformers.video_utils import VideoInput
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
logger = logging.get_logger(__name__)
class myQwen3VLProcessor(Qwen3VLProcessor):
def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
super().__init__(image_processor, tokenizer, video_processor, chat_template, **kwargs)
def __call__(
self,
images: ImageInput = None,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
videos: VideoInput = None,
**kwargs: Unpack[Qwen3VLProcessorKwargs],
) -> BatchFeature:
r"""
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen3VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]
# If user has not requested video metadata, pop it
if not kwargs.get("return_metadata"):
video_metadata = videos_inputs.pop("video_metadata")
else:
video_metadata = videos_inputs["video_metadata"]
else:
videos_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
text = text.copy() # below lines change text in-place
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
# num_image_tokens = image_grid_thw[index].prod() // merge_length
num_image_tokens = 40
text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.video_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
metadata = video_metadata[index]
if metadata.fps is None:
logger.warning_once(
"Qwen3VL requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
"Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
"Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
)
metadata.fps = 24 if metadata.fps is None else metadata.fps
# if timestamps are not provided, calculate them
curr_timestamp = self._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
self.video_processor.temporal_patch_size,
)
video_placeholder = ""
frame_seqlen = video_grid_thw[index][1:].prod() // merge_length
for frame_idx in range(video_grid_thw[index][0]):
curr_time = curr_timestamp[frame_idx]
video_placeholder += f"<{curr_time:.1f} seconds>"
video_placeholder += (
self.vision_start_token + "<|placeholder|>" * frame_seqlen + self.vision_end_token
)
if f"{self.vision_start_token}{self.video_token}{self.vision_end_token}" in text[i]:
text[i] = text[i].replace(
f"{self.vision_start_token}{self.video_token}{self.vision_end_token}", video_placeholder, 1
)
else:
# vllm may input video token directly
text[i] = text[i].replace(self.video_token, video_placeholder, 1)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
if return_mm_token_type_ids:
array_ids = np.array(text_inputs["input_ids"])
mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
mm_token_type_ids[array_ids == self.image_token_id] = 1
text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
def _sample_indices_uniform(idx: torch.LongTensor, keep_ratio: float, min_keep: int = 0):
"""
idx: 1D indices in original sequence (sorted)
keep_ratio: 0~1, keep uniformly spaced
"""
n = idx.numel()
if n == 0:
return idx
k = max(min_keep, int(torch.ceil(torch.tensor(n * keep_ratio)).item()))
k = min(k, n)
if k == n:
return idx
# uniform pick: linspace over [0, n-1]
pos = torch.linspace(0, n - 1, steps=k, device=idx.device)
pos = pos.round().long().clamp(0, n - 1)
return idx[pos]
def sparse_keep_and_gather(
inputs_embeds, # (B,S,D)
attention_mask, # (B,S)
position_ids, # (4,B,S)
visual_pos_masks, # (B,S) bool
deepstack_visual_embeds,# list[tensor] each (Nvis_total,D) OR None
keep_ratio: float = 0.25,
min_keep_per_vis: int = 0,
max_len: int | None = None,
):
"""
稀疏保留:保留全部文本 token视觉 token 按 keep_ratio 均匀采样保留。
可选 max_len如果最终还超长再从视觉 token 里继续裁(不动文本)。
"""
device = inputs_embeds.device
B, S, D = inputs_embeds.shape
eff = attention_mask.bool()
keep_mask_token = torch.zeros((B, S), dtype=torch.bool, device=device)
for b in range(B):
eff_idx = eff[b].nonzero(as_tuple=False).squeeze(1) # 有效 token
if eff_idx.numel() == 0:
continue
vis_eff = visual_pos_masks[b, eff_idx] # 有效里哪些是视觉
text_idx = eff_idx[~vis_eff] # 全保留
vis_idx = eff_idx[vis_eff] # 待稀疏
# 视觉稀疏采样(删中间就靠这一步)
kept_vis = _sample_indices_uniform(vis_idx, keep_ratio, min_keep=min_keep_per_vis)
chosen = torch.cat([text_idx, kept_vis], dim=0)
chosen, _ = torch.sort(chosen) # 保持原序
# 如果还要控最大长度:优先继续裁视觉(不裁文本)
if max_len is not None and chosen.numel() > max_len:
# 已保留的视觉位置
chosen_vis = chosen[visual_pos_masks[b, chosen]]
chosen_txt = chosen[~visual_pos_masks[b, chosen]]
# 文本若已超 max_len只能截文本极少
if chosen_txt.numel() >= max_len:
chosen = chosen_txt[:max_len]
else:
budget = max_len - chosen_txt.numel()
# 对视觉再均匀裁到 budget
chosen_vis = _sample_indices_uniform(chosen_vis, budget / max(chosen_vis.numel(), 1))
chosen = torch.cat([chosen_txt, chosen_vis], dim=0)
chosen, _ = torch.sort(chosen)
keep_mask_token[b, chosen] = True
# ===== gather + pad 到 batch 内最大长度 =====
keep_lens = keep_mask_token.sum(dim=1).tolist()
max_keep = max(keep_lens) if keep_lens else 0
new_inputs = inputs_embeds.new_zeros((B, max_keep, D))
new_attn = attention_mask.new_zeros((B, max_keep))
new_pos = position_ids.new_zeros((4, B, max_keep))
new_vis = visual_pos_masks.new_zeros((B, max_keep), dtype=torch.bool)
for b in range(B):
idx = keep_mask_token[b].nonzero(as_tuple=False).squeeze(1)
L = idx.numel()
if L == 0:
continue
new_inputs[b, :L, :] = inputs_embeds[b, idx, :]
new_attn[b, :L] = attention_mask[b, idx]
new_pos[:, b, :L] = position_ids[:, b, idx]
new_vis[b, :L] = visual_pos_masks[b, idx]
# ===== deepstack 同步裁剪(关键!)=====
new_deepstack = None
if deepstack_visual_embeds is not None:
# deepstack 的顺序 = visual_pos_masks flatten 后 True 的顺序
# 所以用 keep_mask_token 在这些位置的布尔值来裁剪
keep_vis_flat = keep_mask_token[visual_pos_masks] # 1D bool, length = Nvis_total
new_deepstack = [x[keep_vis_flat] for x in deepstack_visual_embeds]
return new_inputs, new_attn, new_pos, new_vis, new_deepstack
@can_return_tuple
def patch_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Qwen3VLModelOutputWithPast:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_mask = None
video_mask = None
if pixel_values is not None:
image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features(
pixel_values, image_grid_thw, return_dict=True
)
image_embeds = image_outputs.pooler_output
deepstack_image_embeds = image_outputs.deepstack_features
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True
)
video_embeds = video_outputs.pooler_output
deepstack_video_embeds = video_outputs.deepstack_features
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
visual_pos_masks = None
deepstack_visual_embeds = None
if image_mask is not None and video_mask is not None:
# aggregate visual_pos_masks and deepstack_visual_embeds
image_mask = image_mask[..., 0]
video_mask = video_mask[..., 0]
visual_pos_masks = image_mask | video_mask
deepstack_visual_embeds = []
image_mask_joint = image_mask[visual_pos_masks]
video_mask_joint = video_mask[visual_pos_masks]
for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
embed_joint[image_mask_joint, :] = img_embed
embed_joint[video_mask_joint, :] = vid_embed
deepstack_visual_embeds.append(embed_joint)
elif image_mask is not None:
image_mask = image_mask[..., 0]
visual_pos_masks = image_mask
deepstack_visual_embeds = deepstack_image_embeds
elif video_mask is not None:
video_mask = video_mask[..., 0]
visual_pos_masks = video_mask
deepstack_visual_embeds = deepstack_video_embeds
if position_ids is None:
position_ids = self.compute_3d_position_ids(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
# ====== 稀疏采样裁剪:只在 prefill 做past_key_values is None=====
if past_key_values.get_seq_length() == 0 and visual_pos_masks is not None:
# 这些参数你可以通过 kwargs 传入
keep_ratio = kwargs.pop("visual_keep_ratio", 0.1) # 只保留 25% 视觉 token
min_keep = kwargs.pop("min_keep_per_vis", 0) # 每段视觉最少保留多少(可设比如 16
max_len = kwargs.pop("truncate_max_len", None) # 总长度上限(可选)
inputs_embeds, attention_mask, position_ids, visual_pos_masks, deepstack_visual_embeds = sparse_keep_and_gather(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
keep_ratio=keep_ratio,
min_keep_per_vis=min_keep,
max_len=max_len,
)
# cache_position 建议重建为 0..L-1避免对齐问题
cache_position = torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
).unsqueeze(0).expand(inputs_embeds.shape[0], -1)
# rope_deltas 建议也按裁剪后的序列重算(防止不一致)
eff_len = attention_mask.sum(dim=1).to(torch.long) # (B,)
max_pos = position_ids.max(dim=0).values.max(dim=1).values # (B,)
self.rope_deltas = (max_pos + 1 - eff_len).unsqueeze(1)
# ====== 裁剪结束 ======
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
**kwargs,
)
return Qwen3VLModelOutputWithPast(
**outputs,
rope_deltas=self.rope_deltas,
)