GPUTritonPerformance Optimization

GraphCUDA: Fusing Sparse-Dense and Dense-Dense Matrix Multiplication (Part 1)

2026-04-29 | 8 min read

This blog post is the first in a series where I iteratively build kernels for fusing sparse-dense and dense-dense matrix multiplication. It serves as a tool to build a step-by-step understanding of how GPU architectures have improved, and how different DSLs and tools can be used to gain more control over the optimization process and further improve performance.

This first blog post covers the initial implementation and inspiration for the kernel, written in Triton and targeting the Ampere architecture. In Part 2, I switch to CUDA C++ to learn the low-level implementation details that Triton abstracts away. Finally, in Part 3, I switch to the CuTe DSL to optimize further for Hopper and Blackwell architectures. All code is here.

I made this kernel for my own high-performance graph neural network library, GraphCUDA, specifically for the GCNConv forward pass, which, when stripped down to its core operations, is AXWA\,X\,W, where A is the sparse adjacency matrix of the graph, X is the matrix of input embeddings for the graph, and W is the weight matrix.

You’ll get the most from this installment if you’re comfortable with matrix multiply as blocked computation and have some understanding of Triton and GPU architecture. If not, start with Triton’s matrix multiplication tutorial.

Inspiration from FlashAttention

I read the FlashAttention 2 paper, which, when you drop the online softmax, fuses a chain of matrix multiplications among matrices of shapes N×dmodelN \times d_{model}, dmodel×Nd_{model} \times N, and N×dmodelN \times d_{model}.

FlashAttention fused matrix multiplication
Image 1: FlashAttention fused matrix multiplication.

Note that in this setup, each program in Triton computes a BLOCK_ROW×dmodelBLOCK\_ROW \times d_{model} block of output. For this to be optimal, dmodeld_{model} needs to be relatively small or bounded compared to sequence length, which can grow without bound, so that blocks fit in shared memory, etc. I realized a similar approach can be applied to sparse-dense and dense-dense matrix multiplication. In the general problem of multiplying three matrices A, B, and C, where A is N×K1N \times K1, B is K1×K2K1 \times K2, and C is K2×NK2 \times N, we can make the M and K1 dimensions small by only considering rows and columns that are non-zero in the row block. I will go into more depth on this in the next sections.

Defining the custom sparse matrix format

BSR format example and sparse dense matmul
Image 2: BSR-RM (Row Major) format example and sparse-dense matrix multiplication.

This format is similar to block row sparse (BSR) matrix format, but instead of making BLOCK_M×BLOCK_NBLOCK\_M \times BLOCK\_N blocks, we instead make blocks only by rows, and in each row block store only the non-zero columns along with their column indices.

In fact, we can use torch.sparse utilities to create our modified sparse BSR format. Link to the file here.

Implementation details in Triton + NCU profiling

Approach outlined compared to flash attention
Image 3: The fused SpMM-GEMM approach compared with FlashAttention.

So now, we can use this new matrix format, where if we assume that for any given row block, the number of non-zero columns is small, denote this value K1(i)K1(i) for row block i, then each program computes a BLOCK_M×NBLOCK\_M \times N chunk of the output matrix, and loops over the sparse K1K1 dimension in blocks of size BLOCK_K1.

K1 blocks
Image 4: K1 blocks.

Note that in Triton we cannot have an arbitrary block size, e.g. load a BLOCK_M×K1(i)BLOCK\_M \times K1(i) block of bsr.values_rm, since tl.arange must have a power-of-2 constant end. So we add some complexity compared to FlashAttention by adding K1-dimension blocks, but the idea is that sparsity should be evenly distributed, so each Triton program does close to the same amount of work, i.e. the same number of loops over the K1 dimension, because each row block of matrix A has roughly the same number of non-zero columns. With this, we are finally ready to implement the kernel:

