Table of Contents

  1. Sparse Tensor
  2. Sparse Tensor Decomposition
  3. Sparse GPU Encodings

Sparse Tensor

In the sparse compiler, a sparse tensor is represented in the same way as a dense tensor, except that its type is annotated with an additional sparse encoding. This encoding provides the necessary information to interpret the tensor’s sparse format. The presence of a sparse encoding is the only difference between sparse and dense tensor types.

Consider the following kernel:

@triton.jit
def dot_kernel(D, # Output Tensor
               A, # Sparse Input
               B, # Dense Input
               # Tile dimensions
               BLOCK_M:tl.constexpr, BLOCK_N:tl.constexpr, BLOCK_K:tl.constexpr):
    a_ptrs = tl.make_block_ptr(a_ptr,
                               shape=(BLOCK_M,BLOCK_K),
                               strides=(BLOCK_K,1),
                               offsets=(0,0),
                               block_shape=(BLOCK_M,BLOCK_K),
                               order=(1,0))
    b_ptrs = tl.make_block_ptr(b_ptr,
                               shape=(BLOCK_K,BLOCK_N),
                               strides=(BLOCK_N,1),
                               offsets=(0,0),
                               block_shape=(BLOCK_K,BLOCK_N),
                               order=(1,0))
    d_ptrs = tl.make_block_ptr(d_ptr,
                               shape=(BLOCK_M,BLOCK_N),
                               strides=(BLOCK_N,1),
                               offsets=(0,0),
                               block_shape=(BLOCK_M,BLOCK_N),
                               order=(1,0))

    a = tl.load(a_ptrs)
    b = tl.load(b_ptrs)

    d = tl.dot(a, b).to(tl.float16)

    tl.store(d_ptrs, d)

The initial IR generated by this code is as follows.

module {
  tt.func public @dot_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16, #tt.sparse<"NV24">>, %arg2: !tt.ptr<f16>) {
    %A_ptrs = tt.make_tensor_ptr %arg1, shape=[32, 64], strides=[64, 1], offsets=[0, 0] : <tensor<32x64xf16, #tt.sparse<"NV24">>>
    %B_ptrs = tt.make_tensor_ptr %arg2, shape=[64, 16], strides=[16, 1], offsets=[0, 0] : <tensor<64x16xf16>>
    %D_ptrs = tt.make_tensor_ptr %arg0, shape=[32, 16], strides=[16, 1], offsets=[0, 0] : <tensor<32x16xf16>>
    %A = tt.load %A_ptrs : !tt.ptr<tensor<32x64xf16, #tt.sparse<"NV24">>>
    %B = tt.load %B_ptrs : !tt.ptr<tensor<64x16xf16>>
    %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf32>
    %D = tt.dot %A, %B, %cst : tensor<32x64xf16, #tt.sparse<"NV24">> * tensor<64x16xf16> -> tensor<32x16xf32>
    %D_trunc = arith.truncf %D : tensor<32x16xf32> to tensor<32x16xf16>
    tt.store %D_ptrs, %D_trunc : !tt.ptr<tensor<32x16xf16>>
    tt.return
  }
}

All IR on this page has been simplified for brevity.

Notice the Triton pointer and all operations related to it are tagged with the #tt.sparse encoding.

Sparse Tensor Decomposition

The DecomposeSparseTensors transformation does a pass over the entire IR and decomposes each sparse tensor to its related dense arrays (e.g. the values and metadata arrays in the case of 2:4 sparsity). In our code example, the sparse argument is decomposed into two arguments, which will point to the beginning of the values and metadata arrays in global memory. The tensor creation and loading of the sparse tensor is decomposed into two operations, one for the values array and another for the metadata array. The requested parameters (shape, offsets, and strides) into the sparse tensor are recomputed to correctly represent the parameters in the values and metadata arrays. The metadata array is passed to the dot operation as an extra parameter and a sparseIndex attribute is attached to specify that the \(0^{\text{th}}\) operand is the one that is sparse. After the decomposition, no sparse tensors exist in the IR. Here is the simplified IR after this pass:

module {
  tt.func public @dot_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<i16>, %arg3:!tt.ptr<f16>) {
    %A_ptrs = tt.make_tensor_ptr %arg1, shape=[32, 32], strides=[32, 1], offsets=[0, 0] : <tensor<32x32xf16>>
    %E_ptrs = tt.make_tensor_ptr %arg2, shape=[2, 64], strides=[64, 1], offsets=[0, 0] : <tensor<2x64xi16>>
    %B_ptrs = tt.make_tensor_ptr %arg3, shape=[64, 16], strides=[16, 1], offsets=[0, 0] : <tensor<64x16xf16>>
    %D_ptrs = tt.make_tensor_ptr %arg0, shape=[32, 16], strides=[16, 1], offsets=[0, 0] : <tensor<32x16xf16>>
    %A = tt.load %A_ptrs : !tt.ptr<tensor<32x32xf16>>
    %E = tt.load %E_ptrs : !tt.ptr<tensor<2x64xi16>>
    %B = tt.load %B_ptrs : !tt.ptr<tensor<64x16xf16>>
    %cst = arith.constant dense<0.000000e+00> : tensor<32x16xf32>
    %D = tt.dot %A, %B, E, %cst, sparseIndex=0 : tensor<32x32xf16>, tensor<2x64xi16> * tensor<64x16xf16> -> tensor<32x16xf32>
    %D_trunc = arith.truncf %D : tensor<32x16xf32> to tensor<32x16xf16>
    tt.store %D_ptrs, %D_trunc : !tt.ptr<tensor<32x16xf16>>
    tt.return
  }
}

Sparse GPU Encodings

A dot operand is treated differently in passes depending on whether it is the values or metadata array of a sparse tensor or if it is a dense tensor. For example, to achieve maximal efficiency, swizzling parameters are different for %A, %B, and %E due to the difference in their access patterns. Many passes operate on dot operands without having direct access to the dot operation. Therefore, it is important to encode this information (whether a tensor is the %A, %B, or %E operand) directly into the tensor type. To support this, a parameter, meta, has been added to the Dot Operand Encoding Attribute. Here is a snippet of the GPU IR with the inclusion of the attribute:

%A = ... : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2, meta = 0}>>
%B = ... : tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%E = ... : tensor<2x64xi16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2, meta = 1}>>

%D = tt.dot %A, %B, %cst, %E

The following values are used for the classification:

\[\textbf{meta} = \left\{ \begin{array}{ll} -1 & \text{dense tensor} \\ 0 & \text{values tensor} \\ 1 & \text{metadata tensor} \\ \end{array} \right.\]

This helps the GPU optimization passes know how to best lower the sparse code. This can be extended to sparsity formats with more than one metadata array by using other values of meta to express the extra tensors.