back fp32
This commit is contained in:
@ -67,7 +67,7 @@ class VLMModel:
|
||||
# ================================================================
|
||||
|
||||
# 1. Vision Encoder Acceleration
|
||||
self._optimize_vision_encoder()
|
||||
# self._optimize_vision_encoder()
|
||||
|
||||
# 2. KV Cache Management
|
||||
# self._optimize_kv_cache()
|
||||
@ -168,8 +168,14 @@ class VLMModel:
|
||||
# 2. Inspect: print(self._model.vision_model) to find target layers
|
||||
# 3. Replace: layer.self_attn.forward = optimized_attention
|
||||
# 4. Test: Run benchmark to verify improvement
|
||||
import types
|
||||
from triton_layer_norm import path_forward
|
||||
self._model.model.visual.blocks[0].norm1.__class__.forward = path_forward
|
||||
|
||||
for i in range(len(self._model.model.visual.blocks)):
|
||||
norm1 = self._model.model.visual.blocks[i].norm1
|
||||
norm1.forward = types.MethodType(path_forward, norm1)
|
||||
norm2 = self._model.model.visual.blocks[i].norm2
|
||||
norm2.forward = types.MethodType(path_forward, norm2)
|
||||
|
||||
if 'vision_encoder' not in self._optimizations_applied:
|
||||
self._optimizations_applied.append('vision_encoder')
|
||||
|
||||
@ -23,18 +23,18 @@ def _layer_norm_fwd_fused(
|
||||
X += row * stride
|
||||
# Compute mean
|
||||
mean = 0
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float16)
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float16)
|
||||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# Compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float16)
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float16)
|
||||
x = tl.where(cols < N, x - mean, 0.).to(tl.float16)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.where(cols < N, x - mean, 0.)
|
||||
_var += x * x
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
@ -47,7 +47,7 @@ def _layer_norm_fwd_fused(
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
b = tl.load(B + cols, mask=mask)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float16)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd
|
||||
y = x_hat * w + b
|
||||
# Write output
|
||||
@ -59,8 +59,8 @@ def trition_layer_norm(x, normalized_shape, weight, bias, eps):
|
||||
# reshape input data into 2D tensor
|
||||
# x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x.shape
|
||||
mean = torch.empty((M, ), dtype=torch.float16, device=x.device)
|
||||
rstd = torch.empty((M, ), dtype=torch.float16, device=x.device)
|
||||
mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
|
||||
Reference in New Issue
Block a user