Files
AICASGC/my_patch.py
2026-02-26 16:29:38 +08:00

295 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor, Qwen3VLProcessorKwargs
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModelOutputWithPast, BaseModelOutputWithDeepstackFeatures
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging, TransformersKwargs, can_return_tuple
from transformers.video_utils import VideoInput
from transformers.cache_utils import Cache
from transformers.processing_utils import Unpack
import os
import time
logger = logging.get_logger(__name__)
def _sample_indices_uniform(idx: torch.LongTensor, keep_ratio: float, min_keep: int = 0):
"""
idx: 1D indices in original sequence (sorted)
keep_ratio: 0~1, keep uniformly spaced
"""
n = idx.numel()
if n == 0:
return idx
k = max(min_keep, int(torch.ceil(torch.tensor(n * keep_ratio)).item()))
k = min(k, n)
if k == n:
return idx
# uniform pick: linspace over [0, n-1]
pos = torch.linspace(0, n - 1, steps=k, device=idx.device)
pos = pos.round().long().clamp(0, n - 1)
return idx[pos]
def sparse_keep_and_gather(
inputs_embeds, # (B,S,D)
attention_mask, # (B,S)
position_ids, # (4,B,S)
visual_pos_masks, # (B,S) bool
deepstack_visual_embeds,# list[tensor] each (Nvis_total,D) OR None
keep_ratio: float = 0.25,
min_keep_per_vis: int = 0,
max_len: int | None = None,
):
"""
稀疏保留:保留全部文本 token视觉 token 按 keep_ratio 均匀采样保留。
可选 max_len如果最终还超长再从视觉 token 里继续裁(不动文本)。
"""
device = inputs_embeds.device
B, S, D = inputs_embeds.shape
eff = attention_mask.bool()
keep_mask_token = torch.zeros((B, S), dtype=torch.bool, device=device)
for b in range(B):
eff_idx = eff[b].nonzero(as_tuple=False).squeeze(1) # 有效 token
if eff_idx.numel() == 0:
continue
vis_eff = visual_pos_masks[b, eff_idx] # 有效里哪些是视觉
text_idx = eff_idx[~vis_eff] # 全保留
vis_idx = eff_idx[vis_eff] # 待稀疏
# 视觉稀疏采样(删中间就靠这一步)
kept_vis = _sample_indices_uniform(vis_idx, keep_ratio, min_keep=min_keep_per_vis)
chosen = torch.cat([text_idx, kept_vis], dim=0)
chosen, _ = torch.sort(chosen) # 保持原序
# 如果还要控最大长度:优先继续裁视觉(不裁文本)
if max_len is not None and chosen.numel() > max_len:
# 已保留的视觉位置
chosen_vis = chosen[visual_pos_masks[b, chosen]]
chosen_txt = chosen[~visual_pos_masks[b, chosen]]
# 文本若已超 max_len只能截文本极少
if chosen_txt.numel() >= max_len:
chosen = chosen_txt[:max_len]
else:
budget = max_len - chosen_txt.numel()
# 对视觉再均匀裁到 budget
chosen_vis = _sample_indices_uniform(chosen_vis, budget / max(chosen_vis.numel(), 1))
chosen = torch.cat([chosen_txt, chosen_vis], dim=0)
chosen, _ = torch.sort(chosen)
keep_mask_token[b, chosen] = True
# ===== gather + pad 到 batch 内最大长度 =====
keep_lens = keep_mask_token.sum(dim=1).tolist()
max_keep = max(keep_lens) if keep_lens else 0
new_inputs = inputs_embeds.new_zeros((B, max_keep, D))
new_attn = attention_mask.new_zeros((B, max_keep))
new_pos = position_ids.new_zeros((4, B, max_keep))
new_vis = visual_pos_masks.new_zeros((B, max_keep), dtype=torch.bool)
for b in range(B):
idx = keep_mask_token[b].nonzero(as_tuple=False).squeeze(1)
L = idx.numel()
if L == 0:
continue
new_inputs[b, :L, :] = inputs_embeds[b, idx, :]
new_attn[b, :L] = attention_mask[b, idx]
new_pos[:, b, :L] = position_ids[:, b, idx]
new_vis[b, :L] = visual_pos_masks[b, idx]
# ===== deepstack 同步裁剪(关键!)=====
new_deepstack = None
if deepstack_visual_embeds is not None:
# deepstack 的顺序 = visual_pos_masks flatten 后 True 的顺序
# 所以用 keep_mask_token 在这些位置的布尔值来裁剪
keep_vis_flat = keep_mask_token[visual_pos_masks] # 1D bool, length = Nvis_total
new_deepstack = [x[keep_vis_flat] for x in deepstack_visual_embeds]
return new_inputs, new_attn, new_pos, new_vis, new_deepstack
@can_return_tuple
def patch_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
pixel_values: torch.Tensor | None = None,
pixel_values_videos: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
video_grid_thw: torch.LongTensor | None = None,
cache_position: torch.LongTensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | Qwen3VLModelOutputWithPast:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
def _sync():
# 只在 CUDA 上同步,避免 CPU 模式报错
if torch.cuda.is_available() and inputs_embeds is not None and inputs_embeds.is_cuda:
torch.cuda.synchronize()
def _ms(t0):
return (time.perf_counter() - t0) * 1000.0
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_mask = None
video_mask = None
if pixel_values is not None:
_sync()
t_img = time.perf_counter()
image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features(
pixel_values, image_grid_thw, return_dict=True
)
_sync()
print(f"[VLPATCH_DEBUG] get_image_features: {_ms(t_img):.3f} ms")
image_embeds = image_outputs.pooler_output
deepstack_image_embeds = image_outputs.deepstack_features
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
image_mask, _ = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features(
pixel_values_videos, video_grid_thw, return_dict=True
)
video_embeds = video_outputs.pooler_output
deepstack_video_embeds = video_outputs.deepstack_features
video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
_, video_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
visual_pos_masks = None
deepstack_visual_embeds = None
if image_mask is not None and video_mask is not None:
# aggregate visual_pos_masks and deepstack_visual_embeds
image_mask = image_mask[..., 0]
video_mask = video_mask[..., 0]
visual_pos_masks = image_mask | video_mask
deepstack_visual_embeds = []
image_mask_joint = image_mask[visual_pos_masks]
video_mask_joint = video_mask[visual_pos_masks]
for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
embed_joint[image_mask_joint, :] = img_embed
embed_joint[video_mask_joint, :] = vid_embed
deepstack_visual_embeds.append(embed_joint)
elif image_mask is not None:
image_mask = image_mask[..., 0]
visual_pos_masks = image_mask
deepstack_visual_embeds = deepstack_image_embeds
elif video_mask is not None:
video_mask = video_mask[..., 0]
visual_pos_masks = video_mask
deepstack_visual_embeds = deepstack_video_embeds
if position_ids is None:
position_ids = self.compute_3d_position_ids(
input_ids=input_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
)
# ====== 稀疏采样裁剪:只在 prefill 做past_key_values is None=====
if past_key_values.get_seq_length() == 0 and visual_pos_masks is not None:
# 这些参数你可以通过 kwargs 传入
keep_ratio = kwargs.pop("visual_keep_ratio", 0.1) # 只保留 25% 视觉 token
min_keep = kwargs.pop("min_keep_per_vis", 0) # 每段视觉最少保留多少(可设比如 16
max_len = kwargs.pop("truncate_max_len", None) # 总长度上限(可选)
# 裁剪前统计
L0 = inputs_embeds.shape[1]
nvis0 = int(visual_pos_masks.sum().item()) if visual_pos_masks is not None else -1
eff0 = int(attention_mask.sum().item()) if attention_mask is not None else -1
print(f"[VLPATCH_DEBUG] BEFORE prune: L={L0}, visual={nvis0}, eff={eff0}")
_sync()
t_prune = time.perf_counter()
inputs_embeds, attention_mask, position_ids, visual_pos_masks, deepstack_visual_embeds = sparse_keep_and_gather(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
keep_ratio=keep_ratio,
min_keep_per_vis=min_keep,
max_len=max_len,
)
_sync()
print(f"[VLPATCH_DEBUG] sparse_keep_and_gather: {_ms(t_prune):.3f} ms")
L1 = inputs_embeds.shape[1]
nvis1 = int(visual_pos_masks.sum().item()) if visual_pos_masks is not None else -1
eff1 = int(attention_mask.sum().item()) if attention_mask is not None else -1
print(f"[VLPATCH_DEBUG] AFTER prune: L={L1}, visual={nvis1}, eff={eff1}")
if L0 > 0 and nvis0 >= 0:
print(f"[VLPATCH_DEBUG] ΔL={L1-L0} ({(L1/L0*100):.1f}%), "
f"Δvisual={nvis1-nvis0} ({(nvis1/max(nvis0,1)*100):.1f}%)")
# cache_position 建议重建为 0..L-1避免对齐问题
cache_position = torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
).unsqueeze(0).expand(inputs_embeds.shape[0], -1)
# rope_deltas 建议也按裁剪后的序列重算(防止不一致)
eff_len = attention_mask.sum(dim=1).to(torch.long) # (B,)
max_pos = position_ids.max(dim=0).values.max(dim=1).values # (B,)
self.rope_deltas = (max_pos + 1 - eff_len).unsqueeze(1)
# ====== 裁剪结束 ======
_sync()
t_lm = time.perf_counter()
outputs = self.language_model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
**kwargs,
)
_sync()
print(f"[VLPATCH_DEBUG] language_model: {_ms(t_lm):.3f} ms")
return Qwen3VLModelOutputWithPast(
**outputs,
rope_deltas=self.rope_deltas,
)