tritionLayerNorm
This commit is contained in:
@ -34,7 +34,7 @@ class VLMModel:
|
||||
3. All optimizations are applied in __init__ by calling optimization methods.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: str, device: str = "cuda:0"):
|
||||
def __init__(self, model_path: str, device: str = "cpu"):
|
||||
"""
|
||||
Initialize model and apply optimizations.
|
||||
|
||||
@ -67,7 +67,7 @@ class VLMModel:
|
||||
# ================================================================
|
||||
|
||||
# 1. Vision Encoder Acceleration
|
||||
# self._optimize_vision_encoder()
|
||||
self._optimize_vision_encoder()
|
||||
|
||||
# 2. KV Cache Management
|
||||
# self._optimize_kv_cache()
|
||||
@ -168,6 +168,8 @@ class VLMModel:
|
||||
# 2. Inspect: print(self._model.vision_model) to find target layers
|
||||
# 3. Replace: layer.self_attn.forward = optimized_attention
|
||||
# 4. Test: Run benchmark to verify improvement
|
||||
from triton_layer_norm import path_forward
|
||||
self._model.model.visual.blocks[0].norm1.__class__.forward = path_forward
|
||||
|
||||
if 'vision_encoder' not in self._optimizations_applied:
|
||||
self._optimizations_applied.append('vision_encoder')
|
||||
|
||||
81
triton_layer_norm.py
Normal file
81
triton_layer_norm.py
Normal file
@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_fused(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride, # how much to increase the pointer when moving by 1 row
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
Y += row * stride
|
||||
X += row * stride
|
||||
# Compute mean
|
||||
mean = 0
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
_mean += a
|
||||
mean = tl.sum(_mean, axis=0) / N
|
||||
# Compute variance
|
||||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.where(cols < N, x - mean, 0.)
|
||||
_var += x * x
|
||||
var = tl.sum(_var, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
# Write mean / rstd
|
||||
tl.store(Mean + row, mean)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
for off in range(0, N, BLOCK_SIZE):
|
||||
cols = off + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask)
|
||||
b = tl.load(B + cols, mask=mask)
|
||||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd
|
||||
y = x_hat * w + b
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
def trition_layer_norm(x, normalized_shape, weight, bias, eps):
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
# reshape input data into 2D tensor
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
|
||||
rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
# enqueue kernel
|
||||
_layer_norm_fwd_fused[(M, )]( #
|
||||
x_arg, y, weight, bias, mean, rstd, #
|
||||
x_arg.stride(0), N, eps, #
|
||||
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
|
||||
return y
|
||||
|
||||
def path_forward(self, input: Tensor) -> Tensor:
|
||||
return trition_layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
Reference in New Issue
Block a user