void Class — pytorch Architecture
Architecture documentation for the void class in SparseSemiStructuredPack.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/sparse/cuda/SparseSemiStructuredPack.h lines 301–431
template <typename Algorithm, typename MetadataStore>
CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel(
Params p,
MetadataStore metadata_gmem,
Algorithm compute_tile_indices) {
// Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
// A, B, C and D, as displayed in the following schema:
// +---+---+
// | A | B |
// +---+---+
// | C | D |
// +---+---+
// Each warp (32 threads) will then be responsible for a 32x64 tile of the
// input.
// This configuration allows to read/write data in 128bits chunks. These
// memory accesses are coalesced at the warp-level into 128bytes. See also:
// https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0
// Top-left of the 8x8 tile we own
int warp_x = blockIdx.x * kWarpX;
int warp_y = blockIdx.y * kWarpY;
int x = warp_x + threadIdx.x * kThreadX;
int y = warp_y + threadIdx.y * kThreadY;
Element const* input = p.input + x * p.input_s0 + y;
Element* packed = p.packed + x * p.packed_stride + (y / 2);
Element* packed_trans =
p.packed_trans + (x / 2) + y * p.packed_trans_stride;
Fragment lines[8]; // Contains all values from the 8x8 tile
Tile8x8Meta metadata;
Tile8x8Masks indices;
// Load/process tiles `A` and `B`
Element fillValue = Algorithm::template outOfBoundsFillValue<Element>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
lines[i].fill(fillValue);
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
}
indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0));
indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4));
// Compute packed tiles A & B
{
Tile4x4Packed packed_a = pack_4x4(
indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
Tile4x4Packed packed_b = pack_4x4(
indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
writePackedT(packed, p.packed_stride, packed_a, packed_b);
}
// Compute/store packed tiles A & B in transpose output
Tile4x4Packed packed_trans_a = pack_4x4(
indices.a,
Tile4x4Accessor(lines, 0, 0),
metadata.meta_ac_trans,
0,
true);
Tile4x4Packed packed_trans_b = pack_4x4(
indices.b,
Tile4x4Accessor(lines, 0, 4),
metadata.meta_bd_trans,
0,
true);
// (NOTE) Now we no longer need A & B (`lines[0:4]`)
// Load/process tiles `C` and `D`
CUTLASS_PRAGMA_UNROLL
for (int i = 4; i < 8; ++i) {
lines[i].fill(fillValue);
cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
lines[i], input + i * p.input_s0, x + i < p.input_dim0);
}
indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0));
indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4));
// Compute packed tiles C & D
{
Tile4x4Packed packed_c = pack_4x4(
indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
Tile4x4Packed packed_d = pack_4x4(
indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
writePackedT(
packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
}
// Compute/store packed tiles C & D in transpose output
Tile4x4Packed packed_trans_c = pack_4x4(
indices.c,
Tile4x4Accessor(lines, 4, 0),
metadata.meta_ac_trans,
4,
true);
Tile4x4Packed packed_trans_d = pack_4x4(
indices.d,
Tile4x4Accessor(lines, 4, 4),
metadata.meta_bd_trans,
4,
true);
// Dump the metadata in a nice format
*p.getCurrentThreadIndices() = indices;
// Store packed A, B, C & D for transposed matrix
writePackedT(
packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
packed_trans += 4 * p.packed_trans_stride;
writePackedT(
packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);
// Writing meta non-transposed
{
ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN(
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab);
warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd);
}
// Writing meta transposed
{
ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT(
warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
warp_shuffle_and_write_meta(
packed_trans_meta_reordered, metadata.meta_ac_trans, true);
warp_shuffle_and_write_meta(
packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free