python
@triton.jit
def fused_spmm_gemm_relu_small_n_kernel(
    # sizes
    M: tl.constexpr,
    K1: tl.constexpr,
    K2: tl.constexpr,
    N: tl.constexpr,
    # ptrs
    x_ptr, # dense X (K1, K2)
    w_ptr, # dense weights (K2, N)
    out_ptr, # dense output (M, N)
    relu_mask_ptr, # dense uint8 mask (M, N), only written when apply_relu
    bsr_values_ptr, # BSR adjacency (M, K1)
    bsr_crow_ptr,
    bsr_col_ptr,
    bias_ptr,
    # strides
    sx_k1, sx_k2,
    sw_k2, sw_n,
    so_m, so_n,
    sm_m, sm_n,
    sb_n,
    # flags
    apply_relu: tl.constexpr,
    has_bias: tl.constexpr,
    # tile sizes
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K1: tl.constexpr,
    BLOCK_K2: tl.constexpr,
):
    # ------------------- Compute PIDS -------------------
    pid_m = tl.program_id(0)
    
    # ------------------- Compute Logical Offsets -------------------
    offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M))
    offs_n = (tl.arange(0, BLOCK_N))
    
    tile_m_crow = tl.load(bsr_crow_ptr + pid_m)
    tile_m_next_crow = tl.load(bsr_crow_ptr + pid_m + 1)
    K1_TILE_M = tile_m_next_crow - tile_m_crow
    
    tile_m_adjm_bsr_values_offs = BLOCK_M * tile_m_crow
    offs_m_adjm = tl.arange(0, BLOCK_M)
    
    # ------------------- Loop over K1 blocks -------------------
    acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    # Can create asymmetric work across block M, but hopefully this is okay since lot of block Ms exist since M typically large.
    for k1_block in tl.range(0, K1_TILE_M, BLOCK_K1):
        # ------------------- Compute K1 block indices and offsets -------------------
        offs_k1_adjm = k1_block + tl.arange(0, BLOCK_K1) # will be masked by K1_TILE_M
        offs_adjm_bsr_cols = k1_block + tl.arange(0, BLOCK_K1) # indexes we want to read from BSR col indices, will be masked by K1_TILE_M, need this because arange needs constant end value
        offs_k1_x = tl.load(bsr_col_ptr + tile_m_crow + offs_adjm_bsr_cols, mask=offs_adjm_bsr_cols < K1_TILE_M, other=0)
        
        # ------------------- Load adjm values -------------------
        adjm_values_ptrs = (
            bsr_values_ptr
            + tile_m_adjm_bsr_values_offs
            + offs_m_adjm[:, None] * K1_TILE_M
            + offs_k1_adjm[None, :] * 1 # all contiguous in row major order
        )
        adjm_values_mask = (offs_m[:, None] < M) & (offs_k1_adjm[None, :] < K1_TILE_M)
        adjm_values = tl.load(adjm_values_ptrs, mask=adjm_values_mask, other=0.0)
    
        # ------------------- Loop over K2 blocks -------------------
        for k2_block in range(0, K2, BLOCK_K2):
            # ------------------- Compute K2 block indices and offsets -------------------
            offs_k2 = (k2_block + tl.arange(0, BLOCK_K2))
            
            # ------------------- Load x values -------------------
            x_ptrs = (
                x_ptr
                + offs_k1_x[:, None] * sx_k1
                + offs_k2[None, :] * sx_k2
            )
            x_mask = (offs_adjm_bsr_cols[:, None] < K1_TILE_M) & (offs_k1_x[:, None] < K1) & (offs_k2[None, :] < K2)
            x_values = tl.load(x_ptrs, mask=x_mask, other=0.0)
            
            # ------------------- Load weights values -------------------
            w_ptrs = (
                w_ptr
                + offs_k2[:, None] * sw_k2
                + offs_n[None, :] * sw_n
            )
            w_mask = (offs_k2[:, None] < K2) & (offs_n[None, :] < N)
            w_values = tl.load(w_ptrs, mask=w_mask, other=0.0)
            
            # ------------------- Compute GEMM -------------------
            acc1 = tl.zeros((BLOCK_M, BLOCK_K2), dtype=tl.float32)
            acc1 = tl.dot(adjm_values, x_values, acc=acc1, input_precision="ieee", out_dtype=tl.float32)
            acc2 = tl.dot(acc1.to(w_values.dtype), w_values, acc=acc2, input_precision="ieee", out_dtype=tl.float32)
    
    if has_bias:
        bias_ptrs = bias_ptr + offs_n * sb_n
        bias_mask = (offs_n < N)
        bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0)
        acc2 = acc2 + bias_vals[None, :]
    
    # ------------------- Apply ReLU -------------------
    out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    if apply_relu:
        relu_mask_vals = (acc2 > 0.0)
        relu_mask_ptrs = (
            relu_mask_ptr
            + offs_m[:, None] * sm_m
            + offs_n[None, :] * sm_n
        )
        tl.store(relu_mask_ptrs, relu_mask_vals, mask=out_mask)
        acc2 = tl.maximum(acc2, 0.0)
    
    # ------------------- Write back to output -------------------
    out_ptrs = (
        out_ptr
        + offs_m[:, None] * so_m
        + offs_n[None, :] * so_n
    )
    tl.store(out_ptrs, acc2, mask=out_mask)

