2026-02-27 09:32:09 +00:00
|
|
|
import triton
|
|
|
|
|
import triton.language as tl
|
|
|
|
|
|
|
|
|
|
@triton.jit
|
|
|
|
|
def _fc1_bias_gelu_kernel(
|
|
|
|
|
X_ptr, W_ptr, B_ptr, Y_ptr,
|
|
|
|
|
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
|
|
|
|
stride_xm, stride_xk,
|
|
|
|
|
stride_wn, stride_wk, # W is [N, K]
|
|
|
|
|
stride_ym, stride_yn,
|
|
|
|
|
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
|
|
|
|
|
GROUP_SIZE_M: tl.constexpr
|
|
|
|
|
):
|
|
|
|
|
"""Kernel for computing the matmul C = A x B.
|
|
|
|
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
|
|
|
|
"""
|
|
|
|
|
# -----------------------------------------------------------
|
|
|
|
|
# Map program ids `pid` to the block of C it should compute.
|
|
|
|
|
# This is done in a grouped ordering to promote L2 data reuse.
|
|
|
|
|
# See above `L2 Cache Optimizations` section for details.
|
|
|
|
|
pid = tl.program_id(axis=0)
|
|
|
|
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
|
|
|
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
|
|
|
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
|
|
|
|
group_id = pid // num_pid_in_group
|
|
|
|
|
first_pid_m = group_id * GROUP_SIZE_M
|
|
|
|
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
|
|
|
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
|
|
|
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------
|
|
|
|
|
# Add some integer bound assumptions.
|
|
|
|
|
# This helps to guide integer analysis in the backend to optimize
|
|
|
|
|
# load/store offset address calculation
|
|
|
|
|
tl.assume(pid_m >= 0)
|
|
|
|
|
tl.assume(pid_n >= 0)
|
|
|
|
|
tl.assume(stride_xm > 0)
|
|
|
|
|
tl.assume(stride_xk > 0)
|
|
|
|
|
tl.assume(stride_wn > 0)
|
|
|
|
|
tl.assume(stride_wk > 0)
|
|
|
|
|
tl.assume(stride_ym > 0)
|
|
|
|
|
tl.assume(stride_yn > 0)
|
|
|
|
|
|
|
|
|
|
# ----------------------------------------------------------
|
|
|
|
|
# Create pointers for the first blocks of A and B.
|
|
|
|
|
# We will advance this pointer as we move in the K direction
|
|
|
|
|
# and accumulate
|
|
|
|
|
# `X_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
|
|
|
|
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
|
|
|
|
# See above `Pointer Arithmetic` section for details
|
|
|
|
|
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
|
|
|
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
|
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
|
|
|
X_ptrs = X_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
|
|
|
|
|
W_ptrs = W_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------
|
|
|
|
|
# Iterate to compute a block of the C matrix.
|
|
|
|
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
|
|
|
|
# of fp32 values for higher accuracy.
|
|
|
|
|
# `accumulator` will be converted back to fp16 after the loop.
|
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
|
|
|
# Load the next block of A and B, generate a mask by checking the K dimension.
|
|
|
|
|
# If it is out of bounds, set it to 0.
|
|
|
|
|
a = tl.load(X_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
|
b = tl.load(W_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
|
|
|
|
# We accumulate along the K dimension.
|
|
|
|
|
accumulator = tl.dot(a, b, accumulator)
|
|
|
|
|
# Advance the ptrs to the next K block.
|
|
|
|
|
X_ptrs += BLOCK_SIZE_K * stride_xk
|
|
|
|
|
W_ptrs += BLOCK_SIZE_K * stride_wk
|
|
|
|
|
|
|
|
|
|
# Bias add (broadcast over M)
|
|
|
|
|
b = tl.load(B_ptr + offs_wn, mask=offs_wn < N, other=0.0).to(tl.float32)
|
|
|
|
|
accumulator = accumulator + b[None, :]
|
|
|
|
|
|
|
|
|
|
# GELU(tanh) epilogue (fp32 compute)
|
|
|
|
|
# You can fuse arbitrary activation functions here
|
|
|
|
|
# while the accumulator is still in FP32!
|
|
|
|
|
c0 = 0.7978845608028654
|
|
|
|
|
c1 = 0.044715
|
|
|
|
|
y = c0 * (accumulator + c1 * accumulator * accumulator * accumulator)
|
|
|
|
|
# exp_2y = tl.exp(2.0 * y)
|
|
|
|
|
# tanh_y = (exp_2y - 1.0) / (exp_2y + 1.0)
|
|
|
|
|
ay = tl.abs(y)
|
|
|
|
|
e = tl.exp(-2.0 * ay)
|
|
|
|
|
t = (1.0 - e) / (1.0 + e)
|
|
|
|
|
t = tl.where(y >= 0, t, -t)
|
|
|
|
|
accumulator = 0.5 * accumulator * (1.0 + t)
|
|
|
|
|
|
|
|
|
|
y = accumulator.to(tl.float16)
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------
|
|
|
|
|
# Write back the block of the output matrix C with masks.
|
|
|
|
|
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
|
|
|
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
|
|
|
Y_ptrs = Y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
|
|
|
|
|
Y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
|
|
|
|
|
tl.store(Y_ptrs, y, mask=Y_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fc1_bias_gelu_triton(x, w, b,
|
|
|
|
|
BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=32, GROUP_SIZE_M=8, num_warps=8):
|
|
|
|
|
"""
|
|
|
|
|
x: [M, K]
|
|
|
|
|
w: [N, K] (PyTorch Linear weight layout)
|
|
|
|
|
b: [N]
|
|
|
|
|
returns y: [M, N]
|
|
|
|
|
"""
|
|
|
|
|
import torch
|
|
|
|
|
assert x.is_cuda and w.is_cuda and b.is_cuda
|
|
|
|
|
M, K = x.shape
|
|
|
|
|
N, K2 = w.shape
|
|
|
|
|
assert K == K2 and b.shape[0] == N
|
|
|
|
|
|
|
|
|
|
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
|
|
|
|
|
|
|
|
|
|
_fc1_bias_gelu_kernel[grid](
|
|
|
|
|
x, w, b, y,
|
|
|
|
|
M=M, N=N, K=K,
|
|
|
|
|
stride_xm=x.stride(0), stride_xk=x.stride(1),
|
|
|
|
|
stride_wn=w.stride(0), stride_wk=w.stride(1),
|
|
|
|
|
stride_ym=y.stride(0), stride_yn=y.stride(1),
|
|
|
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
|
|
|
GROUP_SIZE_M=GROUP_SIZE_M,
|
|
|
|
|
num_warps=num_warps
|
|
|
|
|
)
|
|
|
|
|
return y
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
from torch import nn
|
|
|
|
|
|
2026-02-27 09:44:48 +00:00
|
|
|
def path_linear_forward(self, hidden_state):
|
2026-02-27 09:32:09 +00:00
|
|
|
return self.linear_fc2(fc1_bias_gelu_triton(hidden_state, self.linear_fc1.weight, self.linear_fc1.bias))
|