back fp32

This commit is contained in:
2026-02-27 03:15:41 +00:00
parent 3e1436c8d4
commit 3a3c59238c
2 changed files with 16 additions and 10 deletions

View File

@ -67,7 +67,7 @@ class VLMModel:
# ================================================================
# 1. Vision Encoder Acceleration
self._optimize_vision_encoder()
# self._optimize_vision_encoder()
# 2. KV Cache Management
# self._optimize_kv_cache()
@ -168,8 +168,14 @@ class VLMModel:
# 2. Inspect: print(self._model.vision_model) to find target layers
# 3. Replace: layer.self_attn.forward = optimized_attention
# 4. Test: Run benchmark to verify improvement
import types
from triton_layer_norm import path_forward
self._model.model.visual.blocks[0].norm1.__class__.forward = path_forward
for i in range(len(self._model.model.visual.blocks)):
norm1 = self._model.model.visual.blocks[i].norm1
norm1.forward = types.MethodType(path_forward, norm1)
norm2 = self._model.model.visual.blocks[i].norm2
norm2.forward = types.MethodType(path_forward, norm2)
if 'vision_encoder' not in self._optimizations_applied:
self._optimizations_applied.append('vision_encoder')

View File

@ -23,18 +23,18 @@ def _layer_norm_fwd_fused(
X += row * stride
# Compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float16)
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float16)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float16)
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float16)
x = tl.where(cols < N, x - mean, 0.).to(tl.float16)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
@ -47,7 +47,7 @@ def _layer_norm_fwd_fused(
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float16)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
@ -59,8 +59,8 @@ def trition_layer_norm(x, normalized_shape, weight, bias, eps):
# reshape input data into 2D tensor
# x_arg = x.reshape(-1, x.shape[-1])
M, N = x.shape
mean = torch.empty((M, ), dtype=torch.float16, device=x.device)
rstd = torch.empty((M, ), dtype=torch.float16, device=x.device)
mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))