Edge case: Unbalanced sparsity

Consider this example.

Example edge case with uneven distribution of sparsity
Image 5: Example edge case with an uneven distribution of sparsity.

Profiling the edge case against the regular case with the NCU profiler, we initially don’t see occupancy as the issue. The theoretical and achieved occupancy are almost unchanged: the regular case has 8.37% achieved occupancy, while the edge case has 8.30%. This is because all Triton programs still do some work. However, program 0 just does a lot more work, while the rest of the CTAs finish almost immediately, leaving only a small amount of useful work running on the GPU.

text
Regular case
------------
Compute Throughput:        18.47%
Memory Throughput:         153.81 GB/s
SM Active Cycles:          17,704.77

Edge case
---------
Compute Throughput:         0.42%
Memory Throughput:          24.00 GB/s
SM Active Cycles:           3,717.88

The hope is that most adjacency matrices will be at most slightly uneven in sparsity, so there will still be many CTAs for the scheduler to run to counteract this scenario.

Looping K1 then K2 or vice versa

Another consideration is which dimension we should loop over first: K1 or K2.

python
for k1_block in tl.range(0, K1_TILE_M, BLOCK_K1):
    # ------------------- Compute K1 block indices and offsets -------------------
    offs_k1_adjm = k1_block + tl.arange(0, BLOCK_K1) # will be masked by K1_TILE_M
    offs_adjm_bsr_cols = k1_block + tl.arange(0, BLOCK_K1) # indexes we want to read from BSR col indices, will be masked by K1_TILE_M, need this because arange needs constant end value
    offs_k1_x = tl.load(bsr_col_ptr + tile_m_crow + offs_adjm_bsr_cols, mask=offs_adjm_bsr_cols < K1_TILE_M, other=0)
    
    # ------------------- Load adjm values -------------------
    adjm_values_ptrs = (
        bsr_values_ptr
        + tile_m_adjm_bsr_values_offs
        + offs_m_adjm[:, None] * K1_TILE_M
        + offs_k1_adjm[None, :] * 1 # all contiguous in row major order
    )
    adjm_values_mask = (offs_m[:, None] < M) & (offs_k1_adjm[None, :] < K1_TILE_M)
    adjm_values = tl.load(adjm_values_ptrs, mask=adjm_values_mask, other=0.0)

    # ------------------- Loop over K2 blocks -------------------
    for k2_block in range(0, K2, BLOCK_K2):
        # ------------------- Compute K2 block indices and offsets -------------------
        offs_k2 = (k2_block + tl.arange(0, BLOCK_K2))
        
        # ------------------- Load x values -------------------
        x_ptrs = (
            x_ptr
            + offs_k1_x[:, None] * sx_k1
            + offs_k2[None, :] * sx_k2
        )
        x_mask = (offs_adjm_bsr_cols[:, None] < K1_TILE_M) & (offs_k1_x[:, None] < K1) & (offs_k2[None, :] < K2)
        x_values = tl.load(x_ptrs, mask=x_mask, other=0.0)
        
        # ------------------- Load weights values -------------------
        w_ptrs = (
            w_ptr
            + offs_k2[:, None] * sw_k2
            + offs_n[None, :] * sw_n
        )
        w_mask = (offs_k2[:, None] < K2) & (offs_n[None, :] < N)
        w_values = tl.load(w_ptrs, mask=w_mask, other=0.0)
        
        # ------------------- Compute GEMM -------------------
        acc1 = tl.zeros((BLOCK_M, BLOCK_K2), dtype=tl.float32)
        acc1 = tl.dot(adjm_values, x_values, acc=acc1, input_precision="ieee", out_dtype=tl.float32)
        acc2 = tl.dot(acc1.to(w_values.dtype), w_values, acc=acc2, input_precision="ieee", out_dtype=tl.float32)
