From 2f4420bb2d62837187af106082a44add1648a6c3 Mon Sep 17 00:00:00 2001 From: noctis <970308389@qq.com> Date: Fri, 27 Feb 2026 09:32:13 +0800 Subject: [PATCH] tritionLayerNorm --- evaluation_wrapper.py | 6 ++-- triton_layer_norm.py | 81 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 triton_layer_norm.py diff --git a/evaluation_wrapper.py b/evaluation_wrapper.py index 4f6e9b9..14e60f3 100755 --- a/evaluation_wrapper.py +++ b/evaluation_wrapper.py @@ -34,7 +34,7 @@ class VLMModel: 3. All optimizations are applied in __init__ by calling optimization methods. """ - def __init__(self, model_path: str, device: str = "cuda:0"): + def __init__(self, model_path: str, device: str = "cpu"): """ Initialize model and apply optimizations. @@ -67,7 +67,7 @@ class VLMModel: # ================================================================ # 1. Vision Encoder Acceleration - # self._optimize_vision_encoder() + self._optimize_vision_encoder() # 2. KV Cache Management # self._optimize_kv_cache() @@ -168,6 +168,8 @@ class VLMModel: # 2. Inspect: print(self._model.vision_model) to find target layers # 3. Replace: layer.self_attn.forward = optimized_attention # 4. Test: Run benchmark to verify improvement + from triton_layer_norm import path_forward + self._model.model.visual.blocks[0].norm1.__class__.forward = path_forward if 'vision_encoder' not in self._optimizations_applied: self._optimizations_applied.append('vision_encoder') diff --git a/triton_layer_norm.py b/triton_layer_norm.py new file mode 100644 index 0000000..a009e8a --- /dev/null +++ b/triton_layer_norm.py @@ -0,0 +1,81 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + +def trition_layer_norm(x, normalized_shape, weight, bias, eps): + # allocate output + y = torch.empty_like(x) + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + # enqueue kernel + _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + return y + +def path_forward(self, input: Tensor) -> Tensor: + return trition_layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) \ No newline at end of file