diff --git a/evaluation_wrapper.py b/evaluation_wrapper.py index 910dda4..56a1e26 100755 --- a/evaluation_wrapper.py +++ b/evaluation_wrapper.py @@ -67,7 +67,7 @@ class VLMModel: # ================================================================ # 1. Vision Encoder Acceleration - self._optimize_vision_encoder() + # self._optimize_vision_encoder() # 2. KV Cache Management # self._optimize_kv_cache() @@ -168,8 +168,14 @@ class VLMModel: # 2. Inspect: print(self._model.vision_model) to find target layers # 3. Replace: layer.self_attn.forward = optimized_attention # 4. Test: Run benchmark to verify improvement + import types from triton_layer_norm import path_forward - self._model.model.visual.blocks[0].norm1.__class__.forward = path_forward + + for i in range(len(self._model.model.visual.blocks)): + norm1 = self._model.model.visual.blocks[i].norm1 + norm1.forward = types.MethodType(path_forward, norm1) + norm2 = self._model.model.visual.blocks[i].norm2 + norm2.forward = types.MethodType(path_forward, norm2) if 'vision_encoder' not in self._optimizations_applied: self._optimizations_applied.append('vision_encoder') diff --git a/triton_layer_norm.py b/triton_layer_norm.py index 7744939..9b7702b 100644 --- a/triton_layer_norm.py +++ b/triton_layer_norm.py @@ -23,18 +23,18 @@ def _layer_norm_fwd_fused( X += row * stride # Compute mean mean = 0 - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float16) + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float16) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) _mean += a mean = tl.sum(_mean, axis=0) / N # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float16) + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float16) - x = tl.where(cols < N, x - mean, 0.).to(tl.float16) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) @@ -47,7 +47,7 @@ def _layer_norm_fwd_fused( mask = cols < N w = tl.load(W + cols, mask=mask) b = tl.load(B + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float16) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) x_hat = (x - mean) * rstd y = x_hat * w + b # Write output @@ -59,8 +59,8 @@ def trition_layer_norm(x, normalized_shape, weight, bias, eps): # reshape input data into 2D tensor # x_arg = x.reshape(-1, x.shape[-1]) M, N = x.shape - mean = torch.empty((M, ), dtype=torch.float16, device=x.device) - rstd = torch.empty((M, ), dtype=torch.float16, device=x.device) + mean = torch.empty((M, ), dtype=torch.float32, device=x.device) + rstd = torch.empty((M, ), dtype=torch.float32, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))