Looping over K1, K2 dims respectively.
python
for k2_block in range(0, K2, BLOCK_K2):
    # ------------------- Compute K2 block indices and offsets -------------------
    offs_k2 = (k2_block + tl.arange(0, BLOCK_K2))
    
    # ------------------- Load weights values -------------------
    w_ptrs = (
        w_ptr
        + offs_k2[:, None] * sw_k2
        + offs_n[None, :] * sw_n
    )
    w_mask = (offs_k2[:, None] < K2) & (offs_n[None, :] < N)
    w_values = tl.load(w_ptrs, mask=w_mask, other=0.0)

    # ------------------- Loop over K1 blocks -------------------
    # Can create asymmetric work across block M, but hopefully this is okay since lot of block Ms exist since M typically large.
    for k1_block in tl.range(0, K1_TILE_M, BLOCK_K1):
        # ------------------- Compute K1 block indices and offsets -------------------
        offs_k1_adjm = k1_block + tl.arange(0, BLOCK_K1) # will be masked by K1_TILE_M
        offs_adjm_bsr_cols = k1_block + tl.arange(0, BLOCK_K1) # indexes we want to read from BSR col indices, will be masked by K1_TILE_M, need this because arange needs constant end value
        offs_k1_x = tl.load(bsr_col_ptr + tile_m_crow + offs_adjm_bsr_cols, mask=offs_adjm_bsr_cols < K1_TILE_M, other=0)
        
        # ------------------- Load adjm values -------------------
        adjm_values_ptrs = (
            bsr_values_ptr
            + tile_m_adjm_bsr_values_offs
            + offs_m_adjm[:, None] * K1_TILE_M
            + offs_k1_adjm[None, :] * 1 # all contiguous in row major order
        )
        adjm_values_mask = (offs_m[:, None] < M) & (offs_k1_adjm[None, :] < K1_TILE_M)
        adjm_values = tl.load(adjm_values_ptrs, mask=adjm_values_mask, other=0.0)
        
        # ------------------- Load x values -------------------
        x_ptrs = (
            x_ptr
            + offs_k1_x[:, None] * sx_k1
            + offs_k2[None, :] * sx_k2
        )
        x_mask = (offs_adjm_bsr_cols[:, None] < K1_TILE_M) & (offs_k1_x[:, None] < K1) & (offs_k2[None, :] < K2)
        x_values = tl.load(x_ptrs, mask=x_mask, other=0.0)
        
        # ------------------- Compute GEMM -------------------
        acc1 = tl.zeros((BLOCK_M, BLOCK_K2), dtype=tl.float32)
        acc1 = tl.dot(adjm_values, x_values, acc=acc1, input_precision="ieee", out_dtype=tl.float32)
        acc2 = tl.dot(acc1.to(w_values.dtype), w_values, acc=acc2, input_precision="ieee", out_dtype=tl.float32)
