dim3 Class — pytorch Architecture
Architecture documentation for the dim3 class in SparseSemiStructuredPack.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/cuda/SparseSemiStructuredPack.h lines 121–518
template <typename Element_>
struct KernelTypes {
using Element = Element_;
using Fragment =
cutlass::Array<Element, 8>; // always read from gmem in chunks of 128bits
using Fragment4 = cutlass::Array<Element, 4>;
using ValuesPacked = cutlass::Array<Element, 8>; // 4 first col, 4 second col
struct Params {
/// inputs
Element const* input;
int64_t input_s0;
int64_t input_dim0;
int64_t input_dim1;
/// outputs
Element* packed;
int64_t packed_stride;
Element* packed_trans;
int64_t packed_trans_stride;
uint64_t* threads_masks;
__host__ dim3 getBlocksGrid() const {
return dim3(
cutlass::ceil_div(input_dim0, kWarpX),
cutlass::ceil_div(input_dim1, kWarpY),
1);
}
static CUTLASS_HOST_DEVICE dim3 getThreadsGrid() {
return dim3(kWarpX / kThreadX, kWarpY / kThreadY, 1);
}
CUTLASS_DEVICE Tile8x8Masks* getCurrentThreadIndices() const {
Tile8x8Masks* gmem_threads_masks = (Tile8x8Masks*)threads_masks;
gmem_threads_masks += blockIdx.y * getThreadsGrid().y + threadIdx.y;
int64_t strideX = gridDim.y * getThreadsGrid().y;
gmem_threads_masks +=
(blockIdx.x * getThreadsGrid().x + threadIdx.x) * strideX;
return gmem_threads_masks;
}
};
struct Tile4x4Accessor {
using Element = Element_;
Fragment (&_lines)[8];
int _start_row;
int _start_col;
CUTLASS_DEVICE Tile4x4Accessor(
Fragment (&lines)[8],
int start_row,
int start_col)
: _lines(lines), _start_row(start_row), _start_col(start_col) {}
CUTLASS_DEVICE typename Fragment::reference at(int r, int c) {
return _lines[r + _start_row][c + _start_col];
}
};
struct Tile4x4Packed {
Fragment4 values[2];
CUTLASS_DEVICE Tile4x4Packed() {
values[0].clear();
values[1].clear();
}
};
// Returns a packed 4x4 tile (eg 2x4 values) which correspond to the values
// that are in `indices`. Also fills the `meta` array in the right format
// for consumption in the TensorCores.
// Example:
// indices: 0011
// 1001
// 1001
// 0100 (<- note, only 1 value on the last line)
// packed: values[0][2] values[1][0] values[2][0] values[3][1]
// values[0][3] values[1][3] values[2][3] Element(0)
CUTLASS_DEVICE static Tile4x4Packed pack_4x4(
Indices4x4 indices,
Tile4x4Accessor tile,
uint32_t& meta,
int meta_pos,
bool transpose = false) {
Tile4x4Packed packed;
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < 4; ++row) {
uint2b_t col0_from, col1_from;
auto packValue = [&](uint2b_t col_to, uint2b_t col_from) {
auto value = transpose ? tile.at(col_from, row).get()
: tile.at(row, col_from).get();
packed.values[col_to][row] = value;
if (col_to == uint2b_t(0)) {
col0_from = col_from;
} else {
col1_from = col_from;
}
};
auto isSelected = [&](int col) {
if (transpose) {
return indices & (1 << (row + 4 * col));
}
return indices & (1 << (col + 4 * row));
};
// Process cols 0/1
// We know that col0 is always packed to position 0 if it's there
// and col1 is packed to pos 0 or 1 (depending if col0 is selected)
if (isSelected(1)) {
packValue(uint2b_t(0), uint2b_t(1));
}
if (isSelected(0)) {
packValue(uint2b_t(0), uint2b_t(0));
}
if (isSelected(0) && isSelected(1)) {
packValue(uint2b_t(1), uint2b_t(1));
}
// Process cols 2/3
// same sort of heuristic
if (isSelected(2)) {
packValue(uint2b_t(1), uint2b_t(2));
}
if (isSelected(3)) {
packValue(uint2b_t(1), uint2b_t(3));
}
if (isSelected(2) && isSelected(3)) {
packValue(uint2b_t(0), uint2b_t(2));
}
int add_mask = (col0_from | (col1_from << 2)) << (8 * row + meta_pos);
meta |= add_mask;
}
return packed;
}
struct Tile8x8Meta {
// meta_ab[row] |= (real_col << (8*row + 2*pos))
uint32_t meta_ab;
uint32_t meta_cd;
// meta_ac_trans[col] |= (real_row << (8*col + 2*pos))
uint32_t meta_ac_trans;
uint32_t meta_bd_trans;
CUTLASS_DEVICE Tile8x8Meta() {
meta_ab = meta_cd = meta_ac_trans = meta_bd_trans = 0;
}
};
CUTLASS_DEVICE static void writePacked(
Element* ptr,
Fragment4 packed0,
Fragment4 packed1) {
Fragment write;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
write[i] = packed0[i].get();
write[i + 4] = packed1[i].get();
}
cutlass::arch::global_store<Fragment, sizeof(Fragment)>(write, ptr, true);
}
CUTLASS_DEVICE static void writePackedT(
Element* ptr,
int64_t stride,
Tile4x4Packed a,
Tile4x4Packed b) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
Fragment4 write;
write[0] = a.values[0][i].get();
write[1] = a.values[1][i].get();
write[2] = b.values[0][i].get();
write[3] = b.values[1][i].get();
cutlass::arch::global_store<Fragment4, sizeof(Fragment4)>(
write, ptr + i * stride, true);
}
}
template <typename Algorithm, typename MetadataStore>
CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel(
Params p,
MetadataStore metadata_gmem,
Algorithm compute_tile_indices) {
// Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
// A, B, C and D, as displayed in the following schema:
// +---+---+
// | A | B |
// +---+---+
// | C | D |
// +---+---+
// Each warp (32 threads) will then be responsible for a 32x64 tile of the
// input.
// This configuration allows to read/write data in 128bits chunks. These
// memory accesses are coalesced at the warp-level into 128bytes. See also:
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0
// Top-left of the 8x8 tile we own
int warp_x = blockIdx.x * kWarpX;
int warp_y = blockIdx.y * kWarpY;
int x = warp_x + threadIdx.x * kThreadX;
int y = warp_y + threadIdx.y * kThreadY;
Element const* input = p.input + x * p.input_s0 + y;
Element* packed = p.packed + x * p.packed_stride + (y / 2);
Element* packed_trans =
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
Fragment lines[8]; // Contains all values from the 8x8 tile
Tile8x8Meta metadata;
Tile8x8Masks indices;
// Load/process tiles `A` and `B`
Element fillValue = Algorithm::template outOfBoundsFillValue<Element>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
lines[i].fill(fillValue);
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
}
indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0));
indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4));
// Compute packed tiles A & B
{
Tile4x4Packed packed_a = pack_4x4(
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
Tile4x4Packed packed_b = pack_4x4(
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
writePackedT(packed, p.packed_stride, packed_a, packed_b);
}
// Compute/store packed tiles A & B in transpose output
Tile4x4Packed packed_trans_a = pack_4x4(
indices.a,
Tile4x4Accessor(lines, 0, 0),
metadata.meta_ac_trans,
0,
true);
Tile4x4Packed packed_trans_b = pack_4x4(
indices.b,
Tile4x4Accessor(lines, 0, 4),
metadata.meta_bd_trans,
0,
true);
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
// Load/process tiles `C` and `D`
CUTLASS_PRAGMA_UNROLL
for (int i = 4; i < 8; ++i) {
lines[i].fill(fillValue);
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
}
indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0));
indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4));
// Compute packed tiles C & D
{
Tile4x4Packed packed_c = pack_4x4(
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
Tile4x4Packed packed_d = pack_4x4(
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
writePackedT(
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
}
// Compute/store packed tiles C & D in transpose output
Tile4x4Packed packed_trans_c = pack_4x4(
indices.c,
Tile4x4Accessor(lines, 4, 0),
metadata.meta_ac_trans,
4,
true);
Tile4x4Packed packed_trans_d = pack_4x4(
indices.d,
Tile4x4Accessor(lines, 4, 4),
metadata.meta_bd_trans,
4,
true);
// Dump the metadata in a nice format
*p.getCurrentThreadIndices() = indices;
// Store packed A, B, C & D for transposed matrix
writePackedT(
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
packed_trans += 4 * p.packed_trans_stride;
writePackedT(
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
// Writing meta non-transposed
{
ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN(
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab);
warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd);
}
// Writing meta transposed
{
ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT(
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
warp_shuffle_and_write_meta(
packed_trans_meta_reordered, metadata.meta_ac_trans, true);
warp_shuffle_and_write_meta(
packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true);
}
}
CUTLASS_DEVICE static void sparse_semi_structured_apply_kernel(Params p) {
// See `sparse24_sparsify_both_ways_kernel`
// It's basically the same, just that we skip
// the part where compute the indices we keep
// Top-left of the 8x8 tile we own
int warp_x = blockIdx.x * kWarpX;
int warp_y = blockIdx.y * kWarpY;
int x = warp_x + threadIdx.x * kThreadX;
int y = warp_y + threadIdx.y * kThreadY;
Element const* input = p.input + x * p.input_s0 + y;
Element* packed = p.packed + x * p.packed_stride + (y / 2);
Element* packed_trans =
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
Fragment lines[8]; // Contains all values from the 8x8 tile
Tile8x8Meta metadata;
Tile8x8Masks indices = *p.getCurrentThreadIndices();
// Load/process tiles `A` and `B`
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 8; ++i) {
// NB: Values outside bounds is undefined, but shouldn't
// be used anywhere
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
}
// Compute packed tiles A & B
{
Tile4x4Packed packed_a = pack_4x4(
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
Tile4x4Packed packed_b = pack_4x4(
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
writePackedT(packed, p.packed_stride, packed_a, packed_b);
}
// Compute/store packed tiles A & B in transpose output
Tile4x4Packed packed_trans_a = pack_4x4(
indices.a,
Tile4x4Accessor(lines, 0, 0),
metadata.meta_ac_trans,
0,
true);
Tile4x4Packed packed_trans_b = pack_4x4(
indices.b,
Tile4x4Accessor(lines, 0, 4),
metadata.meta_bd_trans,
0,
true);
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
// Compute packed tiles C & D
{
Tile4x4Packed packed_c = pack_4x4(
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
Tile4x4Packed packed_d = pack_4x4(
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
writePackedT(
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
}
// Compute/store packed tiles C & D in transpose output
Tile4x4Packed packed_trans_c = pack_4x4(
indices.c,
Tile4x4Accessor(lines, 4, 0),
metadata.meta_ac_trans,
4,
true);
Tile4x4Packed packed_trans_d = pack_4x4(
indices.d,
Tile4x4Accessor(lines, 4, 4),
metadata.meta_bd_trans,
4,
true);
// Store packed A, B, C & D for transposed matrix
writePackedT(
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
packed_trans += 4 * p.packed_trans_stride;
writePackedT(
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free