all test
This commit is contained in:
@ -67,7 +67,7 @@ class VLMModel:
|
||||
# ================================================================
|
||||
|
||||
# 1. Vision Encoder Acceleration
|
||||
# self._optimize_vision_encoder()
|
||||
self._optimize_vision_encoder()
|
||||
|
||||
# 2. KV Cache Management
|
||||
# self._optimize_kv_cache()
|
||||
@ -169,13 +169,18 @@ class VLMModel:
|
||||
# 3. Replace: layer.self_attn.forward = optimized_attention
|
||||
# 4. Test: Run benchmark to verify improvement
|
||||
import types
|
||||
from triton_layer_norm import path_forward
|
||||
from triton_layer_norm import path_norm_forward
|
||||
|
||||
for i in range(len(self._model.model.visual.blocks)):
|
||||
norm1 = self._model.model.visual.blocks[i].norm1
|
||||
norm1.forward = types.MethodType(path_forward, norm1)
|
||||
norm1.forward = types.MethodType(path_norm_forward, norm1)
|
||||
norm2 = self._model.model.visual.blocks[i].norm2
|
||||
norm2.forward = types.MethodType(path_forward, norm2)
|
||||
norm2.forward = types.MethodType(path_norm_forward, norm2)
|
||||
|
||||
from triton_linear_gelu import path_linear_forward
|
||||
for i in range(len(self._model.model.visual.blocks)):
|
||||
mlp = self._model.model.visual.blocks[i].mlp
|
||||
mlp.forward = types.MethodType(path_linear_forward, mlp)
|
||||
|
||||
if 'vision_encoder' not in self._optimizations_applied:
|
||||
self._optimizations_applied.append('vision_encoder')
|
||||
|
||||
@ -75,7 +75,7 @@ def trition_layer_norm(x, normalized_shape, weight, bias, eps):
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
|
||||
return y
|
||||
|
||||
def path_forward(self, input: Tensor) -> Tensor:
|
||||
def path_norm_forward(self, input: Tensor) -> Tensor:
|
||||
return trition_layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
@ -133,5 +133,5 @@ def fc1_bias_gelu_triton(x, w, b,
|
||||
import functools
|
||||
from torch import nn
|
||||
|
||||
def path_forward(self, hidden_state):
|
||||
def path_linear_forward(self, hidden_state):
|
||||
return self.linear_fc2(fc1_bias_gelu_triton(hidden_state, self.linear_fc1.weight, self.linear_fc1.bias))
|
||||
Reference in New Issue
Block a user