Looping over K2, K1 dims respectively.

The difference between the two loop orders is what gets kept outside the inner loop. When looping over K1 and then K2, we load the BSR column indices and adjacency values once per K1 block and reuse them across the K2 loop. When switching the order, we instead reuse the weight tile across K1, but repeatedly fetch the sparse column indices and the sparse values inside the inner loop.

The profile suggests this extra work is the problem. The regular kernel executes far fewer instructions and its busiest pipeline is the Tensor pipeline, while the switched-loop kernel executes many more instructions and is dominated by ALU work:

text
Regular kernel (K1 outer, K2 inner)
-----------------------------------
Compute Workload Analysis:
    Tensor is the highest-utilized pipeline (22.0%) ...

Instruction Statistics:
    Executed Instructions                           803,584


Switched-loop kernel (K2 outer, K1 inner)
-----------------------------------------
Compute Workload Analysis:
    ALU is the highest-utilized pipeline (20.7%) ...
    It executes integer and logic operations.

Instruction Statistics:
    Executed Instructions                         2,663,424

Padding the K2 dimension to be a multiple of 16

Padding the K2 dimension to a multiple of 16 makes the memory access pattern cleaner. In GraphCUDA, we do this by preprocessing the input data, specifically the input and weight matrices. Without padding, Triton struggles to compile code that uses the hardware’s full capabilities, such as clean vectorized memory access. Even though the Triton code stays the same, the generated code can become much faster because the compiler no longer has to handle as many misaligned accesses.

Autotuning

We can also add a Triton autotuning config for optimal BLOCK_K1 and BLOCK_K2.

python
_FUSED_SPMM_GEMM_RELU_AUTOTUNE_CONFIGS = [
    triton.Config(
        {"BLOCK_K1": bk1, "BLOCK_K2": bk2},
        num_stages=ns,
        num_warps=nw,
    )
    for bk1 in (32,64,128,256,)
    for bk2 in (32,64,)
    for nw in (2,4,8,)
    for ns in (2,4,)
]


@triton.autotune(
    configs=_FUSED_SPMM_GEMM_RELU_AUTOTUNE_CONFIGS,
    key=["M", "N", "K1", "K2"],
    cache_results=True,
)

Additionally, we could tune BLOCK_M. However, we would have to do this via a heuristic when constructing the sparse matrix, so it is harder to implement alongside tuning BLOCK_K1 and BLOCK_K2.

Benchmarks on A100

Initial benchmarks show that the forward pass is faster than the CSR sparse Torch implementation.

Benchmark plot comparing the Triton implementation against the Torch sparse implementation
Image 6: Benchmark plot comparing the Triton implementation against the Torch sparse implementation.

Integrating this with my own GraphCUDA GCNConv layer and comparing against PyTorch Geometric, we see that training is faster.

text
Cora: nodes=2708, edges=10556, features=1433, classes=7, dtype=fp16
Epochs: warmup=1000, measured=10000, hidden_dim=16

Benchmarking pygeometric_gcn
pygeometric_gcn: total=13.243882s, avg_epoch=1.324ms, loss=0.0018, train_acc=1.0000, val_acc=0.7600, test_acc=0.7920

Benchmarking graphcuda_gcn
graphcuda_gcn: total=12.076949s, avg_epoch=1.208ms, loss=0.0004, train_acc=1.0000, val_acc=0.7740, test_acc=0.7860

GraphCUDA speedup vs PyG: 1.10x

Next steps

In the next blog post, I will try to implement the lower-level details of this kernel in CUDA C++, optimized specifically for Ampere architectures. See Part 2.