Home / Class/ void Class — pytorch Architecture

void Class — pytorch Architecture

Architecture documentation for the void class in SparseSemiStructuredPack.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/sparse/cuda/SparseSemiStructuredPack.h lines 301–431

  template <typename Algorithm, typename MetadataStore>
  CUTLASS_DEVICE static void sparse_semi_structured_tile_kernel(
      Params p,
      MetadataStore metadata_gmem,
      Algorithm compute_tile_indices) {
    // Each thread is responsible for an 8x8 tile, which contains 4 4x4 tiles:
    // A, B, C and D, as displayed in the following schema:
    // +---+---+
    // | A | B |
    // +---+---+
    // | C | D |
    // +---+---+
    // Each warp (32 threads) will then be responsible for a 32x64 tile of the
    // input.
    // This configuration allows to read/write data in 128bits chunks. These
    // memory accesses are coalesced at the warp-level into 128bytes. See also:
    // https://docs.google.com/presentation/d/1DtmKThv8S5QAyBktuLRYzZhRzCvS1qSkBbrqNCjMPeA/edit#slide=id.g2494f30c7cf_0_0

    // Top-left of the 8x8 tile we own
    int warp_x = blockIdx.x * kWarpX;
    int warp_y = blockIdx.y * kWarpY;
    int x = warp_x + threadIdx.x * kThreadX;
    int y = warp_y + threadIdx.y * kThreadY;

    Element const* input = p.input + x * p.input_s0 + y;
    Element* packed = p.packed + x * p.packed_stride + (y / 2);
    Element* packed_trans =
        p.packed_trans + (x / 2) + y * p.packed_trans_stride;

    Fragment lines[8]; // Contains all values from the 8x8 tile

    Tile8x8Meta metadata;
    Tile8x8Masks indices;

    // Load/process tiles `A` and `B`
    Element fillValue = Algorithm::template outOfBoundsFillValue<Element>();
    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < 4; ++i) {
      lines[i].fill(fillValue);
      cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
          lines[i], input + i * p.input_s0, x + i < p.input_dim0);
    }
    indices.a = compute_tile_indices(Tile4x4Accessor(lines, 0, 0));
    indices.b = compute_tile_indices(Tile4x4Accessor(lines, 0, 4));

    // Compute packed tiles A & B
    {
      Tile4x4Packed packed_a = pack_4x4(
          indices.a, Tile4x4Accessor(lines, 0, 0), metadata.meta_ab, 0);
      Tile4x4Packed packed_b = pack_4x4(
          indices.b, Tile4x4Accessor(lines, 0, 4), metadata.meta_ab, 4);
      writePackedT(packed, p.packed_stride, packed_a, packed_b);
    }

    // Compute/store packed tiles A & B in transpose output
    Tile4x4Packed packed_trans_a = pack_4x4(
        indices.a,
        Tile4x4Accessor(lines, 0, 0),
        metadata.meta_ac_trans,
        0,
        true);
    Tile4x4Packed packed_trans_b = pack_4x4(
        indices.b,
        Tile4x4Accessor(lines, 0, 4),
        metadata.meta_bd_trans,
        0,
        true);
    // (NOTE) Now we no longer need A & B (`lines[0:4]`)

    // Load/process tiles `C` and `D`
    CUTLASS_PRAGMA_UNROLL
    for (int i = 4; i < 8; ++i) {
      lines[i].fill(fillValue);
      cutlass::arch::global_load<Fragment, sizeof(Fragment)>(
          lines[i], input + i * p.input_s0, x + i < p.input_dim0);
    }
    indices.c = compute_tile_indices(Tile4x4Accessor(lines, 4, 0));
    indices.d = compute_tile_indices(Tile4x4Accessor(lines, 4, 4));

    // Compute packed tiles C & D
    {
      Tile4x4Packed packed_c = pack_4x4(
          indices.c, Tile4x4Accessor(lines, 4, 0), metadata.meta_cd, 0);
      Tile4x4Packed packed_d = pack_4x4(
          indices.d, Tile4x4Accessor(lines, 4, 4), metadata.meta_cd, 4);
      writePackedT(
          packed + 4 * p.packed_stride, p.packed_stride, packed_c, packed_d);
    }

    // Compute/store packed tiles C & D in transpose output
    Tile4x4Packed packed_trans_c = pack_4x4(
        indices.c,
        Tile4x4Accessor(lines, 4, 0),
        metadata.meta_ac_trans,
        4,
        true);
    Tile4x4Packed packed_trans_d = pack_4x4(
        indices.d,
        Tile4x4Accessor(lines, 4, 4),
        metadata.meta_bd_trans,
        4,
        true);

    // Dump the metadata in a nice format
    *p.getCurrentThreadIndices() = indices;

    // Store packed A, B, C & D for transposed matrix
    writePackedT(
        packed_trans, p.packed_trans_stride, packed_trans_a, packed_trans_c);
    packed_trans += 4 * p.packed_trans_stride;
    writePackedT(
        packed_trans, p.packed_trans_stride, packed_trans_b, packed_trans_d);

    // Writing meta non-transposed
    {
      ElementInputE* packed_meta_reordered = metadata_gmem.get_metaN(
          warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
      warp_shuffle_and_write_meta(packed_meta_reordered, metadata.meta_ab);
      warp_shuffle_and_write_meta(packed_meta_reordered + 32, metadata.meta_cd);
    }

    // Writing meta transposed
    {
      ElementInputE* packed_trans_meta_reordered = metadata_gmem.get_metaT(
          warp_x, threadIdx.x * kThreadX, warp_y, threadIdx.y * kThreadY);
      warp_shuffle_and_write_meta(
          packed_trans_meta_reordered, metadata.meta_ac_trans, true);
      warp_shuffle_and_write_meta(
          packed_trans_meta_reordered + 32, metadata.meta_bd_trans, true);
    }
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free