This commit is contained in:
2026-02-26 10:51:08 +08:00
parent 92f33eb431
commit bac7838dcd
2 changed files with 32 additions and 26 deletions

View File

@ -34,7 +34,7 @@ class VLMModel:
3. All optimizations are applied in __init__ by calling optimization methods.
"""
def __init__(self, model_path: str, device: str = "cuda:0"):
def __init__(self, model_path: str, device: str = "cpu"):
"""
Initialize model and apply optimizations.

View File

@ -249,6 +249,9 @@ def patch_forward(
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
"""
import time
start = time.time()
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
@ -316,34 +319,34 @@ def patch_forward(
past_key_values=past_key_values,
)
# ====== 稀疏采样裁剪:只在 prefill 做past_key_values is None=====
if past_key_values.get_seq_length() == 0 and visual_pos_masks is not None:
# 这些参数你可以通过 kwargs 传入
keep_ratio = kwargs.pop("visual_keep_ratio", 0.1) # 只保留 25% 视觉 token
min_keep = kwargs.pop("min_keep_per_vis", 0) # 每段视觉最少保留多少(可设比如 16
max_len = kwargs.pop("truncate_max_len", None) # 总长度上限(可选)
# # ====== 稀疏采样裁剪:只在 prefill 做past_key_values is None=====
# if past_key_values.get_seq_length() == 0 and visual_pos_masks is not None:
# # 这些参数你可以通过 kwargs 传入
# keep_ratio = kwargs.pop("visual_keep_ratio", 0.1) # 只保留 25% 视觉 token
# min_keep = kwargs.pop("min_keep_per_vis", 0) # 每段视觉最少保留多少(可设比如 16
# max_len = kwargs.pop("truncate_max_len", None) # 总长度上限(可选)
inputs_embeds, attention_mask, position_ids, visual_pos_masks, deepstack_visual_embeds = sparse_keep_and_gather(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
visual_pos_masks=visual_pos_masks,
deepstack_visual_embeds=deepstack_visual_embeds,
keep_ratio=keep_ratio,
min_keep_per_vis=min_keep,
max_len=max_len,
)
# inputs_embeds, attention_mask, position_ids, visual_pos_masks, deepstack_visual_embeds = sparse_keep_and_gather(
# inputs_embeds=inputs_embeds,
# attention_mask=attention_mask,
# position_ids=position_ids,
# visual_pos_masks=visual_pos_masks,
# deepstack_visual_embeds=deepstack_visual_embeds,
# keep_ratio=keep_ratio,
# min_keep_per_vis=min_keep,
# max_len=max_len,
# )
# cache_position 建议重建为 0..L-1避免对齐问题
cache_position = torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
).unsqueeze(0).expand(inputs_embeds.shape[0], -1)
# # cache_position 建议重建为 0..L-1避免对齐问题
# cache_position = torch.arange(
# inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
# ).unsqueeze(0).expand(inputs_embeds.shape[0], -1)
# rope_deltas 建议也按裁剪后的序列重算(防止不一致)
eff_len = attention_mask.sum(dim=1).to(torch.long) # (B,)
max_pos = position_ids.max(dim=0).values.max(dim=1).values # (B,)
self.rope_deltas = (max_pos + 1 - eff_len).unsqueeze(1)
# ====== 裁剪结束 ======
# # rope_deltas 建议也按裁剪后的序列重算(防止不一致)
# eff_len = attention_mask.sum(dim=1).to(torch.long) # (B,)
# max_pos = position_ids.max(dim=0).values.max(dim=1).values # (B,)
# self.rope_deltas = (max_pos + 1 - eff_len).unsqueeze(1)
# # ====== 裁剪结束 ======
outputs = self.language_model(
input_ids=None,
@ -357,6 +360,9 @@ def patch_forward(
**kwargs,
)
end = time.time()
print('程序运行时间:%s毫秒' % ((end - start)*1000))
return Qwen3VLModelOutputWithPast(
**outputs,
rope_deltas=self.rope_deltas,