GraphCUDA: Fusing Sparse-Dense and Dense-Dense Matrix Multiplication (Part 1)
2026-04-29 | 8 min read
This blog post is the first in a series where I iteratively build kernels for fusing sparse-dense and dense-dense matrix multiplication. It serves as a tool to build a step-by-step understanding of how GPU architectures have improved, and how different DSLs and tools can be used to gain more control over the optimization process and further improve performance.
This first blog post covers the initial implementation and inspiration for the kernel, written in Triton and targeting the Ampere architecture. In Part 2, I switch to CUDA C++ to learn the low-level implementation details that Triton abstracts away. Finally, in Part 3, I switch to the CuTe DSL to optimize further for Hopper and Blackwell architectures. All code is here.
I made this kernel for my own high-performance graph neural network library, GraphCUDA, specifically for the GCNConv forward pass, which, when stripped down to its core operations, is AXW, where A is the sparse adjacency matrix of the graph, X is the matrix of input embeddings for the graph, and W is the weight matrix.
You’ll get the most from this installment if you’re comfortable with matrix multiply as blocked computation and have some understanding of Triton and GPU architecture. If not, start with Triton’s matrix multiplication tutorial.
Inspiration from FlashAttention
I read the FlashAttention 2 paper, which, when you drop the online softmax, fuses a chain of matrix multiplications among matrices of shapes N×dmodel, dmodel×N, and N×dmodel.
Note that in this setup, each program in Triton computes a BLOCK_ROW×dmodel block of output. For this to be optimal, dmodel needs to be relatively small or bounded compared to sequence length, which can grow without bound, so that blocks fit in shared memory, etc. I realized a similar approach can be applied to sparse-dense and dense-dense matrix multiplication. In the general problem of multiplying three matrices A, B, and C, where A is N×K1, B is K1×K2, and C is K2×N, we can make the M and K1 dimensions small by only considering rows and columns that are non-zero in the row block. I will go into more depth on this in the next sections.
Defining the custom sparse matrix format
Image 2: BSR-RM (Row Major) format example and sparse-dense matrix multiplication.
This format is similar to block row sparse (BSR) matrix format, but instead of making BLOCK_M×BLOCK_N blocks, we instead make blocks only by rows, and in each row block store only the non-zero columns along with their column indices.
In fact, we can use torch.sparse utilities to create our modified sparse BSR format. Link to the file here.
Implementation details in Triton + NCU profiling
Image 3: The fused SpMM-GEMM approach compared with FlashAttention.
So now, we can use this new matrix format, where if we assume that for any given row block, the number of non-zero columns is small, denote this value K1(i) for row block i, then each program computes a BLOCK_M×N chunk of the output matrix, and loops over the sparse K1 dimension in blocks of size BLOCK_K1.
Image 4: K1 blocks.
Note that in Triton we cannot have an arbitrary block size, e.g. load a BLOCK_M×K1(i) block of bsr.values_rm, since tl.arange must have a power-of-2 constant end. So we add some complexity compared to FlashAttention by adding K1-dimension blocks, but the idea is that sparsity should be evenly distributed, so each Triton program does close to the same amount of work, i.e. the same number of loops over the K1 dimension, because each row block of matrix A has roughly the same number of non-zero columns. With this, we are finally ready to implement the kernel:
python
@triton.jitdeffused_spmm_gemm_relu_small_n_kernel(# sizes M: tl.constexpr, K1: tl.constexpr, K2: tl.constexpr, N: tl.constexpr,# ptrs x_ptr,# dense X (K1, K2) w_ptr,# dense weights (K2, N) out_ptr,# dense output (M, N) relu_mask_ptr,# dense uint8 mask (M, N), only written when apply_relu bsr_values_ptr,# BSR adjacency (M, K1) bsr_crow_ptr, bsr_col_ptr, bias_ptr,# strides sx_k1, sx_k2, sw_k2, sw_n, so_m, so_n, sm_m, sm_n, sb_n,# flags apply_relu: tl.constexpr, has_bias: tl.constexpr,# tile sizes BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K1: tl.constexpr, BLOCK_K2: tl.constexpr,):# ------------------- Compute PIDS ------------------- pid_m = tl.program_id(0)# ------------------- Compute Logical Offsets ------------------- offs_m =(pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) offs_n =(tl.arange(0, BLOCK_N)) tile_m_crow = tl.load(bsr_crow_ptr + pid_m) tile_m_next_crow = tl.load(bsr_crow_ptr + pid_m +1) K1_TILE_M = tile_m_next_crow - tile_m_crow
tile_m_adjm_bsr_values_offs = BLOCK_M * tile_m_crow
offs_m_adjm = tl.arange(0, BLOCK_M)# ------------------- Loop over K1 blocks ------------------- acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)# Can create asymmetric work across block M, but hopefully this is okay since lot of block Ms exist since M typically large.for k1_block in tl.range(0, K1_TILE_M, BLOCK_K1):# ------------------- Compute K1 block indices and offsets ------------------- offs_k1_adjm = k1_block + tl.arange(0, BLOCK_K1)# will be masked by K1_TILE_M offs_adjm_bsr_cols = k1_block + tl.arange(0, BLOCK_K1)# indexes we want to read from BSR col indices, will be masked by K1_TILE_M, need this because arange needs constant end value offs_k1_x = tl.load(bsr_col_ptr + tile_m_crow + offs_adjm_bsr_cols, mask=offs_adjm_bsr_cols < K1_TILE_M, other=0)# ------------------- Load adjm values ------------------- adjm_values_ptrs =( bsr_values_ptr
+ tile_m_adjm_bsr_values_offs
+ offs_m_adjm[:,None]* K1_TILE_M
+ offs_k1_adjm[None,:]*1# all contiguous in row major order) adjm_values_mask =(offs_m[:,None]< M)&(offs_k1_adjm[None,:]< K1_TILE_M) adjm_values = tl.load(adjm_values_ptrs, mask=adjm_values_mask, other=0.0)# ------------------- Loop over K2 blocks -------------------for k2_block inrange(0, K2, BLOCK_K2):# ------------------- Compute K2 block indices and offsets ------------------- offs_k2 =(k2_block + tl.arange(0, BLOCK_K2))# ------------------- Load x values ------------------- x_ptrs =( x_ptr
+ offs_k1_x[:,None]* sx_k1
+ offs_k2[None,:]* sx_k2
) x_mask =(offs_adjm_bsr_cols[:,None]< K1_TILE_M)&(offs_k1_x[:,None]< K1)&(offs_k2[None,:]< K2) x_values = tl.load(x_ptrs, mask=x_mask, other=0.0)# ------------------- Load weights values ------------------- w_ptrs =( w_ptr
+ offs_k2[:,None]* sw_k2
+ offs_n[None,:]* sw_n
) w_mask =(offs_k2[:,None]< K2)&(offs_n[None,:]< N) w_values = tl.load(w_ptrs, mask=w_mask, other=0.0)# ------------------- Compute GEMM ------------------- acc1 = tl.zeros((BLOCK_M, BLOCK_K2), dtype=tl.float32) acc1 = tl.dot(adjm_values, x_values, acc=acc1, input_precision="ieee", out_dtype=tl.float32) acc2 = tl.dot(acc1.to(w_values.dtype), w_values, acc=acc2, input_precision="ieee", out_dtype=tl.float32)if has_bias: bias_ptrs = bias_ptr + offs_n * sb_n
bias_mask =(offs_n < N) bias_vals = tl.load(bias_ptrs, mask=bias_mask, other=0.0) acc2 = acc2 + bias_vals[None,:]# ------------------- Apply ReLU ------------------- out_mask =(offs_m[:,None]< M)&(offs_n[None,:]< N)if apply_relu: relu_mask_vals =(acc2 >0.0) relu_mask_ptrs =( relu_mask_ptr
+ offs_m[:,None]* sm_m
+ offs_n[None,:]* sm_n
) tl.store(relu_mask_ptrs, relu_mask_vals, mask=out_mask) acc2 = tl.maximum(acc2,0.0)# ------------------- Write back to output ------------------- out_ptrs =( out_ptr
+ offs_m[:,None]* so_m
+ offs_n[None,:]* so_n
) tl.store(out_ptrs, acc2, mask=out_mask)
Edge case: Unbalanced sparsity
Consider this example.
Image 5: Example edge case with an uneven distribution of sparsity.
Profiling the edge case against the regular case with the NCU profiler, we initially don’t see occupancy as the issue. The theoretical and achieved occupancy are almost unchanged: the regular case has 8.37% achieved occupancy, while the edge case has 8.30%. This is because all Triton programs still do some work. However, program 0 just does a lot more work, while the rest of the CTAs finish almost immediately, leaving only a small amount of useful work running on the GPU.
text
Regular case
------------
Compute Throughput: 18.47%
Memory Throughput: 153.81 GB/s
SM Active Cycles: 17,704.77
Edge case
---------
Compute Throughput: 0.42%
Memory Throughput: 24.00 GB/s
SM Active Cycles: 3,717.88
The hope is that most adjacency matrices will be at most slightly uneven in sparsity, so there will still be many CTAs for the scheduler to run to counteract this scenario.
Looping K1 then K2 or vice versa
Another consideration is which dimension we should loop over first: K1 or K2.
python
for k1_block in tl.range(0, K1_TILE_M, BLOCK_K1):# ------------------- Compute K1 block indices and offsets ------------------- offs_k1_adjm = k1_block + tl.arange(0, BLOCK_K1)# will be masked by K1_TILE_M offs_adjm_bsr_cols = k1_block + tl.arange(0, BLOCK_K1)# indexes we want to read from BSR col indices, will be masked by K1_TILE_M, need this because arange needs constant end value offs_k1_x = tl.load(bsr_col_ptr + tile_m_crow + offs_adjm_bsr_cols, mask=offs_adjm_bsr_cols < K1_TILE_M, other=0)# ------------------- Load adjm values ------------------- adjm_values_ptrs =( bsr_values_ptr
+ tile_m_adjm_bsr_values_offs
+ offs_m_adjm[:,None]* K1_TILE_M
+ offs_k1_adjm[None,:]*1# all contiguous in row major order) adjm_values_mask =(offs_m[:,None]< M)&(offs_k1_adjm[None,:]< K1_TILE_M) adjm_values = tl.load(adjm_values_ptrs, mask=adjm_values_mask, other=0.0)# ------------------- Loop over K2 blocks -------------------for k2_block inrange(0, K2, BLOCK_K2):# ------------------- Compute K2 block indices and offsets ------------------- offs_k2 =(k2_block + tl.arange(0, BLOCK_K2))# ------------------- Load x values ------------------- x_ptrs =( x_ptr
+ offs_k1_x[:,None]* sx_k1
+ offs_k2[None,:]* sx_k2
) x_mask =(offs_adjm_bsr_cols[:,None]< K1_TILE_M)&(offs_k1_x[:,None]< K1)&(offs_k2[None,:]< K2) x_values = tl.load(x_ptrs, mask=x_mask, other=0.0)# ------------------- Load weights values ------------------- w_ptrs =( w_ptr
+ offs_k2[:,None]* sw_k2
+ offs_n[None,:]* sw_n
) w_mask =(offs_k2[:,None]< K2)&(offs_n[None,:]< N) w_values = tl.load(w_ptrs, mask=w_mask, other=0.0)# ------------------- Compute GEMM ------------------- acc1 = tl.zeros((BLOCK_M, BLOCK_K2), dtype=tl.float32) acc1 = tl.dot(adjm_values, x_values, acc=acc1, input_precision="ieee", out_dtype=tl.float32) acc2 = tl.dot(acc1.to(w_values.dtype), w_values, acc=acc2, input_precision="ieee", out_dtype=tl.float32)
Looping over K1, K2 dims respectively.
python
for k2_block inrange(0, K2, BLOCK_K2):# ------------------- Compute K2 block indices and offsets ------------------- offs_k2 =(k2_block + tl.arange(0, BLOCK_K2))# ------------------- Load weights values ------------------- w_ptrs =( w_ptr
+ offs_k2[:,None]* sw_k2
+ offs_n[None,:]* sw_n
) w_mask =(offs_k2[:,None]< K2)&(offs_n[None,:]< N) w_values = tl.load(w_ptrs, mask=w_mask, other=0.0)# ------------------- Loop over K1 blocks -------------------# Can create asymmetric work across block M, but hopefully this is okay since lot of block Ms exist since M typically large.for k1_block in tl.range(0, K1_TILE_M, BLOCK_K1):# ------------------- Compute K1 block indices and offsets ------------------- offs_k1_adjm = k1_block + tl.arange(0, BLOCK_K1)# will be masked by K1_TILE_M offs_adjm_bsr_cols = k1_block + tl.arange(0, BLOCK_K1)# indexes we want to read from BSR col indices, will be masked by K1_TILE_M, need this because arange needs constant end value offs_k1_x = tl.load(bsr_col_ptr + tile_m_crow + offs_adjm_bsr_cols, mask=offs_adjm_bsr_cols < K1_TILE_M, other=0)# ------------------- Load adjm values ------------------- adjm_values_ptrs =( bsr_values_ptr
+ tile_m_adjm_bsr_values_offs
+ offs_m_adjm[:,None]* K1_TILE_M
+ offs_k1_adjm[None,:]*1# all contiguous in row major order) adjm_values_mask =(offs_m[:,None]< M)&(offs_k1_adjm[None,:]< K1_TILE_M) adjm_values = tl.load(adjm_values_ptrs, mask=adjm_values_mask, other=0.0)# ------------------- Load x values ------------------- x_ptrs =( x_ptr
+ offs_k1_x[:,None]* sx_k1
+ offs_k2[None,:]* sx_k2
) x_mask =(offs_adjm_bsr_cols[:,None]< K1_TILE_M)&(offs_k1_x[:,None]< K1)&(offs_k2[None,:]< K2) x_values = tl.load(x_ptrs, mask=x_mask, other=0.0)# ------------------- Compute GEMM ------------------- acc1 = tl.zeros((BLOCK_M, BLOCK_K2), dtype=tl.float32) acc1 = tl.dot(adjm_values, x_values, acc=acc1, input_precision="ieee", out_dtype=tl.float32) acc2 = tl.dot(acc1.to(w_values.dtype), w_values, acc=acc2, input_precision="ieee", out_dtype=tl.float32)
Looping over K2, K1 dims respectively.
The difference between the two loop orders is what gets kept outside the inner loop. When looping over K1 and then K2, we load the BSR column indices and adjacency values once per K1 block and reuse them across the K2 loop. When switching the order, we instead reuse the weight tile across K1, but repeatedly fetch the sparse column indices and the sparse values inside the inner loop.
The profile suggests this extra work is the problem. The regular kernel executes far fewer instructions and its busiest pipeline is the Tensor pipeline, while the switched-loop kernel executes many more instructions and is dominated by ALU work:
text
Regular kernel (K1 outer, K2 inner)
-----------------------------------
Compute Workload Analysis:
Tensor is the highest-utilized pipeline (22.0%) ...
Instruction Statistics:
Executed Instructions 803,584
Switched-loop kernel (K2 outer, K1 inner)
-----------------------------------------
Compute Workload Analysis:
ALU is the highest-utilized pipeline (20.7%) ...
It executes integer and logic operations.
Instruction Statistics:
Executed Instructions 2,663,424
Padding the K2 dimension to be a multiple of 16
Padding the K2 dimension to a multiple of 16 makes the memory access pattern cleaner. In GraphCUDA, we do this by preprocessing the input data, specifically the input and weight matrices. Without padding, Triton struggles to compile code that uses the hardware’s full capabilities, such as clean vectorized memory access. Even though the Triton code stays the same, the generated code can become much faster because the compiler no longer has to handle as many misaligned accesses.
Autotuning
We can also add a Triton autotuning config for optimal BLOCK_K1 and BLOCK_K2.
Additionally, we could tune BLOCK_M. However, we would have to do this via a heuristic when constructing the sparse matrix, so it is harder to implement alongside tuning BLOCK_K1 and BLOCK_K2.
Benchmarks on A100
Initial benchmarks show that the forward pass is faster than the CSR sparse Torch implementation.
Image 6: Benchmark plot comparing the Triton implementation against the Torch sparse implementation.
Integrating this with my own GraphCUDA GCNConv layer and comparing against PyTorch Geometric, we see that training is faster.
In the next blog post, I will try to implement the lower-level details of this kernel in CUDA C++, optimized specifically for Ampere architectures. See Part 2.