This commit is contained in:
2026-02-27 09:44:48 +00:00
parent 1391fd9f4e
commit e028635027
3 changed files with 11 additions and 6 deletions

View File

@ -67,7 +67,7 @@ class VLMModel:
# ================================================================ # ================================================================
# 1. Vision Encoder Acceleration # 1. Vision Encoder Acceleration
# self._optimize_vision_encoder() self._optimize_vision_encoder()
# 2. KV Cache Management # 2. KV Cache Management
# self._optimize_kv_cache() # self._optimize_kv_cache()
@ -169,13 +169,18 @@ class VLMModel:
# 3. Replace: layer.self_attn.forward = optimized_attention # 3. Replace: layer.self_attn.forward = optimized_attention
# 4. Test: Run benchmark to verify improvement # 4. Test: Run benchmark to verify improvement
import types import types
from triton_layer_norm import path_forward from triton_layer_norm import path_norm_forward
for i in range(len(self._model.model.visual.blocks)): for i in range(len(self._model.model.visual.blocks)):
norm1 = self._model.model.visual.blocks[i].norm1 norm1 = self._model.model.visual.blocks[i].norm1
norm1.forward = types.MethodType(path_forward, norm1) norm1.forward = types.MethodType(path_norm_forward, norm1)
norm2 = self._model.model.visual.blocks[i].norm2 norm2 = self._model.model.visual.blocks[i].norm2
norm2.forward = types.MethodType(path_forward, norm2) norm2.forward = types.MethodType(path_norm_forward, norm2)
from triton_linear_gelu import path_linear_forward
for i in range(len(self._model.model.visual.blocks)):
mlp = self._model.model.visual.blocks[i].mlp
mlp.forward = types.MethodType(path_linear_forward, mlp)
if 'vision_encoder' not in self._optimizations_applied: if 'vision_encoder' not in self._optimizations_applied:
self._optimizations_applied.append('vision_encoder') self._optimizations_applied.append('vision_encoder')

View File

@ -75,7 +75,7 @@ def trition_layer_norm(x, normalized_shape, weight, bias, eps):
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
return y return y
def path_forward(self, input: Tensor) -> Tensor: def path_norm_forward(self, input: Tensor) -> Tensor:
return trition_layer_norm( return trition_layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps input, self.normalized_shape, self.weight, self.bias, self.eps
) )

View File

@ -133,5 +133,5 @@ def fc1_bias_gelu_triton(x, w, b,
import functools import functools
from torch import nn from torch import nn
def path_forward(self, hidden_state): def path_linear_forward(self, hidden_state):
return self.linear_fc2(fc1_bias_gelu_triton(hidden_state, self.linear_fc1.weight, self.linear_fc1.bias)) return self.linear_fc2(fc1_bias_gelu_triton(hidden_state, self.linear_fc1.weight, self.linear_fc1.bias))