Home / Class/ Threads Class — pytorch Architecture

Threads Class — pytorch Architecture

Architecture documentation for the Threads class in tile_smem_loader.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h lines 20–66

template <
    typename scalar_t, // scalar type
    typename ThreadblockTileShape, // size of tile to load
    int Threads, // number of participating threads
    int ElementsPerAccess> // thread access width in elements
class TileSmemLoader {
 public:
  using SmemTile =
      cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;

  using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
      cutlass::layout::PitchLinearShape<
          ThreadblockTileShape::kColumn, // contiguous
          ThreadblockTileShape::kRow>, // strided
      Threads, // Threads
      ElementsPerAccess>; // ElementsPerAccess

  using GmemTileIterator =
      cutlass::transform::threadblock::PredicatedTileIterator<
          ThreadblockTileShape, // Shape
          scalar_t, // Element
          cutlass::layout::RowMajor, // Layout
          0, // AdvanceRank
          ThreadMap>; // ThreadMap

  using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
      ThreadblockTileShape, // Shape
      scalar_t, // Element
      cutlass::layout::RowMajor, // Layout
      0, // AdvanceRank
      ThreadMap>; // ThreadMap

  using Fragment = typename GmemTileIterator::Fragment;

  /// load a tile from global memory into shared memory
  CUTLASS_DEVICE
  static void load(
      GmemTileIterator tile_load_iter,
      SmemTileIterator tile_store_iter) {
    Fragment tb_frag;
    tb_frag.clear();
    tile_load_iter.load(tb_frag);
    tile_store_iter.store(tb_frag);

    __syncthreads();
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free